Coverage for src/mlopus/utils/mongo.py: 97%
60 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import re
2from datetime import datetime
3from typing import Dict, Any, TypeVar, Literal, Tuple, Iterable, Callable, List
5from mlopus.utils import pydantic, dicts, time_utils
7T = TypeVar("T")
9Query = Dict[str, Any]
11Direction = Literal[1, -1]
13Sorting = List[Tuple[str, Direction]]
16def preprocess_query_or_doc(obj: dicts.AnyDict) -> dicts.AnyDict:
17 """Use safe repr for datetime because mongomock doesn't handle datetime quite well."""
18 return dicts.map_leaf_vals(obj, lambda x: time_utils.safe_repr(x) if isinstance(x, datetime) else x)
21def deprocess_query_or_doc(obj: dicts.AnyDict) -> dicts.AnyDict:
22 """Inverse function of `preprocess_query_or_doc`."""
23 return dicts.map_leaf_vals(obj, time_utils.maybe_parse_safe_repr)
26def find_all(
27 objs: Iterable[T],
28 query: Query,
29 sorting: Sorting | None = None,
30 to_doc: Callable[[T], dicts.AnyDict] = lambda x: x,
31 from_doc: Callable[[dicts.AnyDict], T] = lambda x: x,
32) -> Iterable[T]:
33 """Find all objects matching query in MongoDB query language."""
34 from mongomock.mongo_client import MongoClient
36 coll = MongoClient(tz_aware=True).db.collection
37 coll.insert_many((preprocess_query_or_doc(to_doc(x)) for x in objs))
38 docs = coll.find(preprocess_query_or_doc(query), sort=sorting)
39 return (deprocess_query_or_doc(from_doc(x)) for x in docs)
42class Mongo2Sql(pydantic.BaseModel):
43 """Basic MongoDB to SQL conversor for partial push-down of queries."""
45 field_pattern: re.Pattern = re.compile(r"[\w.]+") # Allowed chars in field name in SQL query
47 def parse_sorting(self, sorting: Sorting, coll: str) -> Tuple[str, Sorting]: # noqa
48 """Parse MongoDB sorting rule into SQL expression and rule remainder for partial sort push-down."""
49 clauses = []
50 remainder = []
52 for raw_subj, direction in sorting:
53 if raw_subj.startswith("$") or not (subj := self._parse_subj(coll, raw_subj)):
54 remainder.append((raw_subj, direction))
55 else:
56 clauses.append(f"{subj} {self._parse_direction(coll, direction)}")
58 return ", ".join(clauses), remainder
60 @classmethod
61 def _parse_direction(cls, coll: str, direction: Direction) -> str: # noqa
62 return {1: "ASC", -1: "DESC"}[direction]
64 def parse_query(self, query: Query, coll: str) -> Tuple[str, Query]:
65 """Parse MongoDB query into SQL expression and query remainder for partial query push-down."""
66 clauses = []
67 remainder = {}
69 for raw_subj, raw_filter in query.items():
70 if raw_subj.startswith("$") or not (subj := self._parse_subj(coll, raw_subj)):
71 remainder[raw_subj] = raw_filter
72 continue
74 if not isinstance(raw_filter, dict):
75 raw_filter = {"$eq": raw_filter}
77 for raw_pred, raw_obj in raw_filter.items():
78 clause_args = (coll, subj, raw_pred, raw_obj)
79 if (pred := self._parse_pred(*clause_args)) and (obj := self._parse_obj(*clause_args)):
80 clauses.append(" ".join([subj, pred, obj]))
81 else:
82 remainder.setdefault(raw_subj, {})[raw_pred] = raw_obj
84 return " AND ".join(clauses), remainder
86 def _parse_subj(self, coll: str, subj: str) -> str | None:
87 return subj if subj is not None and self.field_pattern.fullmatch(subj) else None
89 @classmethod
90 def _parse_pred(cls, coll: str, subj: str, raw_pred: Any, raw_obj: Any) -> str | None: # noqa
91 return {
92 "$eq": "IS" if raw_obj is None else "=",
93 "$gt": ">",
94 "$lt": "<",
95 "$neq": "!=",
96 "$gte": ">=",
97 "$lte": "<=",
98 }.get(raw_pred)
100 @classmethod
101 def _parse_obj(cls, coll: str, subj: str, pred: Any, raw_obj: Any) -> str | None: # noqa
102 match raw_obj:
103 case None:
104 return "NULL"
105 case str():
106 return "'%s'" % raw_obj.translate(str.maketrans({"'": "''"}))
107 case _:
108 return str(raw_obj)