Coverage for src/mlopus/mlflow/providers/mlflow.py: 94%

324 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import contextlib 

2import hashlib 

3import logging 

4import os 

5from datetime import datetime 

6from pathlib import Path 

7from typing import Dict, Any, TypeVar, Callable, Tuple, List, Literal, Mapping, Iterable, Set, Iterator, ContextManager 

8 

9from mlflow import MlflowClient as _NativeMlflowClient, entities as native 

10from mlflow.entities import model_registry as native_model_registry 

11from mlflow.store.entities import PagedList 

12from mlflow.tracking import artifact_utils as mlflow_artifact_utils 

13from mlflow.utils import rest_utils as mlflow_rest_utils 

14 

15from mlopus.mlflow.api.base import BaseMlflowApi 

16from mlopus.mlflow.api.common import schema, patterns 

17from mlopus.utils import pydantic, mongo, dicts, json_utils, iter_utils, urls, string_utils, env_utils 

18 

19logger = logging.getLogger(__name__) 

20 

21A = TypeVar("A") # Any type 

22A2 = TypeVar("A2") # Any type 

23 

24# Entity types used in MLOpus 

25E = schema.Experiment 

26R = schema.Run 

27M = schema.Model 

28V = schema.ModelVersion 

29T = TypeVar("T", bound=schema.BaseEntity) 

30 

31# Types used natively by open source MLflow 

32NE = native.Experiment 

33NR = native.Run 

34NM = native_model_registry.RegisteredModel 

35NV = native_model_registry.ModelVersion 

36NT = TypeVar("NT", NE, NR, NM, NV) 

37 

38 

39class MlflowClient(_NativeMlflowClient): 

40 """Patch of native MlflowClient.""" 

41 

42 def healthcheck(self): 

43 """Check if service is available.""" 

44 if (url := urls.parse_url(self.tracking_uri)).scheme in ("http", "https"): 

45 creds = self._tracking_client.store.get_host_creds() 

46 response = mlflow_rest_utils.http_request(creds, "/health", "GET") 

47 assert response.status_code == 200, f"Healthcheck failed for URI '{url}'" 

48 elif not urls.is_local(url): 

49 raise NotImplementedError(f"No healthcheck for URI '{url}'") 

50 

51 def get_artifact_uri(self, run_id: str) -> str: 

52 """Get root URL of remote run artifacts, according to MLflow server.""" 

53 return mlflow_artifact_utils.get_artifact_uri(run_id, tracking_uri=self.tracking_uri) 

54 

55 @property 

56 def tracking_client(self): 

57 """Public access to tracking client.""" 

58 return self._tracking_client 

59 

60 

61class KeepUntouched(pydantic.BaseModel): 

62 """Set of rules for keys of params/values/metrics whose values are to be kept untouched. 

63 

64 Their values are exchanged to/from MLflow server without any pre-processing/encoding/escaping. 

65 """ 

66 

67 prefixes: Dict[str, Set[str]] = {"tags": {"mlflow"}} 

68 

69 def __call__(self, key_parts: Tuple[str, ...], scope: Literal["tags", "params", "metrics"]) -> bool: 

70 return key_parts[0] in self.prefixes.get(scope, ()) 

71 

72 

73class MaxLength(pydantic.BaseModel): 

74 """Max value lengths.""" 

75 

76 tag: int = 500 

77 param: int = 500 

78 

79 

80class MlflowDataTranslation(pydantic.BaseModel): 

81 """Translates native MLflow data to MLOpus format and back.""" 

82 

83 ignore_json_errors: bool = True 

84 max_length: MaxLength = MaxLength() 

85 keep_untouched: KeepUntouched = KeepUntouched() 

86 

87 class Config(BaseMlflowApi.Config): 

88 """Class constants.""" 

89 

90 STATUS_TO_MLFLOW: Dict[schema.RunStatus, native.RunStatus] = { 

91 schema.RunStatus.FAILED: native.RunStatus.FAILED, 

92 schema.RunStatus.RUNNING: native.RunStatus.RUNNING, 

93 schema.RunStatus.FINISHED: native.RunStatus.FINISHED, 

94 schema.RunStatus.SCHEDULED: native.RunStatus.SCHEDULED, 

95 } 

96 

97 STATUS_FROM_MLFLOW: Dict[native.RunStatus, schema.RunStatus | None] = { 

98 **{v: k for k, v in STATUS_TO_MLFLOW.items()}, 

99 native.RunStatus.KILLED: schema.RunStatus.FAILED, 

100 } 

101 

102 # ======================================================================================================= 

103 # === Datetime translation ============================================================================== 

104 

105 @classmethod 

106 def mlflow_ts_to_datetime(cls, timestamp: int | None) -> datetime | None: 

107 """Parse MLflow timestamp as datetime.""" 

108 return None if timestamp is None else datetime.fromtimestamp(timestamp / 1000) 

109 

110 @classmethod 

111 def datetime_to_mlflow_ts(cls, datetime_: datetime) -> int: 

112 """Coerce datetime to MLflow timestamp.""" 

113 return int(datetime_.timestamp() * 1000) 

114 

115 # ======================================================================================================= 

116 # === Enum translation ================================================================================== 

117 

118 @classmethod 

119 def run_status_from_mlflow(cls, status: native.RunStatus | str) -> schema.RunStatus: 

120 """Parse run status enum value from MLflow.""" 

121 status = native.RunStatus.from_string(status) if isinstance(status, str) else status 

122 return cls.Config.STATUS_FROM_MLFLOW[status] 

123 

124 @classmethod 

125 def run_status_to_mlflow(cls, status: schema.RunStatus, as_str: bool = False) -> native.RunStatus | str: 

126 """Coerce run status enum value to MLflow format.""" 

127 status = cls.Config.STATUS_TO_MLFLOW[status] 

128 return native.RunStatus.to_string(status) if as_str else status 

129 

130 # ======================================================================================================= 

131 # === Dict translation ================================================================================== 

132 

133 @classmethod 

134 @string_utils.retval_matches(patterns.TAG_PARAM_OR_METRIC_KEY) 

135 def encode_key(cls, key: str) -> str: 

136 """Make sure dict key is safe for storage and query.""" 

137 return str(key) 

138 

139 @classmethod 

140 def _decode_key(cls, key: str) -> str: 

141 """Inverse func of _encode_key.""" 

142 return key 

143 

144 @classmethod 

145 def flatten_key(cls, key_parts: Iterable[str]) -> str: 

146 """Turns a nested dict key into a flat one by joining with dot delimiter.""" 

147 return ".".join(key_parts) 

148 

149 @classmethod 

150 def unflatten_key(cls, key: str) -> Tuple[str, ...]: 

151 """Inverse func of _flatten_key.""" 

152 return tuple(key.split(".")) 

153 

154 def preprocess_key(self, key_parts: Tuple[str, ...]) -> str: 

155 """Process key of tag, param or metric to be sent to MLflow.""" 

156 return self.flatten_key(self.encode_key(k) for k in key_parts) 

157 

158 def _preprocess_dict( 

159 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2 | None] 

160 ) -> Iterable[Tuple[str, A2]]: 

161 for key_parts, val in dicts.flatten(data).items(): 

162 if (mapped_val := val_mapper(key_parts, val)) is not None: 

163 yield self.preprocess_key(key_parts), mapped_val 

164 

165 def preprocess_dict( 

166 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2 | None] 

167 ) -> Dict[str, A2]: 

168 """Pre-process dict of tags, params or metrics to be compatible with MLflow. 

169 

170 - Flatten nested keys as tuples 

171 - URL-encode dots and other forbidden chars in keys 

172 - Join tuple keys as dot-delimited strings 

173 - Map leaf-values (tags and params as json, metrics as float) 

174 """ 

175 return dict(self._preprocess_dict(data, val_mapper)) 

176 

177 def _deprocess_dict( 

178 self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2] 

179 ) -> Iterable[Tuple[Tuple[str, ...], A2]]: 

180 """Inverse func of _preprocess_dict.""" 

181 for key, val in data.items(): 

182 key_parts = tuple(self._decode_key(k) for k in self.unflatten_key(key)) 

183 yield key_parts, val_mapper(key_parts, val) 

184 

185 def deprocess_dict(self, data: Mapping[str, A], val_mapper: Callable[[Tuple[str, ...], A], A2]) -> Dict[str, A2]: 

186 """Inverse func of _preprocess_dict.""" 

187 return dicts.unflatten(self._deprocess_dict(data, val_mapper)) 

188 

189 def process_tag(self, key_parts: Tuple[str, ...], val: Any) -> str | None: 

190 """JSON encode user-tag.""" 

191 if not self.keep_untouched(key_parts, scope="tags"): 

192 val = string_utils.escape_sql_single_quote(json_utils.dumps(val)) 

193 if len(str(val)) > self.max_length.tag: 

194 val = None 

195 logger.warning("Ignoring tag above max length of %s: %s", self.max_length.tag, key_parts) 

196 return val 

197 

198 def _deprocess_tag(self, key_parts: Tuple[str, ...], val: str) -> Any: 

199 """JSON decode user-tag value from MLflow.""" 

200 if not self.keep_untouched(key_parts, scope="tags"): 

201 val = json_utils.loads(string_utils.unscape_sql_single_quote(val), ignore_errors=self.ignore_json_errors) 

202 return val 

203 

204 def process_param(self, key_parts: Tuple[str, ...], val: Any) -> str: # noqa 

205 """JSON encode param.""" 

206 if not self.keep_untouched(key_parts, scope="params"): 

207 val = string_utils.escape_sql_single_quote(json_utils.dumps(val)) 

208 if len(str(val)) > self.max_length.param: 

209 val = None 

210 logger.warning("Ignoring param above max length of %s: %s", self.max_length.param, key_parts) 

211 return val 

212 

213 def _deprocess_param(self, key_parts: Tuple[str, ...], val: str) -> Any: # noqa 

214 """JSON decode param value from MLflow.""" 

215 if not self.keep_untouched(key_parts, scope="params"): 

216 val = json_utils.loads(string_utils.unscape_sql_single_quote(val), ignore_errors=self.ignore_json_errors) 

217 return val 

218 

219 def process_metric(self, key_parts: Tuple[str, ...], val: Any) -> float: # noqa 

220 """Coerce metric value to float.""" 

221 if not self.keep_untouched(key_parts, scope="metrics"): 

222 val = float(val) 

223 return val 

224 

225 def _deprocess_metric(self, key_parts: Tuple[str, ...], val: Any) -> float: # noqa 

226 """Coerce metric val from MLflow to float.""" 

227 if not self.keep_untouched(key_parts, scope="metrics"): 

228 val = float(val) 

229 return val 

230 

231 def preprocess_tags(self, data: Mapping) -> Dict[str, str]: 

232 """Prepare tags dict to be sent to native MLflow API.""" 

233 return self.preprocess_dict(data, val_mapper=self.process_tag) 

234 

235 def deprocess_tags(self, data: Mapping) -> Dict[str, Any]: 

236 """Inverse func of _preprocess_tags.""" 

237 return self.deprocess_dict(data, val_mapper=self._deprocess_tag) 

238 

239 def preprocess_params(self, data: Mapping) -> Dict[str, str]: 

240 """Prepare params dict to be sent to native MLflow API.""" 

241 return self.preprocess_dict(data, val_mapper=self.process_param) 

242 

243 def deprocess_params(self, data: Mapping) -> Dict[str, Any]: 

244 """Inverse func of _preprocess_params.""" 

245 return self.deprocess_dict(data, val_mapper=self._deprocess_param) 

246 

247 def preprocess_metrics(self, data: Mapping) -> Dict[str, float]: 

248 """Prepare metrics dict to be sent to native MLflow API.""" 

249 return self.preprocess_dict(data, val_mapper=self.process_metric) 

250 

251 def deprocess_metrics(self, data: Mapping) -> Dict[str, Any]: 

252 """Inverse func of _preprocess_metrics.""" 

253 return self.deprocess_dict(data, val_mapper=self._deprocess_metric) 

254 

255 

256class MlflowQueryPushDown(mongo.Mongo2Sql): 

257 """Basic MongoDB query to MLflow SQL conversor.""" 

258 

259 mongo_subj_2_mlflow: Dict[str, Dict[str, str]] = { 

260 "exp": { 

261 "name": "name", 

262 }, 

263 "run": { 

264 "id": "id", 

265 "name": "run_name", 

266 "status": "status", 

267 "end_time": "end_time", 

268 "start_time": "start_time", 

269 }, 

270 "model": { 

271 "name": "name", 

272 }, 

273 "mv": { 

274 "run.id": "run_id", 

275 "model.name": "name", 

276 "version": "version_number", 

277 }, 

278 } 

279 data_translation: MlflowDataTranslation = None 

280 nested_subjects: Set[str] = {"metrics"} 

281 

282 def __init__(self, **kwargs): 

283 super().__init__(**kwargs) 

284 if intersect := self.nested_subjects.intersection({"tags", "params"}): 

285 logger.warning( 

286 "Pushing down queries for the nested subject(s) %s may produce " 

287 "incomplete results because of MLflow SQL limitations over JSON fields.", 

288 intersect, 

289 ) 

290 

291 def _parse_subj(self, coll: str, subj: str) -> str | None: 

292 if (scope := (parts := subj.split("."))[0]) in self.nested_subjects: 

293 subj = f"{scope}.%s" % self.data_translation.preprocess_key(tuple(parts[1:])) 

294 else: 

295 subj = self.mongo_subj_2_mlflow[coll].get(subj) 

296 return super()._parse_subj(coll, subj) 

297 

298 def _parse_obj(self, coll: str, subj: str, pred: Any, raw_obj: Any) -> str | None: 

299 if coll == "run" and subj in ("start_time", "end_time") and isinstance(raw_obj, datetime): 

300 raw_obj = self.data_translation.datetime_to_mlflow_ts(raw_obj) 

301 elif isinstance(raw_obj, schema.RunStatus): 

302 raw_obj = self.data_translation.run_status_to_mlflow(raw_obj, as_str=True) 

303 elif process := { 

304 "tags": self.data_translation.process_tag, 

305 "params": self.data_translation.process_param, 

306 "metrics": self.data_translation.process_metric, 

307 }.get((parts := subj.split("."))[0]): 

308 return "'%s'" % process(tuple(parts[1:]), raw_obj) # noqa 

309 

310 return super()._parse_obj(coll, subj, pred, raw_obj) 

311 

312 def parse_exp(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]: 

313 """Parse query and sorting rule for experiments search, return SQL expression and remainder for each.""" 

314 return *self.parse_query(query, coll="exp"), *self.parse_sorting(sorting, coll="exp") # noqa 

315 

316 def parse_run( 

317 self, query: mongo.Query, sorting: mongo.Sorting 

318 ) -> Tuple[str, mongo.Query, str, mongo.Sorting, List[str]]: 

319 """Parse query and sorting rule for runs search, return SQL expression and remainder for each, plus exp IDs.""" 

320 if isinstance(exp_ids := query.pop("exp.id", None), str): 

321 exp_ids = [exp_ids] 

322 elif isinstance(exp_ids, dict) and set(exp_ids.keys()) == "$in": 

323 exp_ids = exp_ids.pop("$in") 

324 if not (isinstance(exp_ids, (list, tuple)) and len(exp_ids) > 0 and all(isinstance(x, str) for x in exp_ids)): 

325 raise ValueError( 

326 f"{self.__class__.__name__}: `exp.id` must be specified when querying runs. " 

327 'Example: {"exp.id": {"$in": ["42", "13"]}}' 

328 ) 

329 return *self.parse_query(query, coll="run"), *self.parse_sorting(sorting, coll="run"), exp_ids # noqa 

330 

331 def parse_model(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]: 

332 """Parse query and sorting rule for models search, return SQL expression and remainder for each.""" 

333 return *self.parse_query(query, coll="model"), *self.parse_sorting(sorting, coll="model") # noqa 

334 

335 def parse_mv(self, query: mongo.Query, sorting: mongo.Sorting) -> Tuple[str, mongo.Query, str, mongo.Sorting]: 

336 """Parse query and sorting rule for model version search, return SQL expression and remainder for each.""" 

337 return *self.parse_query(query, coll="mv"), *self.parse_sorting(sorting, coll="mv") # noqa 

338 

339 

340class MlflowTagKeys(pydantic.BaseModel): 

341 """Default tag keys.""" 

342 

343 artifacts_repo: str = "artifacts_repo" 

344 parent_run_id: str = "mlflow.parentRunId" 

345 

346 

347class MlflowApi(BaseMlflowApi): 

348 """MLflow API provider based on open source MLflow. 

349 

350 **Plugin name:** `mlflow` 

351 

352 **Requires extras:** `mlflow` 

353 

354 **Default cache dir:** `~/.cache/mlopus/mlflow-providers/mlflow/<hashed-tracking-uri>` 

355 

356 Assumptions: 

357 - No artifacts proxy. 

358 - SQL database is server-managed. 

359 """ 

360 

361 tracking_uri: str = pydantic.Field( 

362 default=None, 

363 description=( 

364 "MLflow server URL or path to a local directory. " 

365 "Defaults to the environment variable `MLFLOW_TRACKING_URI`, " 

366 "falls back to `~/.cache/mlflow`." 

367 ), 

368 ) 

369 

370 healthcheck: bool = pydantic.Field( 

371 default=True, 

372 description=( 

373 "If true and not in :attr:`~mlopus.mlflow.BaseMlflowApi.offline_mode`, " 

374 "eagerly attempt connection to the server after initialization." 

375 ), 

376 ) 

377 

378 client_settings: Dict[str, str | int] = pydantic.Field( 

379 default_factory=dict, 

380 description=( 

381 "MLflow client settings. Keys are like the open-source MLflow environment variables, " 

382 "but lower case and without the `MLFLOW_` prefix. Example: `http_request_max_retries`. " 

383 "See: https://mlflow.org/docs/latest/python_api/mlflow.environment_variables.html" 

384 ), 

385 ) 

386 

387 tag_keys: MlflowTagKeys = pydantic.Field( 

388 repr=False, 

389 default_factory=MlflowTagKeys, 

390 description="Tag keys for storing internal information such as parent run ID.", 

391 ) 

392 

393 query_push_down: MlflowQueryPushDown = pydantic.Field( 

394 repr=False, 

395 default_factory=MlflowQueryPushDown, 

396 description=( 

397 "Utility for partial translation of MongoDB queries to open-source MLflow SQL. " 

398 "Users may replace this with a different implementation when subclassing the API." 

399 ), 

400 ) 

401 

402 data_translation: MlflowDataTranslation = pydantic.Field( 

403 repr=False, 

404 default_factory=MlflowDataTranslation, 

405 description=( 

406 "Utility for translating keys and values from MLOpus schema to native MLflow schema and back. " 

407 "Users may replace this with a different implementation when subclassing the API." 

408 ), 

409 ) 

410 

411 # ======================================================================================================= 

412 # === Pydantic validators =============================================================================== 

413 

414 @pydantic.root_validator(pre=True) # noqa 

415 @classmethod 

416 def _valid_tracking_uri(cls, values: dicts.AnyDict) -> dicts.AnyDict: 

417 """Use default if provided value is None or empty string.""" 

418 raw_url = values.get("tracking_uri") or os.environ.get("MLFLOW_TRACKING_URI") or (Path.home() / ".cache/mlflow") 

419 values["tracking_uri"] = str(urls.parse_url(raw_url, resolve_if_local=True)) 

420 return values 

421 

422 def __init__(self, **kwargs): 

423 """Let the query push down use the same data translator as the API.""" 

424 super().__init__(**kwargs) 

425 if self.query_push_down.data_translation is None: 

426 self.query_push_down.data_translation = self.data_translation 

427 

428 if self.healthcheck and not self.offline_mode: 

429 with self._client() as cli: 

430 cli.healthcheck() 

431 

432 # ======================================================================================================= 

433 # === Properties ======================================================================================== 

434 

435 @property 

436 def _default_cache_id(self) -> str: 

437 """Sub-dir under default cache dir. Only used if `cache_dir` is not specified.""" 

438 return hashlib.md5(self.tracking_uri.encode()).hexdigest()[:16] 

439 

440 # ======================================================================================================= 

441 # === Client ============================================================================================ 

442 

443 def _get_client(self) -> MlflowClient: 

444 assert not self.offline_mode, "Cannot use MlflowClient in offline mode." 

445 return MlflowClient(self.tracking_uri) 

446 

447 @contextlib.contextmanager 

448 def _client(self) -> ContextManager[MlflowClient]: 

449 with env_utils.using_env_vars({"MLFLOW_%s" % k.upper(): str(v) for k, v in self.client_settings.items()}): 

450 yield self._get_client() 

451 

452 def _using_client(self, func: Callable[[MlflowClient], A]) -> A: 

453 with self._client() as client: 

454 return func(client) 

455 

456 # ======================================================================================================= 

457 # === Metadata parsers ================================================================================== 

458 

459 def _parse_experiment(self, native_experiment: NE) -> E: 

460 return E( 

461 name=native_experiment.name, 

462 id=native_experiment.experiment_id, 

463 tags=self.data_translation.deprocess_tags(native_experiment.tags), 

464 ) 

465 

466 def _parse_run(self, native_run: NR, experiment: E) -> R: 

467 if not (repo := native_run.data.tags.get(self.tag_keys.artifacts_repo)): 

468 repo = self._using_client(lambda client: client.get_artifact_uri(native_run.info.run_id)) 

469 

470 return R( 

471 exp=experiment, 

472 id=native_run.info.run_id, 

473 name=native_run.info.run_name, 

474 repo=str(urls.parse_url(repo, resolve_if_local=True)), 

475 tags=self.data_translation.deprocess_tags(native_run.data.tags), 

476 params=self.data_translation.deprocess_params(native_run.data.params), 

477 metrics=self.data_translation.deprocess_metrics(native_run.data.metrics), 

478 status=self.data_translation.run_status_from_mlflow(native_run.info.status), 

479 end_time=self.data_translation.mlflow_ts_to_datetime(native_run.info.end_time), 

480 start_time=self.data_translation.mlflow_ts_to_datetime(native_run.info.start_time), 

481 ) 

482 

483 def _parse_model(self, native_model: NM) -> M: 

484 return M( 

485 name=native_model.name, 

486 tags=self.data_translation.deprocess_tags(native_model.tags), 

487 ) 

488 

489 def _parse_model_version(self, native_mv: NV, model: M, run: R) -> V: 

490 path_in_run = ( 

491 str(urls.parse_url(native_mv.source)).removeprefix(f"runs:///{run.id}/").removeprefix(run.repo).strip("/") 

492 ) 

493 

494 return V( 

495 run=run, 

496 model=model, 

497 path_in_run=path_in_run, 

498 version=str(native_mv.version), 

499 tags=self.data_translation.deprocess_tags(native_mv.tags), 

500 ) 

501 

502 # ======================================================================================================= 

503 # === Implementations of abstract methods from `BaseMlflowApi` ========================================== 

504 

505 def _impl_default_cache_dir(self) -> Path: 

506 """Get default cache dir based on the current MLflow API settings.""" 

507 return Path.home().joinpath(".cache/mlopus/mlflow-providers/mlflow", self._default_cache_id) 

508 

509 def _impl_get_exp_url(self, exp_id: str) -> urls.Url: 

510 """Get Experiment URL.""" 

511 path = "" if urls.is_local(base := self.tracking_uri) else "#/experiments" 

512 return urls.urljoin(base, path, exp_id) 

513 

514 def _impl_get_run_url(self, run_id: str, exp_id: str) -> urls.Url: 

515 """Get Run URL.""" 

516 path = "" if urls.is_local(base := self._impl_get_exp_url(exp_id)) else "runs" 

517 return urls.urljoin(base, path, run_id) 

518 

519 def _impl_get_model_url(self, name: str) -> urls.Url: 

520 """Get URL to registered model.""" 

521 name = patterns.encode_model_name(name) 

522 path = "models/%s" if urls.is_local(base := self.tracking_uri) else "#/models/%s" 

523 return urls.urljoin(base, path % name) 

524 

525 def _impl_get_mv_url(self, name: str, version: str) -> urls.Url: 

526 """Get model version URL.""" 

527 name = patterns.encode_model_name(name) 

528 path = "models/%s/version-%s" if urls.is_local(base := self.tracking_uri) else "#/models/%s/versions/%s" 

529 return urls.urljoin(base, path % (name, version)) 

530 

531 def _impl_fetch_exp(self, exp_id: str) -> E: 

532 """Get Experiment by ID.""" 

533 with self._client() as cli: 

534 return self._parse_experiment(cli.get_experiment(exp_id)) 

535 

536 def _impl_fetch_run(self, run_id: str) -> R: 

537 """Get Run by ID.""" 

538 native_run = self._using_client(lambda client: client.get_run(run_id)) 

539 experiment = self._fetch_exp(native_run.info.experiment_id) 

540 return self._parse_run(native_run, experiment) 

541 

542 def _impl_fetch_model(self, name: str) -> M: 

543 """Get registered Model by name.""" 

544 return self._parse_model(self._using_client(lambda client: client.get_registered_model(name))) 

545 

546 def _impl_fetch_mv(self, name_and_version: Tuple[str, str]) -> V: 

547 """Get ModelVersion by name and version.""" 

548 native_mv = self._using_client(lambda client: client.get_model_version(*name_and_version)) 

549 run = self._fetch_run(native_mv.run_id) 

550 model = self._fetch_model(native_mv.name) 

551 return self._parse_model_version(native_mv, model, run) 

552 

553 def _impl_find_experiments( 

554 self, query: mongo.Query, sorting: mongo.Sorting 

555 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[E]]: 

556 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

557 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_exp(query, sorting) 

558 logger.debug("Query push down (exp): %s %s", filter_expr, sort_expr) 

559 

560 paginator = iter_utils.Paginator[NE]( 

561 lambda token: _open_paged_list( 

562 self._using_client( 

563 lambda client: client.search_experiments( 

564 page_token=token, 

565 filter_string=filter_expr, 

566 order_by=sort_expr.split(", ") if sort_expr else None, 

567 ) 

568 ), 

569 ), 

570 ) 

571 

572 return query_remainder, sort_remainder, paginator.map_results(self._parse_experiment) 

573 

574 def _impl_find_runs( 

575 self, query: mongo.Query, sorting: mongo.Sorting 

576 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[R]]: 

577 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

578 filter_expr, query_remainder, sort_expr, sort_remainder, exp_ids = self.query_push_down.parse_run( 

579 query, sorting 

580 ) 

581 logger.debug("Query push down (run): %s %s", filter_expr, sort_expr) 

582 

583 paginator = iter_utils.Paginator[NR]( 

584 lambda token: _open_paged_list( 

585 self._using_client( 

586 lambda client: client.search_runs( 

587 page_token=token, 

588 experiment_ids=exp_ids, 

589 filter_string=filter_expr, 

590 order_by=sort_expr.split(", ") if sort_expr else None, 

591 ) 

592 ), 

593 ), 

594 ).map_results( 

595 lambda native_run: self._parse_run( 

596 native_run=native_run, 

597 experiment=self._fetch_exp(native_run.info.experiment_id), 

598 ), 

599 ) 

600 

601 return query_remainder, sort_remainder, paginator 

602 

603 def _impl_find_models( 

604 self, query: mongo.Query, sorting: mongo.Sorting 

605 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[M]]: 

606 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

607 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_model(query, sorting) 

608 logger.debug("Query push down (model): %s %s", filter_expr, sort_expr) 

609 

610 paginator = iter_utils.Paginator[NM]( 

611 lambda token: _open_paged_list( 

612 self._using_client( 

613 lambda client: client.search_registered_models( 

614 page_token=token, 

615 filter_string=filter_expr, 

616 order_by=sort_expr.split(", ") if sort_expr else None, 

617 ) 

618 ), 

619 ), 

620 ) 

621 

622 return query_remainder, sort_remainder, paginator.map_results(self._parse_model) 

623 

624 def _impl_find_mv( 

625 self, query: mongo.Query, sorting: mongo.Sorting 

626 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[V]]: 

627 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

628 filter_expr, query_remainder, sort_expr, sort_remainder = self.query_push_down.parse_mv(query, sorting) 

629 logger.debug("Query push down (mv): %s %s", filter_expr, sort_expr) 

630 

631 paginator = iter_utils.Paginator[NV]( 

632 lambda token: _open_paged_list( 

633 self._using_client( 

634 lambda client: client.search_model_versions( 

635 page_token=token, 

636 filter_string=filter_expr, 

637 order_by=sort_expr.split(", ") if sort_expr else None, 

638 ) 

639 ), 

640 ), 

641 ).map_results( 

642 lambda native_mv: self._parse_model_version( 

643 native_mv=native_mv, 

644 run=self._fetch_run(native_mv.run_id), 

645 model=self._fetch_model(native_mv.name), 

646 ), 

647 ) 

648 

649 return query_remainder, sort_remainder, paginator 

650 

651 def _impl_find_child_runs(self, run: R) -> Iterator[R]: 

652 """Find child runs.""" 

653 return self.find_runs({"exp.id": run.exp.id, f"tags.{self.tag_keys.parent_run_id}": run.id}) 

654 

655 def _impl_create_exp(self, name: str, tags: Mapping) -> E: 

656 """Create experiment and return its metadata.""" 

657 tags = self.data_translation.preprocess_tags(tags) 

658 return self._get_exp(self._using_client(lambda client: client.create_experiment(name=name, tags=tags))) 

659 

660 def _impl_create_model(self, name: str, tags: Mapping) -> M: 

661 """Create registered model and return its metadata.""" 

662 tags = self.data_translation.preprocess_tags(tags) 

663 return self._parse_model(self._using_client(lambda client: client.create_registered_model(name, tags))) 

664 

665 def _impl_create_run( 

666 self, exp_id: str, name: str | None, repo: urls.Url | None, parent_run_id: str | None = None 

667 ) -> str: 

668 """Create run.""" 

669 with self._client() as client: 

670 run_id = client.create_run(exp_id, run_name=name).info.run_id 

671 if repo: 

672 client.set_tag(run_id, key=self.tag_keys.artifacts_repo, value=str(repo), synchronous=True) 

673 if parent_run_id: 

674 client.set_tag(run_id, key=self.tag_keys.parent_run_id, value=parent_run_id, synchronous=True) 

675 return run_id 

676 

677 def _impl_set_run_status(self, run_id: str, status: schema.RunStatus): 

678 """Set Run status.""" 

679 with self._client() as client: 

680 client.update_run(run_id, status=self.data_translation.run_status_to_mlflow(status, as_str=True)) 

681 

682 def _impl_set_run_end_time(self, run_id: str, end_time: datetime): 

683 """Set Run end time.""" 

684 end_time = self.data_translation.datetime_to_mlflow_ts(end_time) 

685 self._using_client(lambda client: client.tracking_client.store.update_run_info(run_id, None, end_time, None)) 

686 

687 def _impl_update_exp_tags(self, exp_id: str, tags: Mapping): 

688 """Update Exp tags.""" 

689 with self._client() as client: 

690 for k, v in self.data_translation.preprocess_tags(tags).items(): 

691 client.set_experiment_tag(exp_id, k, v) 

692 

693 def _impl_update_run_tags(self, run_id: str, tags: Mapping): 

694 """Update Run tags.""" 

695 with self._client() as client: 

696 for k, v in self.data_translation.preprocess_tags(tags).items(): 

697 client.set_tag(run_id, k, v, synchronous=True) 

698 

699 def _impl_update_model_tags(self, name: str, tags: Mapping): 

700 """Update Model tags.""" 

701 with self._client() as client: 

702 for k, v in self.data_translation.preprocess_tags(tags).items(): 

703 client.set_registered_model_tag(name, k, v) 

704 

705 def _impl_update_mv_tags(self, name: str, version: str, tags: Mapping): 

706 """Update Exp tags.""" 

707 with self._client() as client: 

708 for k, v in self.data_translation.preprocess_tags(tags).items(): 

709 client.set_model_version_tag(name, version, k, v) 

710 

711 def _impl_log_run_params(self, run_id: str, params: Mapping): 

712 """Log run params.""" 

713 with self._client() as client: 

714 for k, v in self.data_translation.preprocess_params(params).items(): 

715 client.log_param(run_id, k, v, synchronous=True) 

716 

717 def _impl_log_run_metrics(self, run_id: str, metrics: Mapping): 

718 """Log run metrics.""" 

719 with self._client() as client: 

720 for k, v in self.data_translation.preprocess_metrics(metrics).items(): 

721 client.log_metric(run_id, k, v, synchronous=True) 

722 

723 def _impl_register_mv(self, model: M, run: R, path_in_run: str, version: str | None, tags: Mapping) -> V: 

724 """Register model version.""" 

725 assert version is None, f"Arbitrary `version` not supported in '{self.__class__.__name__}'" 

726 tags = self.data_translation.preprocess_tags(tags) 

727 source = str(urls.urljoin(run.repo, path_in_run)) 

728 native_mv = self._using_client(lambda client: client.create_model_version(model.name, source, run.id, tags)) 

729 return self._parse_model_version(native_mv, model, run) 

730 

731 

732def _open_paged_list(paged_list: PagedList[A]) -> iter_utils.Page[A]: 

733 """Convert MLflow native PagedList to `iter_utils.Page`.""" 

734 return iter_utils.Page(token=paged_list.token, results=paged_list)