Coverage for src/mlopus/mlflow/api/model.py: 89%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import typing 

2from pathlib import Path 

3from typing import Callable, Iterator, Mapping 

4 

5from mlopus.utils import dicts, pydantic, mongo 

6from . import entity, contract 

7from .common import schema, decorators 

8from .mv import ModelVersionApi 

9 

10V = schema.ModelVersion 

11 

12RunIdentifier = contract.RunIdentifier 

13 

14 

15class ModelApi(schema.Model, entity.EntityApi): 

16 """Registered model metadata with MLflow API handle.""" 

17 

18 def _get_latest_data(self) -> schema.Model: 

19 """Get latest data for this entity. Used for self update after methods with the `require_update` decorator.""" 

20 return self.api.get_model(self) 

21 

22 @property 

23 def url(self) -> str: 

24 """This model's URL.""" 

25 return self.api.get_model_url(self) 

26 

27 def cache_meta(self) -> "ModelApi": 

28 """Fetch latest metadata for this model and save it to cache.""" 

29 return self._use_values_from(self.api.cache_model_meta(self)) 

30 

31 def export_meta(self, target: Path) -> "ModelApi": 

32 """Export model metadata cache to target. 

33 

34 :param target: Cache export path. 

35 """ 

36 return self._use_values_from(self.api.export_model_meta(self, target)) 

37 

38 @decorators.require_update 

39 def set_tags(self, tags: Mapping) -> "ModelApi": 

40 """Set tags on this model. 

41 

42 :param tags: See :attr:`schema.Model.tags`. 

43 """ 

44 self.api.set_tags_on_model(self, tags) 

45 return self 

46 

47 @pydantic.validate_arguments 

48 def get_version(self, version: str) -> ModelVersionApi: 

49 """Get ModelVersion API by version identifier. 

50 

51 :param version: Version identifier. 

52 """ 

53 return typing.cast(ModelVersionApi, self.api.get_model_version((self.name, version))) 

54 

55 @pydantic.validate_arguments 

56 def find_versions( 

57 self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None 

58 ) -> Iterator[ModelVersionApi]: 

59 """Search versions of this model with query in MongoDB query language. 

60 

61 :param query: Query in MongoDB query language. 

62 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`). 

63 """ 

64 results = self.api.find_model_versions(dicts.set_reserved_key(query, key="model.name", val=self.name), sorting) 

65 return typing.cast(Iterator[ModelVersionApi], results) 

66 

67 @pydantic.validate_arguments 

68 def log_version( 

69 self, 

70 run: RunIdentifier, 

71 source: Path | Callable[[Path], None], 

72 path_in_run: str | None = None, 

73 keep_the_source: bool | None = None, 

74 allow_duplication: bool | None = None, 

75 use_cache: bool | None = None, 

76 version: str | None = None, 

77 tags: Mapping | None = None, 

78 ) -> ModelVersionApi: 

79 """Publish artifact file or dir as model version inside the specified experiment run. 

80 

81 :param run: | Run ID or object. 

82 

83 :param source: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.source` 

84 

85 :param path_in_run: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_model_version.path_in_run` 

86 

87 :param keep_the_source: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.keep_the_source` 

88 

89 :param allow_duplication: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.allow_duplication` 

90 

91 :param use_cache: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_run_artifact.use_cache` 

92 

93 :param version: | See :paramref:`~mlopus.mlflow.BaseMlflowApi.log_model_version.version` 

94 

95 :param tags: | Model version tags. 

96 | See :class:`schema.ModelVersion.tags` 

97 

98 :return: New model version metadata with API handle. 

99 """ 

100 from .run import RunApi 

101 

102 mv = self.api.log_model_version( 

103 self, run, source, path_in_run, keep_the_source, allow_duplication, use_cache, version, tags 

104 ) 

105 

106 if isinstance(run, RunApi): 

107 mv.run = run # inject live run object so the mv gets updates regarding the run status 

108 

109 return typing.cast(ModelVersionApi, mv)