Coverage for src/mlopus/utils/mongo.py: 97%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import re 

2from datetime import datetime 

3from typing import Dict, Any, TypeVar, Literal, Tuple, Iterable, Callable, List 

4 

5from mlopus.utils import pydantic, dicts, time_utils 

6 

7T = TypeVar("T") 

8 

9Query = Dict[str, Any] 

10 

11Direction = Literal[1, -1] 

12 

13Sorting = List[Tuple[str, Direction]] 

14 

15 

16def preprocess_query_or_doc(obj: dicts.AnyDict) -> dicts.AnyDict: 

17 """Use safe repr for datetime because mongomock doesn't handle datetime quite well.""" 

18 return dicts.map_leaf_vals(obj, lambda x: time_utils.safe_repr(x) if isinstance(x, datetime) else x) 

19 

20 

21def deprocess_query_or_doc(obj: dicts.AnyDict) -> dicts.AnyDict: 

22 """Inverse function of `preprocess_query_or_doc`.""" 

23 return dicts.map_leaf_vals(obj, time_utils.maybe_parse_safe_repr) 

24 

25 

26def find_all( 

27 objs: Iterable[T], 

28 query: Query, 

29 sorting: Sorting | None = None, 

30 to_doc: Callable[[T], dicts.AnyDict] = lambda x: x, 

31 from_doc: Callable[[dicts.AnyDict], T] = lambda x: x, 

32) -> Iterable[T]: 

33 """Find all objects matching query in MongoDB query language.""" 

34 from mongomock.mongo_client import MongoClient 

35 

36 coll = MongoClient(tz_aware=True).db.collection 

37 coll.insert_many((preprocess_query_or_doc(to_doc(x)) for x in objs)) 

38 docs = coll.find(preprocess_query_or_doc(query), sort=sorting) 

39 return (deprocess_query_or_doc(from_doc(x)) for x in docs) 

40 

41 

42class Mongo2Sql(pydantic.BaseModel): 

43 """Basic MongoDB to SQL conversor for partial push-down of queries.""" 

44 

45 field_pattern: re.Pattern = re.compile(r"[\w.]+") # Allowed chars in field name in SQL query 

46 

47 def parse_sorting(self, sorting: Sorting, coll: str) -> Tuple[str, Sorting]: # noqa 

48 """Parse MongoDB sorting rule into SQL expression and rule remainder for partial sort push-down.""" 

49 clauses = [] 

50 remainder = [] 

51 

52 for raw_subj, direction in sorting: 

53 if raw_subj.startswith("$") or not (subj := self._parse_subj(coll, raw_subj)): 

54 remainder.append((raw_subj, direction)) 

55 else: 

56 clauses.append(f"{subj} {self._parse_direction(coll, direction)}") 

57 

58 return ", ".join(clauses), remainder 

59 

60 @classmethod 

61 def _parse_direction(cls, coll: str, direction: Direction) -> str: # noqa 

62 return {1: "ASC", -1: "DESC"}[direction] 

63 

64 def parse_query(self, query: Query, coll: str) -> Tuple[str, Query]: 

65 """Parse MongoDB query into SQL expression and query remainder for partial query push-down.""" 

66 clauses = [] 

67 remainder = {} 

68 

69 for raw_subj, raw_filter in query.items(): 

70 if raw_subj.startswith("$") or not (subj := self._parse_subj(coll, raw_subj)): 

71 remainder[raw_subj] = raw_filter 

72 continue 

73 

74 if not isinstance(raw_filter, dict): 

75 raw_filter = {"$eq": raw_filter} 

76 

77 for raw_pred, raw_obj in raw_filter.items(): 

78 clause_args = (coll, subj, raw_pred, raw_obj) 

79 if (pred := self._parse_pred(*clause_args)) and (obj := self._parse_obj(*clause_args)): 

80 clauses.append(" ".join([subj, pred, obj])) 

81 else: 

82 remainder.setdefault(raw_subj, {})[raw_pred] = raw_obj 

83 

84 return " AND ".join(clauses), remainder 

85 

86 def _parse_subj(self, coll: str, subj: str) -> str | None: 

87 return subj if subj is not None and self.field_pattern.fullmatch(subj) else None 

88 

89 @classmethod 

90 def _parse_pred(cls, coll: str, subj: str, raw_pred: Any, raw_obj: Any) -> str | None: # noqa 

91 return { 

92 "$eq": "IS" if raw_obj is None else "=", 

93 "$gt": ">", 

94 "$lt": "<", 

95 "$neq": "!=", 

96 "$gte": ">=", 

97 "$lte": "<=", 

98 }.get(raw_pred) 

99 

100 @classmethod 

101 def _parse_obj(cls, coll: str, subj: str, pred: Any, raw_obj: Any) -> str | None: # noqa 

102 match raw_obj: 

103 case None: 

104 return "NULL" 

105 case str(): 

106 return "'%s'" % raw_obj.translate(str.maketrans({"'": "''"})) 

107 case _: 

108 return str(raw_obj)