Coverage for src/mlopus/mlflow/api/model.py: 89%
37 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import typing
2from pathlib import Path
3from typing import Callable, Iterator, Mapping
5from mlopus.utils import dicts, pydantic, mongo
6from . import entity, contract
7from .common import schema, decorators
8from .mv import ModelVersionApi
10V = schema.ModelVersion
12RunIdentifier = contract.RunIdentifier
15class ModelApi(schema.Model, entity.EntityApi):
16 """Registered model metadata with MLflow API handle."""
18 def _get_latest_data(self) -> schema.Model:
19 """Get latest data for this entity. Used for self update after methods with the `require_update` decorator."""
20 return self.api.get_model(self)
22 @property
23 def url(self) -> str:
24 """This model's URL."""
25 return self.api.get_model_url(self)
27 def cache_meta(self) -> "ModelApi":
28 """Fetch latest metadata for this model and save it to cache."""
29 return self._use_values_from(self.api.cache_model_meta(self))
31 def export_meta(self, target: Path) -> "ModelApi":
32 """Export model metadata cache to target.
34 :param target: Cache export path.
35 """
36 return self._use_values_from(self.api.export_model_meta(self, target))
38 @decorators.require_update
39 def set_tags(self, tags: Mapping) -> "ModelApi":
40 """Set tags on this model.
42 :param tags: See :attr:`schema.Model.tags`.
43 """
44 self.api.set_tags_on_model(self, tags)
45 return self
47 @pydantic.validate_arguments
48 def get_version(self, version: str) -> ModelVersionApi:
49 """Get ModelVersion API by version identifier.
51 :param version: Version identifier.
52 """
53 return typing.cast(ModelVersionApi, self.api.get_model_version((self.name, version)))
55 @pydantic.validate_arguments
56 def find_versions(
57 self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None
58 ) -> Iterator[ModelVersionApi]:
59 """Search versions of this model with query in MongoDB query language.
61 :param query: Query in MongoDB query language.
62 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`).
63 """
64 results = self.api.find_model_versions(dicts.set_reserved_key(query, key="model.name", val=self.name), sorting)
65 return typing.cast(Iterator[ModelVersionApi], results)
67 @pydantic.validate_arguments
68 def log_version(
69 self,
70 run: RunIdentifier,
71 source: Path | Callable[[Path], None],
72 path_in_run: str | None = None,
73 keep_the_source: bool | None = None,
74 allow_duplication: bool | None = None,
75 use_cache: bool | None = None,
76 version: str | None = None,
77 tags: Mapping | None = None,
78 ) -> ModelVersionApi:
79 """Publish artifact file or dir as model version inside the specified experiment run.
81 :param run: | Run ID or object.
83 :param source: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.source`
85 :param path_in_run: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_model_version.path_in_run`
87 :param keep_the_source: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.keep_the_source`
89 :param allow_duplication: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.allow_duplication`
91 :param use_cache: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.use_cache`
93 :param version: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_model_version.version`
95 :param tags: | Model version tags.
96 | See :class:`schema.ModelVersion.tags`
98 :return: New model version metadata with API handle.
99 """
100 from .run import RunApi
102 mv = self.api.log_model_version(
103 self, run, source, path_in_run, keep_the_source, allow_duplication, use_cache, version, tags
104 )
106 if isinstance(run, RunApi):
107 mv.run = run # inject live run object so the mv gets updates regarding the run status
109 return typing.cast(ModelVersionApi, mv)