Coverage for src/mlopus/mlflow/traits.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import functools 

2import logging 

3from typing import Dict, Any 

4 

5import mlopus.mlflow 

6from mlopus.utils import pydantic, dicts 

7from .api.base import BaseMlflowApi 

8from .api.common import schema 

9from .api.run import RunApi 

10from .utils import get_api 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15class MlflowApiMixin(pydantic.BaseModel): 

16 """Mixin for pydantic classes that hold a reference to a Mlflow API instance. 

17 

18 The API is instantiated by the utility `mlopus.mlflow.get_api()` on object initialization. 

19 

20 Example: 

21 

22 .. code-block:: python 

23 

24 class Foo(MlflowMixinApi): 

25 pass 

26 

27 foo = Foo( 

28 mlflow_api={"plugin": "...", "conf": {...}} # kwargs for `mlopus.mlflow.get_api()` 

29 ) 

30 

31 foo.mlflow_api # BaseMlflowApi 

32 """ 

33 

34 mlflow_api: BaseMlflowApi = pydantic.Field( 

35 exclude=True, 

36 default=None, 

37 description=( 

38 "Instance of :class:`~mlopus.mlflow.BaseMlflowApi` " 

39 "or a `dict` of keyword arguments for :func:`mlopus.mlflow.get_api`." 

40 ), 

41 ) 

42 

43 @pydantic.validator("mlflow_api", pre=True) # noqa 

44 @classmethod 

45 def _load_mlflow_api(cls, value: BaseMlflowApi | Dict[str, Any] | None) -> BaseMlflowApi: 

46 return value if isinstance(value, BaseMlflowApi) else get_api(**value or {}) 

47 

48 def using(self, mlflow_api: BaseMlflowApi) -> "MlflowApiMixin": 

49 """Get a copy of this object that uses the specified MLflow API.""" 

50 return self.copy(update={"mlflow_api": mlflow_api}) 

51 

52 

53class ExpConf(pydantic.BaseModel): 

54 """Experiment specification for `MlflowRunManager`.""" 

55 

56 name: str = pydantic.Field(description="Used when getting or creating the experiment.") 

57 

58 

59class RunConf(pydantic.BaseModel, pydantic.EmptyStrAsMissing): 

60 """Run specification for `MlflowRunManager`.""" 

61 

62 id: str | None = pydantic.Field(default=None, description="Run ID for resuming a previous run.") 

63 name: str | None = pydantic.Field(default=None, description="Run name for starting a new run.") 

64 parent: str | None = pydantic.Field(default=None, description="Parent run ID for starting a new run.") 

65 tags: dicts.AnyDict = pydantic.Field( 

66 default_factory=dict, 

67 description="Run tags for starting a new run or finding an ongoing one.", 

68 ) 

69 

70 

71class MlflowRunManager(MlflowApiMixin): 

72 """A pydantic object that holds a reference to an ongoing MLflow Run. 

73 

74 1. If ``run.id`` is given, that run is resumed. 

75 2. Otherwise, an ongoing run is searched for in ``exp.name`` containing ``run.tags`` 

76 3. If none can be found, a new run is started in ``exp.name`` containing ``run.tags`` 

77 

78 Example: 

79 

80 .. code-block:: python 

81 

82 config = { 

83 "api": {...}, # kwargs for `mlopus.mlflow.get_api()` 

84 "exp": {"name": ...}, 

85 "run": {"name": ..., "tags": ..., "id": ...}, 

86 } 

87 

88 foo_1 = MlflowRunManager(**config) 

89 foo_2 = MlflowRunManager(**config) 

90 

91 # Objects with same config share the same managed run 

92 assert foo_1.run.id == foo_2.run.id 

93 

94 # Accessing the cached property `run` triggers the resume/search/creation of the run. 

95 """ 

96 

97 mlflow_api: BaseMlflowApi = pydantic.Field( 

98 alias="api", 

99 exclude=True, 

100 description=( 

101 "Instance of :class:`BaseMlflowApi` or a `dict` of keyword arguments for :func:`mlopus.mlflow.get_api`." 

102 ), 

103 ) 

104 

105 exp: ExpConf = pydantic.Field( 

106 default_factory=ExpConf, 

107 description="Experiment specification (created if doesn't exist). Used to find or create the run.", 

108 ) 

109 

110 run_conf: RunConf = pydantic.Field( 

111 alias="run", 

112 default_factory=RunConf, 

113 description="Run specification, used for resuming, finding or creating the run.", 

114 ) 

115 

116 @functools.cached_property 

117 def run(self) -> RunApi: 

118 """API handle for the ongoing MLflow Run.""" 

119 return self._resolve_run() 

120 

121 def _resolve_run(self) -> RunApi: 

122 """Resume, find or start run and return API handle.""" 

123 if run_id := self.run_conf.id: 

124 if (run := self.mlflow_api.get_run(run_id)).status != mlopus.mlflow.RunStatus.RUNNING: 

125 logger.info("MLflow Run URL: %s", run.resume().url) 

126 return run 

127 

128 assert self.run_conf.tags, "Cannot locate shared run or start a new one without tags." 

129 

130 query = { 

131 **{"tags.%s" % ".".join(k): v for k, v in dicts.flatten(self.run_conf.tags).items()}, 

132 "status": schema.RunStatus.RUNNING, 

133 "exp.id": (exp := self.mlflow_api.get_or_create_exp(self.exp.name)).id, 

134 } 

135 

136 for run in self.mlflow_api.find_runs(query, sorting=[("start_time", -1)]): 

137 return run.resume() 

138 

139 run = self.mlflow_api.start_run(exp, self.run_conf.name, self.run_conf.tags, parent=self.run_conf.parent) 

140 logger.info("MLflow Run URL: %s", run.url) 

141 return run 

142 

143 

144class MlflowRunMixin(pydantic.BaseModel): 

145 """Mixin for pydantic classes that hold a reference to a `MlflowRunManager`.""" 

146 

147 run_manager: MlflowRunManager | None = pydantic.Field( 

148 exclude=True, 

149 alias="mlflow", 

150 description="Instance or dict to be parsed into instance of :class:`~mlopus.mlflow.MlflowRunManager`", 

151 )