Coverage for src/mlopus/mlflow/traits.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import functools
2import logging
3from typing import Dict, Any
5import mlopus.mlflow
6from mlopus.utils import pydantic, dicts
7from .api.base import BaseMlflowApi
8from .api.common import schema
9from .api.run import RunApi
10from .utils import get_api
12logger = logging.getLogger(__name__)
15class MlflowApiMixin(pydantic.BaseModel):
16 """Mixin for pydantic classes that hold a reference to a Mlflow API instance.
18 The API is instantiated by the utility `mlopus.mlflow.get_api()` on object initialization.
20 Example:
22 .. code-block:: python
24 class Foo(MlflowMixinApi):
25 pass
27 foo = Foo(
28 mlflow_api={"plugin": "...", "conf": {...}} # kwargs for `mlopus.mlflow.get_api()`
29 )
31 foo.mlflow_api # BaseMlflowApi
32 """
34 mlflow_api: BaseMlflowApi = pydantic.Field(
35 exclude=True,
36 default=None,
37 description=(
38 "Instance of :class:`~mlopus.mlflow.BaseMlflowApi` "
39 "or a `dict` of keyword arguments for :func:`mlopus.mlflow.get_api`."
40 ),
41 )
43 @pydantic.validator("mlflow_api", pre=True) # noqa
44 @classmethod
45 def _load_mlflow_api(cls, value: BaseMlflowApi | Dict[str, Any] | None) -> BaseMlflowApi:
46 return value if isinstance(value, BaseMlflowApi) else get_api(**value or {})
48 def using(self, mlflow_api: BaseMlflowApi) -> "MlflowApiMixin":
49 """Get a copy of this object that uses the specified MLflow API."""
50 return self.copy(update={"mlflow_api": mlflow_api})
53class ExpConf(pydantic.BaseModel):
54 """Experiment specification for `MlflowRunManager`."""
56 name: str = pydantic.Field(description="Used when getting or creating the experiment.")
59class RunConf(pydantic.BaseModel, pydantic.EmptyStrAsMissing):
60 """Run specification for `MlflowRunManager`."""
62 id: str | None = pydantic.Field(default=None, description="Run ID for resuming a previous run.")
63 name: str | None = pydantic.Field(default=None, description="Run name for starting a new run.")
64 parent: str | None = pydantic.Field(default=None, description="Parent run ID for starting a new run.")
65 tags: dicts.AnyDict = pydantic.Field(
66 default_factory=dict,
67 description="Run tags for starting a new run or finding an ongoing one.",
68 )
71class MlflowRunManager(MlflowApiMixin):
72 """A pydantic object that holds a reference to an ongoing MLflow Run.
74 1. If ``run.id`` is given, that run is resumed.
75 2. Otherwise, an ongoing run is searched for in ``exp.name`` containing ``run.tags``
76 3. If none can be found, a new run is started in ``exp.name`` containing ``run.tags``
78 Example:
80 .. code-block:: python
82 config = {
83 "api": {...}, # kwargs for `mlopus.mlflow.get_api()`
84 "exp": {"name": ...},
85 "run": {"name": ..., "tags": ..., "id": ...},
86 }
88 foo_1 = MlflowRunManager(**config)
89 foo_2 = MlflowRunManager(**config)
91 # Objects with same config share the same managed run
92 assert foo_1.run.id == foo_2.run.id
94 # Accessing the cached property `run` triggers the resume/search/creation of the run.
95 """
97 mlflow_api: BaseMlflowApi = pydantic.Field(
98 alias="api",
99 exclude=True,
100 description=(
101 "Instance of :class:`BaseMlflowApi` or a `dict` of keyword arguments for :func:`mlopus.mlflow.get_api`."
102 ),
103 )
105 exp: ExpConf = pydantic.Field(
106 default_factory=ExpConf,
107 description="Experiment specification (created if doesn't exist). Used to find or create the run.",
108 )
110 run_conf: RunConf = pydantic.Field(
111 alias="run",
112 default_factory=RunConf,
113 description="Run specification, used for resuming, finding or creating the run.",
114 )
116 @functools.cached_property
117 def run(self) -> RunApi:
118 """API handle for the ongoing MLflow Run."""
119 return self._resolve_run()
121 def _resolve_run(self) -> RunApi:
122 """Resume, find or start run and return API handle."""
123 if run_id := self.run_conf.id:
124 if (run := self.mlflow_api.get_run(run_id)).status != mlopus.mlflow.RunStatus.RUNNING:
125 logger.info("MLflow Run URL: %s", run.resume().url)
126 return run
128 assert self.run_conf.tags, "Cannot locate shared run or start a new one without tags."
130 query = {
131 **{"tags.%s" % ".".join(k): v for k, v in dicts.flatten(self.run_conf.tags).items()},
132 "status": schema.RunStatus.RUNNING,
133 "exp.id": (exp := self.mlflow_api.get_or_create_exp(self.exp.name)).id,
134 }
136 for run in self.mlflow_api.find_runs(query, sorting=[("start_time", -1)]):
137 return run.resume()
139 run = self.mlflow_api.start_run(exp, self.run_conf.name, self.run_conf.tags, parent=self.run_conf.parent)
140 logger.info("MLflow Run URL: %s", run.url)
141 return run
144class MlflowRunMixin(pydantic.BaseModel):
145 """Mixin for pydantic classes that hold a reference to a `MlflowRunManager`."""
147 run_manager: MlflowRunManager | None = pydantic.Field(
148 exclude=True,
149 alias="mlflow",
150 description="Instance or dict to be parsed into instance of :class:`~mlopus.mlflow.MlflowRunManager`",
151 )