Coverage for src/mlopus/mlflow/providers/mlflow.py: 94%
324 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import contextlib
2import hashlib
3import logging
4import os
5from datetime import datetime
6from pathlib import Path
7from typing import Dict, Any, TypeVar, Callable, Tuple, List, Literal, Mapping, Iterable, Set, Iterator, ContextManager
9from mlflow import MlflowClient as _NativeMlflowClient, entities as native
10from mlflow.entities import model_registry as native_model_registry
11from mlflow.store.entities import PagedList
12from mlflow.tracking import artifact_utils as mlflow_artifact_utils
13from mlflow.utils import rest_utils as mlflow_rest_utils
15from mlopus.mlflow.api.base import BaseMlflowApi
16from mlopus.mlflow.api.common import schema, patterns
17from mlopus.utils import pydantic, mongo, dicts, json_utils, iter_utils, urls, string_utils, env_utils
19logger = logging.getLogger(__name__)
21A = TypeVar("A") # Any type
22A2 = TypeVar("A2") # Any type
24# Entity types used in MLOpus
25E = schema.Experiment
26R = schema.Run
27M = schema.Model
28V = schema.ModelVersion
29T = TypeVar("T", bound=schema.BaseEntity)
31# Types used natively by open source MLflow
32NE = native.Experiment
33NR = native.Run
34NM = native_model_registry.RegisteredModel
35NV = native_model_registry.ModelVersion
36NT = TypeVar("NT", NE, NR, NM, NV)
39class MlflowClient(_NativeMlflowClient):
40 """Patch of native MlflowClient."""
42 def healthcheck(self):
43 """Check if service is available."""
44 if (url := urls.parse_url(self.tracking_uri)).scheme in ("http", "https"):
45 creds = self._tracking_client.store.get_host_creds()
46 response = mlflow_rest_utils.http_request(creds, "/health", "GET")
47 assert response.status_code == 200, f"Healthcheck failed for URI '{url}'"
48 elif not urls.is_local(url):
49 raise NotImplementedError(f"No healthcheck for URI '{url}'")
51 def get_artifact_uri(self, run_id: str) -> str:
52 """Get root URL of remote run artifacts, according to MLflow server."""
53 return mlflow_artifact_utils.get_artifact_uri(run_id, tracking_uri=self.tracking_uri)
55 @property
56 def tracking_client(self):
57 """Public access to tracking client."""
58 return self._tracking_client
61class KeepUntouched(pydantic.BaseModel):
62 """Set of rules for keys of params/values/metrics whose values are to be kept untouched.
64 Their values are exchanged to/from MLflow server without any pre-processing/encoding/escaping.
65 """
67 prefixes: Dict[str, Set[str]] = {"tags": {"mlflow"}}
69 def __call__(self, key_parts: Tuple[str, ...], scope: Literal["tags", "params", "metrics"]) -> bool:
70 return key_parts[0] in self.prefixes.get(scope, ())
73class MaxLength(pydantic.BaseModel):
74 """Max value lengths."""
76 tag: int = 500
77 param: int = 500
80class MlflowDataTranslation(pydantic.BaseModel):
81 """Translates native MLflow data to MLOpus format and back."""
83 ignore_json_errors: bool = True
84 max_length: MaxLength = MaxLength()
85 keep_untouched: KeepUntouched = KeepUntouched()
87 class Config(BaseMlflowApi.Config):
88 """Class constants."""
90 STATUS_TO_MLFLOW: Dict[schema.RunStatus, native.RunStatus] = {
91 schema.RunStatus.FAILED: native.RunStatus.FAILED,
92 schema.RunStatus.RUNNING: native.RunStatus.RUNNING,
93 schema.RunStatus.FINISHED: native.RunStatus.FINISHED,
94 schema.RunStatus.SCHEDULED: native.RunStatus.SCHEDULED,
95 }
97 STATUS_FROM_MLFLOW: Dict[native.RunStatus, schema.RunStatus | None] = {
98 **{v: k for k, v in STATUS_TO_MLFLOW.items()},
99 native.RunStatus.KILLED: schema.RunStatus.FAILED,
100 }
102 # =======================================================================================================
103 # === Datetime translation ==============================================================================
105 @classmethod
106 def mlflow_ts_to_datetime(cls, timestamp: int | None) -> datetime | None:
107 """Parse MLflow timestamp as datetime."""
108 return None if timestamp is None else datetime.fromtimestamp(timestamp / 1000)
110 @classmethod
111 def datetime_to_mlflow_ts(cls, datetime_: datetime) -> int:
112 """Coerce datetime to MLflow timestamp."""
113 return int(datetime_.timestamp() * 1000)
115 # =======================================================================================================
116 # === Enum translation ==================================================================================
118 @classmethod
119 def run_status_from_mlflow(cls, status: native.RunStatus | str) -> schema.RunStatus:
120 """Parse run status enum value from MLflow."""
121 status = native.RunStatus.from_string(status) if isinstance(status, str) else status
122 return cls.Config.STATUS_FROM_MLFLOW[status]
124 @classmethod
125 def run_status_to_mlflow(cls, status: schema.RunStatus, as_str: bool = False) -> native.RunStatus | str:
126 """Coerce run status enum value to MLflow format."""
127 status = cls.Config.STATUS_TO_MLFLOW[status]
128 return native.RunStatus.to_string(status) if as_str else status
130 # =======================================================================================================
131 # === Dict translation ==================================================================================
133 @classmethod
134 @string_utils.retval_matches(patterns.TAG_PARAM_OR_METRIC_KEY)
135 def encode_key(cls, key: str) -> str:
136 """Make sure dict key is safe for storage and query."""
137 return str(key)
139 @classmethod
140 def _decode_key(cls, key: str) -> str:
141 """Inverse func of _encode_key."""
142 return key
144 @classmethod
145 def flatten_key(cls, key_parts: Iterable[str]) -> str:
146 """Turns a nested dict key into a flat one by joining with dot delimiter."""
147 return ".".join(key_parts)
149 @classmethod
150 def unflatten_key(cls, key: str) -> Tuple[str, ...]:
151 """Inverse func of _flatten_key."""
152 return tuple(key.split("."))
154 def preprocess_key(self, key_parts: Tuple[str, ...]) -> str:
155 """Process key of tag, param or metric to be sent to MLflow."""
156 return self.flatten_key(self.encode_key(k) for k in key_parts)
158 def _preprocess_dict(
159 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2 | None]
160 ) -> Iterable[Tuple[str, A2]]:
161 for key_parts, val in dicts.flatten(data).items():
162 if (mapped_val := val_mapper(key_parts, val)) is not None:
163 yield self.preprocess_key(key_parts), mapped_val
165 def preprocess_dict(
166 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2 | None]
167 ) -> Dict[str, A2]:
168 """Pre-process dict of tags, params or metrics to be compatible with MLflow.
170 - Flatten nested keys as tuples
171 - URL-encode dots and other forbidden chars in keys
172 - Join tuple keys as dot-delimited strings
173 - Map leaf-values (tags and params as json, metrics as float)
174 """
175 return dict(self._preprocess_dict(data, val_mapper))
177 def _deprocess_dict(
178 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2]
179 ) -> Iterable[Tuple[Tuple[str, ...], A2]]:
180 """Inverse func of _preprocess_dict."""
181 for key, val in data.items():
182 key_parts = tuple(self._decode_key(k) for k in self.unflatten_key(key))
183 yield key_parts, val_mapper(key_parts, val)
185 def deprocess_dict(self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2]) -> Dict[str, A2]:
186 """Inverse func of _preprocess_dict."""
187 return dicts.unflatten(self._deprocess_dict(data, val_mapper))
189 def process_tag(self, key_parts: Tuple[str, ...], val: Any) -> str | None:
190 """JSON encode user-tag."""
191 if not self.keep_untouched(key_parts, scope="tags"):
192 val = string_utils.escape_sql_single_quote(json_utils.dumps(val))
193 if len(str(val)) > self.max_length.tag:
194 val = None
195 logger.warning("Ignoring tag above max length of %s: %s", self.max_length.tag, key_parts)
196 return val
198 def _deprocess_tag(self, key_parts: Tuple[str, ...], val: str) -> Any:
199 """JSON decode user-tag value from MLflow."""
200 if not self.keep_untouched(key_parts, scope="tags"):
201 val = json_utils.loads(string_utils.unscape_sql_single_quote(val), ignore_errors=self.ignore_json_errors)
202 return val
204 def process_param(self, key_parts: Tuple[str, ...], val: Any) -> str: # noqa
205 """JSON encode param."""
206 if not self.keep_untouched(key_parts, scope="params"):
207 val = string_utils.escape_sql_single_quote(json_utils.dumps(val))
208 if len(str(val)) > self.max_length.param:
209 val = None
210 logger.warning("Ignoring param above max length of %s: %s", self.max_length.param, key_parts)
211 return val
213 def _deprocess_param(self, key_parts: Tuple[str, ...], val: str) -> Any: # noqa
214 """JSON decode param value from MLflow."""
215 if not self.keep_untouched(key_parts, scope="params"):
216 val = json_utils.loads(string_utils.unscape_sql_single_quote(val), ignore_errors=self.ignore_json_errors)
217 return val
219 def process_metric(self, key_parts: Tuple[str, ...], val: Any) -> float: # noqa
220 """Coerce metric value to float."""
221 if not self.keep_untouched(key_parts, scope="metrics"):
222 val = float(val)
223 return val
225 def _deprocess_metric(self, key_parts: Tuple[str, ...], val: Any) -> float: # noqa
226 """Coerce metric val from MLflow to float."""
227 if not self.keep_untouched(key_parts, scope="metrics"):
228 val = float(val)
229 return val
231 def preprocess_tags(self, data: Mapping) -> Dict[str, str]:
232 """Prepare tags dict to be sent to native MLflow API."""
233 return self.preprocess_dict(data, val_mapper=self.process_tag)
235 def deprocess_tags(self, data: Mapping) -> Dict[str, Any]:
236 """Inverse func of _preprocess_tags."""
237 return self.deprocess_dict(data, val_mapper=self._deprocess_tag)
239 def preprocess_params(self, data: Mapping) -> Dict[str, str]:
240 """Prepare params dict to be sent to native MLflow API."""
241 return self.preprocess_dict(data, val_mapper=self.process_param)
243 def deprocess_params(self, data: Mapping) -> Dict[str, Any]:
244 """Inverse func of _preprocess_params."""
245 return self.deprocess_dict(data, val_mapper=self._deprocess_param)
247 def preprocess_metrics(self, data: Mapping) -> Dict[str, float]:
248 """Prepare metrics dict to be sent to native MLflow API."""
249 return self.preprocess_dict(data, val_mapper=self.process_metric)
251 def deprocess_metrics(self, data: Mapping) -> Dict[str, Any]:
252 """Inverse func of _preprocess_metrics."""
253 return self.deprocess_dict(data, val_mapper=self._deprocess_metric)
256class MlflowQueryPushDown(mongo.Mongo2Sql):
257 """Basic MongoDB query to MLflow SQL conversor."""
259 mongo_subj_2_mlflow: Dict[str, Dict[str, str]] = {
260 "exp": {
261 "name": "name",
262 },
263 "run": {
264 "id": "id",
265 "name": "run_name",
266 "status": "status",
267 "end_time": "end_time",
268 "start_time": "start_time",
269 },
270 "model": {
271 "name": "name",
272 },
273 "mv": {
274 "run.id": "run_id",
275 "model.name": "name",
276 "version": "version_number",
277 },
278 }
279 data_translation: MlflowDataTranslation = None
280 nested_subjects: Set[str] = {"metrics"}
282 def __init__(self, **kwargs):
283 super().__init__(**kwargs)
284 if intersect := self.nested_subjects.intersection({"tags", "params"}):
285 logger.warning(
286 "Pushing down queries for the nested subject(s) %s may produce "
287 "incomplete results because of MLflow SQL limitations over JSON fields.",
288 intersect,
289 )
291 def _parse_subj(self, coll: str, subj: str) -> str | None:
292 if (scope := (parts := subj.split("."))[0]) in self.nested_subjects:
293 subj = f"{scope}.%s" % self.data_translation.preprocess_key(tuple(parts[1:]))
294 else:
295 subj = self.mongo_subj_2_mlflow[coll].get(subj)
296 return super()._parse_subj(coll, subj)
298 def _parse_obj(self, coll: str, subj: str, pred: Any, raw_obj: Any) -> str | None:
299 if coll == "run" and subj in ("start_time", "end_time") and isinstance(raw_obj, datetime):
300 raw_obj = self.data_translation.datetime_to_mlflow_ts(raw_obj)
301 elif isinstance(raw_obj, schema.RunStatus):
302 raw_obj = self.data_translation.run_status_to_mlflow(raw_obj, as_str=True)
303 elif process := {
304 "tags": self.data_translation.process_tag,
305 "params": self.data_translation.process_param,
306 "metrics": self.data_translation.process_metric,
307 }.get((parts := subj.split("."))[0]):
308 return "'%s'" % process(tuple(parts[1:]), raw_obj) # noqa
310 return super()._parse_obj(coll, subj, pred, raw_obj)
312 def parse_exp(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]:
313 """Parse query and sorting rule for experiments search, return SQL expression and remainder for each."""
314 return *self.parse_query(query, coll="exp"), *self.parse_sorting(sorting, coll="exp") # noqa
316 def parse_run(
317 self, query: mongo.Query, sorting: mongo.Sorting
318 ) -> Tuple[str, mongo.Query, str, mongo.Sorting, List[str]]:
319 """Parse query and sorting rule for runs search, return SQL expression and remainder for each, plus exp IDs."""
320 if isinstance(exp_ids := query.pop("exp.id", None), str):
321 exp_ids = [exp_ids]
322 elif isinstance(exp_ids, dict) and set(exp_ids.keys()) == "$in":
323 exp_ids = exp_ids.pop("$in")
324 if not (isinstance(exp_ids, (list, tuple)) and len(exp_ids) > 0 and all(isinstance(x, str) for x in exp_ids)):
325 raise ValueError(
326 f"{self.__class__.__name__}: `exp.id` must be specified when querying runs. "
327 'Example: {"exp.id": {"$in": ["42", "13"]}}'
328 )
329 return *self.parse_query(query, coll="run"), *self.parse_sorting(sorting, coll="run"), exp_ids # noqa
331 def parse_model(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]:
332 """Parse query and sorting rule for models search, return SQL expression and remainder for each."""
333 return *self.parse_query(query, coll="model"), *self.parse_sorting(sorting, coll="model") # noqa
335 def parse_mv(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]:
336 """Parse query and sorting rule for model version search, return SQL expression and remainder for each."""
337 return *self.parse_query(query, coll="mv"), *self.parse_sorting(sorting, coll="mv") # noqa
340class MlflowTagKeys(pydantic.BaseModel):
341 """Default tag keys."""
343 artifacts_repo: str = "artifacts_repo"
344 parent_run_id: str = "mlflow.parentRunId"
347class MlflowApi(BaseMlflowApi):
348 """MLflow API provider based on open source MLflow.
350 **Plugin name:** `mlflow`
352 **Requires extras:** `mlflow`
354 **Default cache dir:** `~/.cache/mlopus/mlflow-providers/mlflow/<hashed-tracking-uri>`
356 Assumptions:
357 - No artifacts proxy.
358 - SQL database is server-managed.
359 """
361 tracking_uri: str = pydantic.Field(
362 default=None,
363 description=(
364 "MLflow server URL or path to a local directory. "
365 "Defaults to the environment variable `MLFLOW_TRACKING_URI`, "
366 "falls back to `~/.cache/mlflow`."
367 ),
368 )
370 healthcheck: bool = pydantic.Field(
371 default=True,
372 description=(
373 "If true and not in :attr:`~mlopus.mlflow.BaseMlflowApi.offline_mode`, "
374 "eagerly attempt connection to the server after initialization."
375 ),
376 )
378 client_settings: Dict[str, str | int] = pydantic.Field(
379 default_factory=dict,
380 description=(
381 "MLflow client settings. Keys are like the open-source MLflow environment variables, "
382 "but lower case and without the `MLFLOW_` prefix. Example: `http_request_max_retries`. "
383 "See: https://mlflow.org/docs/latest/python_api/mlflow.environment_variables.html"
384 ),
385 )
387 tag_keys: MlflowTagKeys = pydantic.Field(
388 repr=False,
389 default_factory=MlflowTagKeys,
390 description="Tag keys for storing internal information such as parent run ID.",
391 )
393 query_push_down: MlflowQueryPushDown = pydantic.Field(
394 repr=False,
395 default_factory=MlflowQueryPushDown,
396 description=(
397 "Utility for partial translation of MongoDB queries to open-source MLflow SQL. "
398 "Users may replace this with a different implementation when subclassing the API."
399 ),
400 )
402 data_translation: MlflowDataTranslation = pydantic.Field(
403 repr=False,
404 default_factory=MlflowDataTranslation,
405 description=(
406 "Utility for translating keys and values from MLOpus schema to native MLflow schema and back. "
407 "Users may replace this with a different implementation when subclassing the API."
408 ),
409 )
411 # =======================================================================================================
412 # === Pydantic validators ===============================================================================
414 @pydantic.root_validator(pre=True) # noqa
415 @classmethod
416 def _valid_tracking_uri(cls, values: dicts.AnyDict) -> dicts.AnyDict:
417 """Use default if provided value is None or empty string."""
418 raw_url = values.get("tracking_uri") or os.environ.get("MLFLOW_TRACKING_URI") or (Path.home() / ".cache/mlflow")
419 values["tracking_uri"] = str(urls.parse_url(raw_url, resolve_if_local=True))
420 return values
422 def __init__(self, **kwargs):
423 """Let the query push down use the same data translator as the API."""
424 super().__init__(**kwargs)
425 if self.query_push_down.data_translation is None:
426 self.query_push_down.data_translation = self.data_translation
428 if self.healthcheck and not self.offline_mode:
429 with self._client() as cli:
430 cli.healthcheck()
432 # =======================================================================================================
433 # === Properties ========================================================================================
435 @property
436 def _default_cache_id(self) -> str:
437 """Sub-dir under default cache dir. Only used if `cache_dir` is not specified."""
438 return hashlib.md5(self.tracking_uri.encode()).hexdigest()[:16]
440 # =======================================================================================================
441 # === Client ============================================================================================
443 def _get_client(self) -> MlflowClient:
444 assert not self.offline_mode, "Cannot use MlflowClient in offline mode."
445 return MlflowClient(self.tracking_uri)
447 @contextlib.contextmanager
448 def _client(self) -> ContextManager[MlflowClient]:
449 with env_utils.using_env_vars({"MLFLOW_%s" % k.upper(): str(v) for k, v in self.client_settings.items()}):
450 yield self._get_client()
452 def _using_client(self, func: Callable[[MlflowClient], A]) -> A:
453 with self._client() as client:
454 return func(client)
456 # =======================================================================================================
457 # === Metadata parsers ==================================================================================
459 def _parse_experiment(self, native_experiment: NE) -> E:
460 return E(
461 name=native_experiment.name,
462 id=native_experiment.experiment_id,
463 tags=self.data_translation.deprocess_tags(native_experiment.tags),
464 )
466 def _parse_run(self, native_run: NR, experiment: E) -> R:
467 if not (repo := native_run.data.tags.get(self.tag_keys.artifacts_repo)):
468 repo = self._using_client(lambda client: client.get_artifact_uri(native_run.info.run_id))
470 return R(
471 exp=experiment,
472 id=native_run.info.run_id,
473 name=native_run.info.run_name,
474 repo=str(urls.parse_url(repo, resolve_if_local=True)),
475 tags=self.data_translation.deprocess_tags(native_run.data.tags),
476 params=self.data_translation.deprocess_params(native_run.data.params),
477 metrics=self.data_translation.deprocess_metrics(native_run.data.metrics),
478 status=self.data_translation.run_status_from_mlflow(native_run.info.status),
479 end_time=self.data_translation.mlflow_ts_to_datetime(native_run.info.end_time),
480 start_time=self.data_translation.mlflow_ts_to_datetime(native_run.info.start_time),
481 )
483 def _parse_model(self, native_model: NM) -> M:
484 return M(
485 name=native_model.name,
486 tags=self.data_translation.deprocess_tags(native_model.tags),
487 )
489 def _parse_model_version(self, native_mv: NV, model: M, run: R) -> V:
490 path_in_run = (
491 str(urls.parse_url(native_mv.source)).removeprefix(f"runs:///{run.id}/").removeprefix(run.repo).strip("/")
492 )
494 return V(
495 run=run,
496 model=model,
497 path_in_run=path_in_run,
498 version=str(native_mv.version),
499 tags=self.data_translation.deprocess_tags(native_mv.tags),
500 )
502 # =======================================================================================================
503 # === Implementations of abstract methods from `BaseMlflowApi` ==========================================
505 def _impl_default_cache_dir(self) -> Path:
506 """Get default cache dir based on the current MLflow API settings."""
507 return Path.home().joinpath(".cache/mlopus/mlflow-providers/mlflow", self._default_cache_id)
509 def _impl_get_exp_url(self, exp_id: str) -> urls.Url:
510 """Get Experiment URL."""
511 path = "" if urls.is_local(base := self.tracking_uri) else "#/experiments"
512 return urls.urljoin(base, path, exp_id)
514 def _impl_get_run_url(self, run_id: str, exp_id: str) -> urls.Url:
515 """Get Run URL."""
516 path = "" if urls.is_local(base := self._impl_get_exp_url(exp_id)) else "runs"
517 return urls.urljoin(base, path, run_id)
519 def _impl_get_model_url(self, name: str) -> urls.Url:
520 """Get URL to registered model."""
521 name = patterns.encode_model_name(name)
522 path = "models/%s" if urls.is_local(base := self.tracking_uri) else "#/models/%s"
523 return urls.urljoin(base, path % name)
525 def _impl_get_mv_url(self, name: str, version: str) -> urls.Url:
526 """Get model version URL."""
527 name = patterns.encode_model_name(name)
528 path = "models/%s/version-%s" if urls.is_local(base := self.tracking_uri) else "#/models/%s/versions/%s"
529 return urls.urljoin(base, path % (name, version))
531 def _impl_fetch_exp(self, exp_id: str) -> E:
532 """Get Experiment by ID."""
533 with self._client() as cli:
534 return self._parse_experiment(cli.get_experiment(exp_id))
536 def _impl_fetch_run(self, run_id: str) -> R:
537 """Get Run by ID."""
538 native_run = self._using_client(lambda client: client.get_run(run_id))
539 experiment = self._fetch_exp(native_run.info.experiment_id)
540 return self._parse_run(native_run, experiment)
542 def _impl_fetch_model(self, name: str) -> M:
543 """Get registered Model by name."""
544 return self._parse_model(self._using_client(lambda client: client.get_registered_model(name)))
546 def _impl_fetch_mv(self, name_and_version: Tuple[str, str]) -> V:
547 """Get ModelVersion by name and version."""
548 native_mv = self._using_client(lambda client: client.get_model_version(*name_and_version))
549 run = self._fetch_run(native_mv.run_id)
550 model = self._fetch_model(native_mv.name)
551 return self._parse_model_version(native_mv, model, run)
553 def _impl_find_experiments(
554 self, query: mongo.Query, sorting: mongo.Sorting
555 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[E]]:
556 """Push down MongoDB query where possible and return query remainder with results iterator."""
557 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_exp(query, sorting)
558 logger.debug("Query push down (exp): %s %s", filter_expr, sort_expr)
560 paginator = iter_utils.Paginator[NE](
561 lambda token: _open_paged_list(
562 self._using_client(
563 lambda client: client.search_experiments(
564 page_token=token,
565 filter_string=filter_expr,
566 order_by=sort_expr.split(", ") if sort_expr else None,
567 )
568 ),
569 ),
570 )
572 return query_remainder, sort_remainder, paginator.map_results(self._parse_experiment)
574 def _impl_find_runs(
575 self, query: mongo.Query, sorting: mongo.Sorting
576 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[R]]:
577 """Push down MongoDB query where possible and return query remainder with results iterator."""
578 filter_expr, query_remainder, sort_expr, sort_remainder, exp_ids = self.query_push_down.parse_run(
579 query, sorting
580 )
581 logger.debug("Query push down (run): %s %s", filter_expr, sort_expr)
583 paginator = iter_utils.Paginator[NR](
584 lambda token: _open_paged_list(
585 self._using_client(
586 lambda client: client.search_runs(
587 page_token=token,
588 experiment_ids=exp_ids,
589 filter_string=filter_expr,
590 order_by=sort_expr.split(", ") if sort_expr else None,
591 )
592 ),
593 ),
594 ).map_results(
595 lambda native_run: self._parse_run(
596 native_run=native_run,
597 experiment=self._fetch_exp(native_run.info.experiment_id),
598 ),
599 )
601 return query_remainder, sort_remainder, paginator
603 def _impl_find_models(
604 self, query: mongo.Query, sorting: mongo.Sorting
605 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[M]]:
606 """Push down MongoDB query where possible and return query remainder with results iterator."""
607 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_model(query, sorting)
608 logger.debug("Query push down (model): %s %s", filter_expr, sort_expr)
610 paginator = iter_utils.Paginator[NM](
611 lambda token: _open_paged_list(
612 self._using_client(
613 lambda client: client.search_registered_models(
614 page_token=token,
615 filter_string=filter_expr,
616 order_by=sort_expr.split(", ") if sort_expr else None,
617 )
618 ),
619 ),
620 )
622 return query_remainder, sort_remainder, paginator.map_results(self._parse_model)
624 def _impl_find_mv(
625 self, query: mongo.Query, sorting: mongo.Sorting
626 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[V]]:
627 """Push down MongoDB query where possible and return query remainder with results iterator."""
628 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_mv(query, sorting)
629 logger.debug("Query push down (mv): %s %s", filter_expr, sort_expr)
631 paginator = iter_utils.Paginator[NV](
632 lambda token: _open_paged_list(
633 self._using_client(
634 lambda client: client.search_model_versions(
635 page_token=token,
636 filter_string=filter_expr,
637 order_by=sort_expr.split(", ") if sort_expr else None,
638 )
639 ),
640 ),
641 ).map_results(
642 lambda native_mv: self._parse_model_version(
643 native_mv=native_mv,
644 run=self._fetch_run(native_mv.run_id),
645 model=self._fetch_model(native_mv.name),
646 ),
647 )
649 return query_remainder, sort_remainder, paginator
651 def _impl_find_child_runs(self, run: R) -> Iterator[R]:
652 """Find child runs."""
653 return self.find_runs({"exp.id": run.exp.id, f"tags.{self.tag_keys.parent_run_id}": run.id})
655 def _impl_create_exp(self, name: str, tags: Mapping) -> E:
656 """Create experiment and return its metadata."""
657 tags = self.data_translation.preprocess_tags(tags)
658 return self._get_exp(self._using_client(lambda client: client.create_experiment(name=name, tags=tags)))
660 def _impl_create_model(self, name: str, tags: Mapping) -> M:
661 """Create registered model and return its metadata."""
662 tags = self.data_translation.preprocess_tags(tags)
663 return self._parse_model(self._using_client(lambda client: client.create_registered_model(name, tags)))
665 def _impl_create_run(
666 self, exp_id: str, name: str | None, repo: urls.Url | None, parent_run_id: str | None = None
667 ) -> str:
668 """Create run."""
669 with self._client() as client:
670 run_id = client.create_run(exp_id, run_name=name).info.run_id
671 if repo:
672 client.set_tag(run_id, key=self.tag_keys.artifacts_repo, value=str(repo), synchronous=True)
673 if parent_run_id:
674 client.set_tag(run_id, key=self.tag_keys.parent_run_id, value=parent_run_id, synchronous=True)
675 return run_id
677 def _impl_set_run_status(self, run_id: str, status: schema.RunStatus):
678 """Set Run status."""
679 with self._client() as client:
680 client.update_run(run_id, status=self.data_translation.run_status_to_mlflow(status, as_str=True))
682 def _impl_set_run_end_time(self, run_id: str, end_time: datetime):
683 """Set Run end time."""
684 end_time = self.data_translation.datetime_to_mlflow_ts(end_time)
685 self._using_client(lambda client: client.tracking_client.store.update_run_info(run_id, None, end_time, None))
687 def _impl_update_exp_tags(self, exp_id: str, tags: Mapping):
688 """Update Exp tags."""
689 with self._client() as client:
690 for k, v in self.data_translation.preprocess_tags(tags).items():
691 client.set_experiment_tag(exp_id, k, v)
693 def _impl_update_run_tags(self, run_id: str, tags: Mapping):
694 """Update Run tags."""
695 with self._client() as client:
696 for k, v in self.data_translation.preprocess_tags(tags).items():
697 client.set_tag(run_id, k, v, synchronous=True)
699 def _impl_update_model_tags(self, name: str, tags: Mapping):
700 """Update Model tags."""
701 with self._client() as client:
702 for k, v in self.data_translation.preprocess_tags(tags).items():
703 client.set_registered_model_tag(name, k, v)
705 def _impl_update_mv_tags(self, name: str, version: str, tags: Mapping):
706 """Update Exp tags."""
707 with self._client() as client:
708 for k, v in self.data_translation.preprocess_tags(tags).items():
709 client.set_model_version_tag(name, version, k, v)
711 def _impl_log_run_params(self, run_id: str, params: Mapping):
712 """Log run params."""
713 with self._client() as client:
714 for k, v in self.data_translation.preprocess_params(params).items():
715 client.log_param(run_id, k, v, synchronous=True)
717 def _impl_log_run_metrics(self, run_id: str, metrics: Mapping):
718 """Log run metrics."""
719 with self._client() as client:
720 for k, v in self.data_translation.preprocess_metrics(metrics).items():
721 client.log_metric(run_id, k, v, synchronous=True)
723 def _impl_register_mv(self, model: M, run: R, path_in_run: str, version: str | None, tags: Mapping) -> V:
724 """Register model version."""
725 assert version is None, f"Arbitrary `version` not supported in '{self.__class__.__name__}'"
726 tags = self.data_translation.preprocess_tags(tags)
727 source = str(urls.urljoin(run.repo, path_in_run))
728 native_mv = self._using_client(lambda client: client.create_model_version(model.name, source, run.id, tags))
729 return self._parse_model_version(native_mv, model, run)
732def _open_paged_list(paged_list: PagedList[A]) -> iter_utils.Page[A]:
733 """Convert MLflow native PagedList to `iter_utils.Page`."""
734 return iter_utils.Page(token=paged_list.token, results=paged_list)