Coverage for src/mlopus/mlflow/api/base.py: 87%

626 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import contextlib 

2import logging 

3import os.path 

4import tempfile 

5import typing 

6from abc import ABC, abstractmethod 

7from datetime import datetime 

8from pathlib import Path 

9from typing import Type, TypeVar, Callable, Iterator, Tuple, Mapping 

10 

11from mlopus.utils import pydantic, paths, urls, mongo, string_utils, iter_utils 

12from . import contract 

13from .common import schema, decorators, serde, exceptions, patterns, transfer 

14from .exp import ExpApi 

15from .model import ModelApi 

16from .mv import ModelVersionApi 

17from .run import RunApi 

18 

19logger = logging.getLogger(__name__) 

20 

21A = TypeVar("A") # Any type 

22 

23# Entity types 

24E = schema.Experiment 

25R = schema.Run 

26M = schema.Model 

27V = schema.ModelVersion 

28T = TypeVar("T", bound=schema.BaseEntity) 

29 

30# Identifier types 

31ExpIdentifier = contract.ExpIdentifier | ExpApi 

32RunIdentifier = contract.RunIdentifier | RunApi 

33ModelIdentifier = contract.ModelIdentifier | ModelApi 

34ModelVersionIdentifier = contract.ModelVersionIdentifier | ModelVersionApi 

35 

36 

37class BaseMlflowApi(contract.MlflowApiContract, ABC, frozen=True): 

38 """Base class for API clients that use "MLflow-like" backends for experiment tracking and model registry. 

39 

40 Important: 

41 Implementations of this interface are meant to be thread-safe and independent of env vars/globals, 

42 so multiple API instances can coexist in the same program if necessary. 

43 """ 

44 

45 cache_dir: Path | None = pydantic.Field( 

46 default=None, 

47 description=( 

48 "Root path for cached artifacts and metadata. " 

49 "If not specified, then a default is determined by the respective API plugin." 

50 ), 

51 ) 

52 

53 offline_mode: bool = pydantic.Field( 

54 default=False, 

55 description=( 

56 "If `True`, block all operations that attempt communication " 

57 "with the MLflow server (i.e.: only use cached metadata). " 

58 "Artifacts are still accessible if they are cached or if " 

59 ":attr:`pull_artifacts_in_offline_mode` is `True`." 

60 ), 

61 ) 

62 

63 pull_artifacts_in_offline_mode: bool = pydantic.Field( 

64 default=False, 

65 description=( 

66 "If `True`, allow pulling artifacts from storage to cache in offline mode. " 

67 "Useful if caching metadata only and pulling artifacts on demand " 

68 "(the artifact's URL must be known beforehand, e.g. by caching the metadata of its parent entity). " 

69 ), 

70 ) 

71 

72 temp_artifacts_dir: Path = pydantic.Field( 

73 default=None, 

74 description=( 

75 "Path for temporary artifacts that are stored by artifact dumpers before being published and preserved " 

76 "after a publish error (e.g.: an upload interruption). Defaults to a path inside the local cache." 

77 ), 

78 ) 

79 

80 cache_local_artifacts: bool = pydantic.Field( 

81 default=False, 

82 description=( 

83 "Use local cache even if the run artifacts repository is in the local file system. " 

84 "May be used for testing cache without connecting to a remote MLflow server." 

85 "Not recommended in production because of unecessary duplicated disk usage. " 

86 ), 

87 ) 

88 

89 always_pull_artifacts: bool = pydantic.Field( 

90 default=False, 

91 description=( 

92 "When accessing a cached artifact file or dir, re-sync it with the remote artifacts repository, even " 

93 "on a cache hit. Prevents accessing stale data if the remote artifact has been changed in the meantime. " 

94 "The default data transfer utility (based on rclone) is pretty efficient for syncing directories, but " 

95 "enabling this option may still add some overhead of calculating checksums if they contain many files." 

96 ), 

97 ) 

98 

99 file_transfer: transfer.FileTransfer = pydantic.Field( 

100 repr=False, 

101 default_factory=transfer.FileTransfer, 

102 description=( 

103 "Utility for uploading/downloading artifact files or dirs. Also used for listing files. Based on " 

104 "RClone by default. Users may replace this with a different implementation when subclassing the API." 

105 ), 

106 ) 

107 

108 entity_serializer: serde.EntitySerializer = pydantic.Field( 

109 repr=False, 

110 default_factory=serde.EntitySerializer, 

111 description=( 

112 "Utility for (de)serializing entity metadata (i.e.: exp, runs, models, versions)." 

113 "Users may replace this with a different implementation when subclassing the API." 

114 ), 

115 ) 

116 

117 def __init__(self, **kwargs): 

118 super().__init__(**kwargs) 

119 

120 # Apply default to cache dir, expand and resolve 

121 if paths.is_cwd(cache := self.cache_dir or ""): 

122 cache = self._impl_default_cache_dir() 

123 pydantic.force_set_attr(self, key="cache_dir", val=cache.expanduser().resolve()) 

124 

125 # Apply default to temp artifacts dir, expand and resolve 

126 if paths.is_cwd(tmp := self.temp_artifacts_dir or ""): 

127 tmp = self._artifacts_cache.joinpath("temp") 

128 pydantic.force_set_attr(self, key="temp_artifacts_dir", val=tmp.expanduser().resolve()) 

129 

130 @property 

131 def in_offline_mode(self) -> "BaseMlflowApi": 

132 """Get an offline copy of this API.""" 

133 return self.model_copy(update={"offline_mode": True}) 

134 

135 # ======================================================================================================= 

136 # === Metadata cache locators =========================================================================== 

137 

138 @property 

139 def _metadata_cache(self) -> Path: 

140 return self.cache_dir.joinpath("metadata") 

141 

142 @property 

143 def _exp_cache(self) -> Path: 

144 return self._metadata_cache.joinpath("exp") 

145 

146 @property 

147 def _run_cache(self) -> Path: 

148 return self._metadata_cache.joinpath("run") 

149 

150 @property 

151 def _model_cache(self) -> Path: 

152 return self._metadata_cache.joinpath("model") 

153 

154 @property 

155 def _mv_cache(self) -> Path: 

156 return self._metadata_cache.joinpath("mv") 

157 

158 def _get_exp_cache(self, exp_id: str) -> Path: 

159 return self._exp_cache.joinpath(exp_id) 

160 

161 def _get_run_cache(self, run_id: str) -> Path: 

162 return self._run_cache.joinpath(run_id) 

163 

164 def _get_model_cache(self, name: str) -> Path: 

165 return self._model_cache.joinpath(patterns.encode_model_name(name)) 

166 

167 def _get_mv_cache(self, name: str, version: str) -> Path: 

168 return self._mv_cache.joinpath(patterns.encode_model_name(name), version) 

169 

170 # ======================================================================================================= 

171 # === Artifact cache locators =========================================================================== 

172 

173 @property 

174 def _artifacts_cache(self) -> Path: 

175 return self.cache_dir.joinpath("artifacts") 

176 

177 @property 

178 def _run_artifacts_cache(self) -> Path: 

179 return self._artifacts_cache.joinpath("runs") 

180 

181 def _get_run_artifact_cache_path( 

182 self, run: RunIdentifier, path_in_run: str = "", allow_base_resolve: bool = True 

183 ) -> Path: 

184 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=allow_base_resolve) 

185 return self._run_artifacts_cache.joinpath(self._coerce_run_id(run), path_in_run) 

186 

187 def _get_temp_artifacts_dir(self) -> Path: 

188 return Path(tempfile.mkdtemp(dir=paths.ensure_is_dir(self.temp_artifacts_dir))) 

189 

190 # ======================================================================================================= 

191 # === Cache protection ================================================================================== 

192 

193 @contextlib.contextmanager 

194 def _lock_run_artifact(self, run_id: str, path_in_run: str, allow_base_resolve: bool = True) -> Path: 

195 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=allow_base_resolve) 

196 with paths.dir_lock(self._get_run_artifact_cache_path(run_id)) as path: 

197 yield path.joinpath(path_in_run) 

198 

199 # ======================================================================================================= 

200 # === Cache cleanup ===================================================================================== 

201 

202 def _clean_temp_artifacts(self): 

203 paths.ensure_non_existing(self.temp_artifacts_dir, force=True) 

204 

205 def _clean_all_meta_cache(self): 

206 for path in (self._exp_cache, self._run_cache, self._model_cache, self._mv_cache): 

207 with paths.dir_lock(path): 

208 paths.ensure_empty_dir(path, force=True) 

209 

210 def _clean_run_artifact(self, run_id: str, path_in_run: str = ""): 

211 with self._lock_run_artifact(run_id, path_in_run) as path: 

212 paths.ensure_non_existing(path, force=True) 

213 

214 def _clean_all_runs_artifacts(self): 

215 with paths.dir_lock(self._run_artifacts_cache) as path: 

216 paths.ensure_non_existing(path, force=True) 

217 

218 def _clean_all_cache(self): 

219 self._clean_all_meta_cache() 

220 self._clean_all_runs_artifacts() 

221 

222 # ======================================================================================================= 

223 # === Metadata Getters ================================================================================== 

224 

225 @classmethod 

226 def _meta_fetcher(cls, fetcher: Callable[[A], T], args: A) -> Callable[[], T]: 

227 return lambda: fetcher(args) 

228 

229 def _meta_cache_reader(self, path: Path, type_: Type[T]) -> Callable[[], T]: 

230 return lambda: self.entity_serializer.load(type_, path) 

231 

232 def _meta_cache_writer(self, lock_dir: Path, path: Path, type_: Type[T]) -> Callable[[T], None]: 

233 def _write_meta_cache(meta: T): 

234 assert isinstance(meta, type_) 

235 with paths.dir_lock(lock_dir): 

236 self.entity_serializer.save(meta, path) 

237 

238 return _write_meta_cache 

239 

240 def _get_meta( 

241 self, 

242 fetcher: Callable[[], T], 

243 cache_reader: Callable[[], T], 

244 cache_writer: Callable[[T], None], 

245 force_cache_refresh: bool = False, 

246 ) -> T: 

247 """Get metadata.""" 

248 if force_cache_refresh: 

249 if self.offline_mode: 

250 raise RuntimeError("Cannot refresh cache on offline mode.") 

251 cache_writer(meta := fetcher()) 

252 elif not self.offline_mode: 

253 meta = fetcher() 

254 else: 

255 meta = cache_reader() 

256 

257 return meta 

258 

259 def _get_exp(self, exp_id: str, **cache_opts: bool) -> E: 

260 """Get Experiment metadata.""" 

261 return self._get_meta( 

262 self._meta_fetcher(self._fetch_exp, exp_id), 

263 self._meta_cache_reader(cache := self._get_exp_cache(exp_id), E), 

264 self._meta_cache_writer(self._exp_cache, cache, E), 

265 **cache_opts, 

266 ) 

267 

268 def _get_run(self, run_id: str, **cache_opts: bool) -> R: 

269 """Get Run metadata.""" 

270 return self._get_meta( 

271 self._meta_fetcher(self._fetch_run, run_id), 

272 self._meta_cache_reader(cache := self._get_run_cache(run_id), R), 

273 lambda run: [ 

274 self._meta_cache_writer(self._run_cache, cache, R)(run), 

275 self._meta_cache_writer(self._exp_cache, self._get_exp_cache(run.exp.id), E)(run.exp), 

276 ][0], 

277 **cache_opts, 

278 ) 

279 

280 def _get_model(self, name: str, **cache_opts: bool) -> M: 

281 """Get Model metadata.""" 

282 return self._get_meta( 

283 self._meta_fetcher(self._fetch_model, name), 

284 self._meta_cache_reader(cache := self._get_model_cache(name), M), 

285 self._meta_cache_writer(self._model_cache, cache, M), 

286 **cache_opts, 

287 ) 

288 

289 def _get_mv(self, name_and_version: Tuple[str, str], **cache_opts: bool) -> V: 

290 """Get ModelVersion metadata.""" 

291 return self._get_meta( 

292 self._meta_fetcher(self._fetch_mv, name_and_version), 

293 self._meta_cache_reader(cache := self._get_mv_cache(*name_and_version), V), 

294 lambda mv: [ 

295 self._meta_cache_writer(self._mv_cache, cache, V)(mv), 

296 self._meta_cache_writer(self._model_cache, self._get_model_cache(mv.model.name), M)(mv.model), 

297 ][0], 

298 **cache_opts, 

299 ) 

300 

301 @decorators.online 

302 def _fetch_exp(self, exp_id: str) -> E: 

303 return self._impl_fetch_exp(exp_id) 

304 

305 @decorators.online 

306 def _fetch_run(self, run_id: str) -> R: 

307 return self._impl_fetch_run(run_id) 

308 

309 @decorators.online 

310 def _fetch_model(self, name: str) -> M: 

311 return self._impl_fetch_model(name) 

312 

313 @decorators.online 

314 def _fetch_mv(self, name_and_version: Tuple[str, str]) -> V: 

315 return self._impl_fetch_mv(name_and_version) 

316 

317 def _export_meta(self, meta: T, cache: Path, target: Path): 

318 paths.ensure_only_parents(target := target / cache.relative_to(self.cache_dir), force=True) 

319 self.entity_serializer.save(meta, target) 

320 

321 # ======================================================================================================= 

322 # === Metadata Finders ================================================================================== 

323 

324 def _find_experiments(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[E]: 

325 return self._find_meta(self._exp_cache, E, query, sorting, self._impl_find_experiments) 

326 

327 def _find_runs(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[R]: 

328 return self._find_meta(self._run_cache, R, query, sorting, self._impl_find_runs) 

329 

330 def _find_models(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[M]: 

331 return self._find_meta(self._model_cache, M, query, sorting, self._impl_find_models) 

332 

333 def _find_mv(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[V]: 

334 return self._find_meta(self._mv_cache, V, query, sorting, self._impl_find_mv) 

335 

336 def _find_meta( 

337 self, 

338 cache: Path, 

339 type_: Type[T], 

340 query: mongo.Query, 

341 sorting: mongo.Sorting, 

342 finder: Callable[[mongo.Query, mongo.Sorting], Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[T]]], 

343 ) -> Iterator[T]: 

344 if self.offline_mode: 

345 logger.warning("Metadata search in offline mode may yield incomplete or stale results.") 

346 cached = (self.entity_serializer.load(type_, x) for x in paths.iter_files(cache)) 

347 paginator = iter_utils.Paginator[T].single_page(cached) 

348 else: 

349 query, sorting, paginator = finder(query, sorting) 

350 

351 if sorting: 

352 paginator = paginator.collapse() 

353 

354 if query: 

355 paginator = paginator.map_pages( 

356 lambda results: list( 

357 mongo.find_all( 

358 results, 

359 query=query, 

360 sorting=sorting, 

361 to_doc=lambda obj: obj.dict(), 

362 from_doc=type_.parse_obj, 

363 ) 

364 ) 

365 ) 

366 

367 for page in paginator: 

368 for result in page: 

369 yield result 

370 

371 # ======================================================================================================= 

372 # === Artifact getters ================================================================================== 

373 

374 def _list_run_artifacts(self, run: RunIdentifier, path_in_run: str = "") -> transfer.LsResult: 

375 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True) 

376 

377 if self.offline_mode and not self.pull_artifacts_in_offline_mode: 

378 logger.warning("Listing run artifacts in offline mode may yield incomplete or stale results.") 

379 subject = self._get_run_artifact_cache_path(run, path_in_run) 

380 elif urls.is_local(subject := urls.urljoin(self._coerce_run(run).repo, path_in_run)): 

381 subject = subject.path 

382 

383 return self.file_transfer.ls(subject) 

384 

385 def _pull_run_artifact(self, run: RunIdentifier, path_in_run: str) -> Path: 

386 """Pull artifact from run repo to local cache, unless repo is already local.""" 

387 if self.offline_mode and not self.pull_artifacts_in_offline_mode: 

388 raise RuntimeError("Artifact pull is disabled.") 

389 

390 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True) 

391 

392 with self._lock_run_artifact((run := self._coerce_run(run)).id, path_in_run) as target: 

393 if urls.is_local(url := run.repo_url): 

394 source = Path(url.path) / path_in_run 

395 

396 if self.cache_local_artifacts: 

397 paths.place_path(source, target, mode="copy", overwrite=True) 

398 else: 

399 logger.info("Run artifacts repo is local, nothing to pull.") 

400 return source 

401 else: 

402 self.file_transfer.pull_files(urls.urljoin(run.repo_url, path_in_run), target) 

403 

404 return target 

405 

406 def _get_run_artifact(self, run: RunIdentifier, path_in_run: str) -> Path: 

407 """Get path to local run artifact, may trigger a cache pull.""" 

408 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True) 

409 cache = self._get_run_artifact_cache_path(run, path_in_run) 

410 

411 if (not self.offline_mode or self.pull_artifacts_in_offline_mode) and ( 

412 not cache.exists() or self.always_pull_artifacts 

413 ): 

414 return self._pull_run_artifact(run, path_in_run) 

415 

416 if not cache.exists(): 

417 raise FileNotFoundError(cache) 

418 

419 return cache 

420 

421 def _place_run_artifact( 

422 self, 

423 run: RunIdentifier, 

424 path_in_run: str, 

425 target: Path, 

426 link: bool, 

427 overwrite: bool, 

428 ): 

429 """Place local run artifact on target path, may trigger a cache pull. 

430 

431 The resulting files are always write-protected, but directories are not. 

432 """ 

433 mode = typing.cast(paths.PathOperation, "link" if link else "copy") 

434 

435 if (src := self._get_run_artifact(run, path_in_run)).is_dir() and link: 

436 paths.ensure_only_parents(target, force=overwrite) 

437 

438 for dirpath, _, filenames in os.walk(src): # Recursively create symbolic links for files 

439 relpath = Path(dirpath).relative_to(src) 

440 for filename in filenames: 

441 paths.place_path(Path(dirpath) / filename, target / relpath / filename, mode, overwrite) 

442 else: 

443 paths.place_path(src, target, mode, overwrite) 

444 

445 if target.is_dir() and not link: # Recursively fix permissions of copied directories 

446 target.chmod(paths.Mode.rwx) 

447 

448 for dirpath, dirnames, _ in os.walk(target): 

449 for dirname in dirnames: 

450 Path(dirpath, dirname).chmod(paths.Mode.rwx) 

451 

452 # ======================================================================================================= 

453 # === Arguments pre-processing ========================================================================== 

454 

455 def _coerce_exp(self, exp: ExpIdentifier) -> E: 

456 match exp: 

457 case E(): 

458 return exp 

459 case str(): 

460 return self._get_exp(exp) 

461 case _: 

462 raise TypeError("Expected Experiment, ExpApi or experiment ID as string.") 

463 

464 @classmethod 

465 @string_utils.retval_matches(patterns.EXP_ID) 

466 def _coerce_exp_id(cls, exp: ExpIdentifier) -> str: 

467 match exp: 

468 case E(): 

469 return exp.id 

470 case str(): 

471 return exp 

472 case _: 

473 raise TypeError("Expected Experiment, ExpApi or experiment ID as string.") 

474 

475 def _coerce_run(self, run: RunIdentifier) -> R: 

476 match run: 

477 case R(): 

478 return run 

479 case str(): 

480 return self._get_run(run) 

481 case _: 

482 raise TypeError("Expected Run, RunApi or run ID as string.") 

483 

484 @classmethod 

485 @string_utils.retval_matches(patterns.RUN_ID) 

486 def _coerce_run_id(cls, run: RunIdentifier) -> str: 

487 match run: 

488 case R(): 

489 return run.id 

490 case str(): 

491 return run 

492 case _: 

493 raise TypeError("Expected Run, RunApi or run ID as string.") 

494 

495 def _coerce_model(self, model: ModelIdentifier) -> M: 

496 match model: 

497 case M(): 

498 return model 

499 case str(): 

500 return self._get_model(model) 

501 case _: 

502 raise TypeError("Expected Model, ModelApi or model name as string.") 

503 

504 @classmethod 

505 @string_utils.retval_matches(patterns.MODEL_NAME) 

506 def _coerce_model_name(cls, model: ModelIdentifier) -> str: 

507 match model: 

508 case M(): 

509 return model.name 

510 case str(): 

511 return model 

512 case _: 

513 raise TypeError("Expected Model, ModelApi or model name as string.") 

514 

515 def _coerce_mv(self, mv: ModelVersionIdentifier) -> V: 

516 match mv: 

517 case V(): 

518 return mv 

519 case (str(), str()): 

520 return self._get_mv(*mv) 

521 case _: 

522 raise TypeError("Expected ModelVersion or tuple or (name, version) strings.") 

523 

524 @classmethod 

525 @string_utils.retval_matches(patterns.MODEL_NAME, index=0) 

526 @string_utils.retval_matches(patterns.MODEL_VERSION, index=1) 

527 def _coerce_mv_tuple(cls, mv: ModelVersionIdentifier) -> Tuple[str, str]: 

528 match mv: 

529 case V(): 

530 return mv.model.name, mv.version 

531 case (str(), str()): 

532 return mv 

533 case _: 

534 raise TypeError("Expected ModelVersion or tuple or (name, version) strings.") 

535 

536 @classmethod 

537 def _valid_path_in_run(cls, path_in_run: str, allow_empty: bool = False) -> str: 

538 """Validate the `path_in_run` for an artifact or model. 

539 

540 - Cannot be empty, unless specified 

541 - Slashes are trimmed 

542 - Cannot do path climbing (e.g.: "../") 

543 - Cannot do current path referencing (e.g.: "./") 

544 

545 Valid path example: "a/b/c" 

546 """ 

547 if (path_in_run := str(path_in_run).strip("/")) and os.path.abspath(path := "/root/" + path_in_run) == path: 

548 return path_in_run 

549 if allow_empty and path_in_run == "": 

550 return path_in_run 

551 raise paths.IllegalPath(f"`path_in_run={path_in_run}`") 

552 

553 # ======================================================================================================= 

554 # === Experiment tracking =============================================================================== 

555 

556 @decorators.online 

557 def _create_exp(self, name: str, tags: Mapping) -> E: 

558 """Create experiment and return its metadata.""" 

559 return self._impl_create_exp(name, tags) 

560 

561 @decorators.online 

562 def _create_model(self, name: str, tags: Mapping) -> M: 

563 """Create registered model and return its metadata.""" 

564 return self._impl_create_model(self._coerce_model_name(name), tags) 

565 

566 @decorators.online 

567 def _create_run( 

568 self, 

569 exp: ExpIdentifier, 

570 name: str | None, 

571 repo: urls.Url | str | None, 

572 tags: Mapping, 

573 status: schema.RunStatus, 

574 parent: RunIdentifier | None, 

575 ) -> R: 

576 """Create run with start at current UTC time.""" 

577 if repo is not None: 

578 repo = urls.parse_url(repo, resolve_if_local=True) 

579 run_id = self._impl_create_run( 

580 self._coerce_exp_id(exp), name, repo, parent_run_id=self._coerce_run_id(parent) if parent else None 

581 ) 

582 self._update_run_tags(run_id, tags) 

583 self._set_run_status(run_id, status) 

584 return self._get_run(run_id) 

585 

586 @decorators.online 

587 def _set_run_status(self, run: RunIdentifier, status: schema.RunStatus): 

588 self._impl_set_run_status(self._coerce_run_id(run), status) 

589 

590 @decorators.online 

591 def _set_run_end_time(self, run: RunIdentifier, end_time: datetime): 

592 self._impl_set_run_end_time(self._coerce_run_id(run), end_time) 

593 

594 @decorators.online 

595 def _update_exp_tags(self, exp: ExpIdentifier, tags: Mapping): 

596 self._impl_update_exp_tags(self._coerce_exp_id(exp), tags) 

597 

598 @decorators.online 

599 def _update_run_tags(self, run: RunIdentifier, tags: Mapping): 

600 self._impl_update_run_tags(self._coerce_run_id(run), tags) 

601 

602 @decorators.online 

603 def _update_model_tags(self, model: ModelIdentifier, tags: Mapping): 

604 self._impl_update_model_tags(self._coerce_model_name(model), tags) 

605 

606 @decorators.online 

607 def _update_mv_tags(self, mv: ModelVersionIdentifier, tags: Mapping): 

608 self._impl_update_mv_tags(*self._coerce_mv_tuple(mv), tags=tags) 

609 

610 @decorators.online 

611 def _log_run_params(self, run: RunIdentifier, params: Mapping): 

612 self._impl_log_run_params(self._coerce_run_id(run), params) 

613 

614 @decorators.online 

615 def _log_run_metrics(self, run: RunIdentifier, metrics: Mapping): 

616 self._impl_log_run_metrics(self._coerce_run_id(run), metrics) 

617 

618 # ======================================================================================================= 

619 # === Model registry ==================================================================================== 

620 

621 @decorators.online 

622 def _register_mv( 

623 self, model: ModelIdentifier, run: RunIdentifier, path_in_run: str, version: str | None, tags: Mapping 

624 ) -> V: 

625 path_in_run = self._valid_path_in_run(path_in_run) 

626 return self._impl_register_mv(self._coerce_model(model), self._coerce_run(run), path_in_run, version, tags) 

627 

628 # ======================================================================================================= 

629 # === Abstract Methods ================================================================================== 

630 

631 @abstractmethod 

632 def _impl_default_cache_dir(self) -> Path: 

633 """Get default cache dir based on the current MLflow API settings.""" 

634 

635 @abstractmethod 

636 def _impl_get_exp_url(self, exp_id: str) -> urls.Url: 

637 """Get Experiment URL.""" 

638 

639 @abstractmethod 

640 def _impl_get_run_url(self, run_id: str, exp_id: str) -> urls.Url: 

641 """Get Run URL.""" 

642 

643 @abstractmethod 

644 def _impl_get_model_url(self, name: str) -> urls.Url: 

645 """Get URL to registered model.""" 

646 

647 @abstractmethod 

648 def _impl_get_mv_url(self, name: str, version: str) -> urls.Url: 

649 """Get model version URL.""" 

650 

651 @abstractmethod 

652 def _impl_fetch_exp(self, exp_id: str) -> E: 

653 """Get Experiment by ID.""" 

654 

655 @abstractmethod 

656 def _impl_fetch_run(self, run_id: str) -> R: 

657 """Get Run by ID.""" 

658 

659 @abstractmethod 

660 def _impl_fetch_model(self, name: str) -> M: 

661 """Get registered Model by name.""" 

662 

663 @abstractmethod 

664 def _impl_fetch_mv(self, name_and_version: Tuple[str, str]) -> V: 

665 """Get ModelVersion by name and version.""" 

666 

667 @abstractmethod 

668 def _impl_find_experiments( 

669 self, query: mongo.Query, sorting: mongo.Sorting 

670 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[E]]: 

671 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

672 

673 @abstractmethod 

674 def _impl_find_runs( 

675 self, query: mongo.Query, sorting: mongo.Sorting 

676 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[R]]: 

677 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

678 

679 @abstractmethod 

680 def _impl_find_models( 

681 self, query: mongo.Query, sorting: mongo.Sorting 

682 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[M]]: 

683 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

684 

685 @abstractmethod 

686 def _impl_find_mv( 

687 self, query: mongo.Query, sorting: mongo.Sorting 

688 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[V]]: 

689 """Push down MongoDB query where possible and return query remainder with results iterator.""" 

690 

691 @abstractmethod 

692 def _impl_find_child_runs(self, run: R) -> Iterator[R]: 

693 """Find child runs.""" 

694 

695 @abstractmethod 

696 def _impl_create_exp(self, name: str, tags: Mapping) -> E: 

697 """Create experiment and return its metadata.""" 

698 

699 @abstractmethod 

700 def _impl_create_model(self, name: str, tags: Mapping) -> M: 

701 """Create registered model and return its metadata.""" 

702 

703 @abstractmethod 

704 def _impl_create_run( 

705 self, exp_id: str, name: str | None, repo: urls.Url | None, parent_run_id: str | None = None 

706 ) -> str: 

707 """Create experiment run.""" 

708 

709 @abstractmethod 

710 def _impl_set_run_status(self, run_id: str, status: schema.RunStatus): 

711 """Set Run status.""" 

712 

713 @abstractmethod 

714 def _impl_set_run_end_time(self, run_id: str, end_time: datetime): 

715 """Set Run end time.""" 

716 

717 @abstractmethod 

718 def _impl_update_exp_tags(self, exp_id: str, tags: Mapping): 

719 """Update Exp tags.""" 

720 

721 @abstractmethod 

722 def _impl_update_run_tags(self, run_id: str, tags: Mapping): 

723 """Update Run tags.""" 

724 

725 @abstractmethod 

726 def _impl_update_model_tags(self, name: str, tags: Mapping): 

727 """Update Model tags.""" 

728 

729 @abstractmethod 

730 def _impl_update_mv_tags(self, name: str, version: str, tags: Mapping): 

731 """Update Exp tags.""" 

732 

733 @abstractmethod 

734 def _impl_log_run_params(self, run_id: str, params: Mapping): 

735 """Log run params.""" 

736 

737 @abstractmethod 

738 def _impl_log_run_metrics(self, run_id: str, metrics: Mapping): 

739 """Log run metrics.""" 

740 

741 @abstractmethod 

742 def _impl_register_mv(self, model: M, run: R, path_in_run: str, version: str | None, tags: Mapping) -> V: 

743 """Register model version.""" 

744 

745 # ======================================================================================================= 

746 # === Public Methods ==================================================================================== 

747 

748 def clean_all_cache(self): 

749 """Clean all cached metadata and artifacts.""" 

750 self._clean_all_cache() 

751 

752 def clean_temp_artifacts(self): 

753 """Clean temporary artifacts.""" 

754 self._clean_temp_artifacts() 

755 

756 @pydantic.validate_arguments 

757 def clean_cached_run_artifact(self, run: RunIdentifier, path_in_run: str = ""): 

758 """Clean cached artifact for specified run. 

759 

760 :param run: Run ID or object. 

761 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

762 """ 

763 self._clean_run_artifact(self._coerce_run_id(run), path_in_run) 

764 

765 @pydantic.validate_arguments 

766 def clean_cached_model_artifact(self, model_version: ModelVersionIdentifier): 

767 """Clean cached artifact for specified model version. 

768 

769 :param model_version: Model version object or `(name, version)` tuple. 

770 """ 

771 mv = self._coerce_mv(model_version) 

772 self.clean_cached_run_artifact(mv.run, mv.path_in_run) 

773 

774 @pydantic.validate_arguments 

775 def list_run_artifacts(self, run: RunIdentifier, path_in_run: str = "") -> transfer.LsResult: 

776 """List run artifacts in repo. 

777 

778 :param run: Run ID or object. 

779 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

780 """ 

781 return self._list_run_artifacts(run, path_in_run) 

782 

783 @pydantic.validate_arguments 

784 def list_model_artifact(self, model_version: ModelVersionIdentifier, path_suffix: str = "") -> transfer.LsResult: 

785 """List model version artifacts in repo. 

786 

787 :param model_version: Model version object or `(name, version)` tuple. 

788 :param path_suffix: Plain relative path inside model artifact dir (e.g.: `a/b/c`). 

789 """ 

790 return self.list_run_artifacts( 

791 run=(mv := self._coerce_mv(model_version)).run, 

792 path_in_run=mv.path_in_run + "/" + path_suffix.strip("/"), 

793 ) 

794 

795 @pydantic.validate_arguments 

796 def cache_run_artifact(self, run: RunIdentifier, path_in_run: str = "") -> Path: 

797 """Pull run artifact from MLflow server to local cache. 

798 

799 :param run: Run ID or object. 

800 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

801 """ 

802 return self._pull_run_artifact(run, path_in_run) 

803 

804 @pydantic.validate_arguments 

805 def cache_model_artifact(self, model_version: ModelVersionIdentifier) -> Path: 

806 """Pull model version artifact from MLflow server to local cache. 

807 

808 :param model_version: Model version object or `(name, version)` tuple. 

809 """ 

810 mv = self._coerce_mv(model_version) 

811 return self.cache_run_artifact(mv.run, mv.path_in_run) 

812 

813 @pydantic.validate_arguments 

814 def get_run_artifact(self, run: RunIdentifier, path_in_run: str = "") -> Path: 

815 """Get local path to run artifact. 

816 

817 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

818 

819 :param run: Run ID or object. 

820 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

821 """ 

822 return self._get_run_artifact(self._coerce_run_id(run), path_in_run) 

823 

824 @pydantic.validate_arguments 

825 def get_model_artifact(self, model_version: ModelVersionIdentifier) -> Path: 

826 """Get local path to model artifact. 

827 

828 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

829 

830 :param model_version: Model version object or `(name, version)` tuple. 

831 """ 

832 mv = self._coerce_mv(model_version) 

833 return self.get_run_artifact(mv.run, mv.path_in_run) 

834 

835 @pydantic.validate_arguments 

836 def place_run_artifact( 

837 self, 

838 run: RunIdentifier, 

839 target: Path, 

840 path_in_run: str = "", 

841 overwrite: bool = False, 

842 link: bool = True, 

843 ): 

844 """Place run artifact on target path. 

845 

846 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

847 The resulting files are always write-protected, but directories are not. 

848 

849 :param run: Run ID or object. 

850 :param target: Target path. 

851 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

852 :param overwrite: Overwrite target path if exists. 

853 :param link: Use symbolic link instead of copy. 

854 """ 

855 self._place_run_artifact(self._coerce_run_id(run), path_in_run, target, link, overwrite) 

856 

857 @pydantic.validate_arguments 

858 def place_model_artifact( 

859 self, 

860 model_version: ModelVersionIdentifier, 

861 target: Path, 

862 overwrite: bool = False, 

863 link: bool = True, 

864 ): 

865 """Place model version artifact on target path. 

866 

867 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

868 

869 :param model_version: Model version object or `(name, version)` tuple. 

870 :param target: Target path. 

871 :param overwrite: Overwrite target path if exists. 

872 :param link: Use symbolic link instead of copy. 

873 """ 

874 mv = self._coerce_mv(model_version) 

875 self.place_run_artifact(mv.run, target, mv.path_in_run, overwrite, link) 

876 

877 @pydantic.validate_arguments 

878 def export_run_artifact( 

879 self, 

880 run: RunIdentifier, 

881 target: Path, 

882 path_in_run: str = "", 

883 ) -> Path: 

884 """Export run artifact cache to target path while keeping the original cache structure. 

885 

886 The target path can then be used as cache dir by the `generic` MLflow API in offline mode. 

887 

888 :param run: Run ID or object. 

889 :param target: Cache export path. 

890 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

891 """ 

892 if paths.is_sub_dir(target, self.cache_dir) or paths.is_sub_dir(self.cache_dir, target): 

893 raise paths.IllegalPath(f"Cannot export cache to itself, its subdirs or parents: {target}") 

894 cache = self._get_run_artifact(run, path_in_run) 

895 target = self.model_copy(update={"cache_dir": target})._get_run_artifact_cache_path(run, path_in_run) 

896 paths.place_path(cache, target, mode="copy", overwrite=True) 

897 paths.rchmod(target, paths.Mode.rwx) # Exported caches are not write-protected 

898 return target 

899 

900 @pydantic.validate_arguments 

901 def export_model_artifact( 

902 self, 

903 model_version: ModelVersionIdentifier, 

904 target: Path, 

905 ) -> Path: 

906 """Export model version artifact cache to target path while keeping the original cache structure. 

907 

908 The target path can then be used as cache dir by the `generic` MLflow API in offline mode. 

909 

910 :param model_version: Model version object or `(name, version)` tuple. 

911 :param target: Cache export path. 

912 """ 

913 mv = self._coerce_mv(model_version) 

914 return self.export_run_artifact(mv.run, target, mv.path_in_run) 

915 

916 @pydantic.validate_arguments 

917 def load_run_artifact(self, run: RunIdentifier, loader: Callable[[Path], A], path_in_run: str = "") -> A: 

918 """Load run artifact. 

919 

920 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

921 

922 :param run: Run ID or object. 

923 :param loader: Loader callback. 

924 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

925 """ 

926 return loader(self._get_run_artifact(self._coerce_run_id(run), path_in_run)) 

927 

928 @pydantic.validate_arguments 

929 def load_model_artifact(self, model_version: ModelVersionIdentifier, loader: Callable[[Path], A]) -> A: 

930 """Load model version artifact. 

931 

932 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`. 

933 

934 :param model_version: Model version object or `(name, version)` tuple. 

935 :param loader: Loader callback. 

936 """ 

937 mv = self._coerce_mv(model_version) 

938 logger.info("Loading model: %s v%s", mv.model.name, mv.version) 

939 return self.load_run_artifact(mv.run, loader, mv.path_in_run) 

940 

941 @decorators.online 

942 @pydantic.validate_arguments 

943 def log_run_artifact( 

944 self, 

945 run: RunIdentifier, 

946 source: Path | Callable[[Path], None], 

947 path_in_run: str | None = None, 

948 keep_the_source: bool | None = None, 

949 allow_duplication: bool | None = None, 

950 use_cache: bool | None = None, 

951 ): 

952 """Publish artifact file or dir to experiment run. 

953 

954 The flags :paramref:`keep_the_source`, :paramref:`allow_duplication` and :paramref:`use_cache` are 

955 experimental and may conflict with one another. It is recommended to leave them unspecified, so this 

956 method will do a best-effort to use cache if it makes sense to, keep the source files if it makes 

957 sense to (possibly as a symbolic link) and avoid duplicated disk usage when possible. 

958 

959 :param run: | Run ID or object. 

960 

961 :param source: | Path to artifact file or dir, or a dumper callback. 

962 | If it's a callback and the upload is interrupted, the temporary artifact is kept. 

963 

964 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`) 

965 

966 - If `source` is a `Path`: Defaults to file or dir name. 

967 - If `source` is a callback: No default available. 

968 

969 :param keep_the_source: 

970 - If `source` is a `Path`: Keep that file or dir (defaults to `True`). 

971 - If `source` is a callback: Keep the temporary artifact, even after a successful upload (defaults to `False`). 

972 

973 :param allow_duplication: | If `False`, a `source` file or dir may be replaced with a symbolic link to the local cache in order to avoid duplicated disk usage. 

974 | Defaults to `True` if :paramref:`keep_the_source` is `True` and the run artifacts repo is local. 

975 

976 :param use_cache: | If `True`, keep artifact in local cache after publishing. 

977 | Defaults to `True` if the run artifacts repo is remote. 

978 """ 

979 tmp = None 

980 

981 if using_dumper := callable(source): 

982 logger.debug("Using temporary artifact path: %s", tmp := self._get_temp_artifacts_dir()) 

983 source(source := tmp.joinpath("artifact")) 

984 

985 if (source := Path(source).expanduser().resolve()).is_relative_to(self._run_artifacts_cache): 

986 raise paths.IllegalPath(f"Source path points to artifact cache: {source}") 

987 

988 if keep_the_source is None: 

989 keep_the_source = False if using_dumper else True # noqa: SIM211 

990 

991 try: 

992 if using_dumper: 

993 assert path_in_run, "When using an artifact dumper, `path_in_run` must be specified." 

994 else: 

995 path_in_run = path_in_run or source.name 

996 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=False) 

997 

998 run = self._coerce_run(run) 

999 target = urls.urljoin(run.repo_url, path_in_run) 

1000 

1001 if repo_is_local := urls.is_local(target): 

1002 if use_cache is None: 

1003 use_cache = False 

1004 

1005 if allow_duplication is None: 

1006 allow_duplication = True if keep_the_source or use_cache else False # noqa: SIM211,SIM210 

1007 

1008 if keep_the_source: 

1009 if allow_duplication: 

1010 mode = "copy" 

1011 else: 

1012 raise RuntimeError("Cannot keep the source without duplication when artifacts repo is local") 

1013 else: 

1014 mode = "move" 

1015 

1016 paths.place_path( 

1017 source, 

1018 target.path, 

1019 mode=typing.cast(paths.PathOperation, mode), 

1020 overwrite=True, 

1021 move_abs_links=False, 

1022 ) 

1023 else: 

1024 if use_cache is None: 

1025 use_cache = True 

1026 

1027 if allow_duplication is None: 

1028 allow_duplication = False 

1029 

1030 self.file_transfer.push_files(source, target) 

1031 except BaseException as exc: 

1032 raise exceptions.FailedToPublishArtifact(source) from exc 

1033 

1034 logger.debug(f"Artifact successfully published to '{target}'") 

1035 

1036 if use_cache: 

1037 with self._lock_run_artifact(run.id, path_in_run, allow_base_resolve=False) as cache: 

1038 if repo_is_local: 

1039 if allow_duplication: 

1040 paths.place_path(target.path, cache, mode="copy", overwrite=True) 

1041 else: 

1042 raise RuntimeError("Cannot cache artifact without duplication when run artifacts repo is local") 

1043 elif keep_the_source: 

1044 if allow_duplication: 

1045 paths.place_path(source, cache, mode="copy", overwrite=True) 

1046 else: 

1047 logger.warning("Keeping the `source` as a symbolic link to the cached artifact") 

1048 logger.debug(f"{source} -> {cache}") 

1049 paths.place_path(source, cache, mode="move", overwrite=True) 

1050 paths.place_path(cache, source, mode="link", overwrite=True) 

1051 else: 

1052 paths.place_path(source, cache, mode="move", overwrite=True) 

1053 

1054 if not keep_the_source: 

1055 paths.ensure_non_existing(source, force=True) 

1056 if tmp is not None: 

1057 paths.ensure_non_existing(tmp, force=True) 

1058 

1059 @pydantic.validate_arguments 

1060 def log_model_version( 

1061 self, 

1062 model: ModelIdentifier, 

1063 run: RunIdentifier, 

1064 source: Path | Callable[[Path], None], 

1065 path_in_run: str | None = None, 

1066 keep_the_source: bool | None = None, 

1067 allow_duplication: bool | None = None, 

1068 use_cache: bool | None = None, 

1069 version: str | None = None, 

1070 tags: Mapping | None = None, 

1071 ) -> ModelVersionApi: 

1072 """Publish artifact file or dir as model version inside the specified experiment run. 

1073 

1074 :param model: | Model name or object. 

1075 

1076 :param run: | Run ID or object. 

1077 

1078 :param source: | See :paramref:`log_run_artifact.source` 

1079 

1080 :param path_in_run: | Plain relative path inside run artifacts (e.g.: `a/b/c`). 

1081 | Defaults to model name. 

1082 

1083 :param keep_the_source: | See :paramref:`log_run_artifact.keep_the_source` 

1084 

1085 :param allow_duplication: | See :paramref:`log_run_artifact.allow_duplication` 

1086 

1087 :param use_cache: | See :paramref:`log_run_artifact.use_cache` 

1088 

1089 :param version: | Arbitrary model version 

1090 | (not supported by all backends). 

1091 

1092 :param tags: | Model version tags. 

1093 | See :class:`schema.ModelVersion.tags` 

1094 

1095 :return: New model version metadata with API handle. 

1096 """ 

1097 logger.info("Logging version of model '%s'", model_name := self._coerce_model_name(model)) 

1098 path_in_run = path_in_run or patterns.encode_model_name(model_name) 

1099 self.log_run_artifact(run, source, path_in_run, keep_the_source, allow_duplication, use_cache) 

1100 return ModelVersionApi(**self._register_mv(model, run, path_in_run, version, tags or {})).using(self) 

1101 

1102 @pydantic.validate_arguments 

1103 def get_exp_url(self, exp: ExpIdentifier) -> str: 

1104 """Get Experiment URL. 

1105 

1106 :param exp: Exp ID or object. 

1107 """ 

1108 return str(self._impl_get_exp_url(self._coerce_exp_id(exp))) 

1109 

1110 @pydantic.validate_arguments 

1111 def get_run_url(self, run: RunIdentifier, exp: ExpIdentifier | None = None) -> str: 

1112 """Get Run URL. 

1113 

1114 :param run: Run ID or object. 

1115 :param exp: Exp ID or object. 

1116 

1117 Caveats: 

1118 - :paramref:`exp` must be specified on :attr:`~BaseMlflowApi.offline_mode` 

1119 if :paramref:`run` is an ID and the run metadata is not in cache. 

1120 """ 

1121 exp = self._coerce_run(run).exp if exp is None else exp 

1122 return str(self._impl_get_run_url(self._coerce_run_id(run), self._coerce_exp_id(exp))) 

1123 

1124 @pydantic.validate_arguments 

1125 def get_model_url(self, model: ModelIdentifier) -> str: 

1126 """Get URL to registered model. 

1127 

1128 :param model: Model name or object. 

1129 """ 

1130 return str(self._impl_get_model_url(self._coerce_model_name(model))) 

1131 

1132 @pydantic.validate_arguments 

1133 def get_model_version_url(self, model_version: ModelVersionIdentifier) -> str: 

1134 """Get model version URL. 

1135 

1136 :param model_version: Model version object or `(name, version)` tuple. 

1137 """ 

1138 return str(self._impl_get_mv_url(*self._coerce_mv_tuple(model_version))) 

1139 

1140 @pydantic.validate_arguments 

1141 def get_exp(self, exp: ExpIdentifier, **cache_opts: bool) -> ExpApi: 

1142 """Get Experiment API by ID. 

1143 

1144 :param exp: Exp ID or object. 

1145 """ 

1146 return ExpApi(**self._get_exp(self._coerce_exp_id(exp), **cache_opts)).using(self) 

1147 

1148 @pydantic.validate_arguments 

1149 def get_run(self, run: RunIdentifier, **cache_opts: bool) -> RunApi: 

1150 """Get Run API by ID. 

1151 

1152 :param run: Run ID or object. 

1153 """ 

1154 return RunApi(**self._get_run(self._coerce_run_id(run), **cache_opts)).using(self) 

1155 

1156 @pydantic.validate_arguments 

1157 def get_model(self, model: ModelIdentifier, **cache_opts: bool) -> ModelApi: 

1158 """Get Model API by name. 

1159 

1160 :param model: Model name or object. 

1161 """ 

1162 return ModelApi(**self._get_model(self._coerce_model_name(model), **cache_opts)).using(self) 

1163 

1164 @pydantic.validate_arguments 

1165 def get_model_version(self, model_version: ModelVersionIdentifier, **cache_opts: bool) -> ModelVersionApi: 

1166 """Get ModelVersion API by name and version. 

1167 

1168 :param model_version: Model version object or `(name, version)` tuple. 

1169 """ 

1170 return ModelVersionApi(**self._get_mv(self._coerce_mv_tuple(model_version), **cache_opts)).using(self) 

1171 

1172 @pydantic.validate_arguments 

1173 def find_exps(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[ExpApi]: 

1174 """Search experiments with query in MongoDB query language. 

1175 

1176 :param query: Query in MongoDB query language. 

1177 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`). 

1178 """ 

1179 return (ExpApi(**x).using(self) for x in self._find_experiments(query or {}, sorting or [])) 

1180 

1181 @pydantic.validate_arguments 

1182 def find_runs(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[RunApi]: 

1183 """Search runs with query in MongoDB query language. 

1184 

1185 :param query: Query in MongoDB query language. 

1186 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`). 

1187 """ 

1188 return (RunApi(**x).using(self) for x in self._find_runs(query or {}, sorting or [])) 

1189 

1190 @pydantic.validate_arguments 

1191 def find_models(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[ModelApi]: 

1192 """Search registered models with query in MongoDB query language. 

1193 

1194 :param query: Query in MongoDB query language. 

1195 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`). 

1196 """ 

1197 return (ModelApi(**x).using(self) for x in self._find_models(query or {}, sorting or [])) 

1198 

1199 @pydantic.validate_arguments 

1200 def find_model_versions( 

1201 self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None 

1202 ) -> Iterator[ModelVersionApi]: 

1203 """Search model versions with query in MongoDB query language. 

1204 

1205 :param query: Query in MongoDB query language. 

1206 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`). 

1207 """ 

1208 return (ModelVersionApi(**x).using(self) for x in self._find_mv(query or {}, sorting or [])) 

1209 

1210 @pydantic.validate_arguments 

1211 def find_child_runs(self, parent: RunIdentifier) -> Iterator[RunApi]: 

1212 """Find child runs. 

1213 

1214 :param parent: Run ID or object. 

1215 """ 

1216 return (RunApi(**x).using(self) for x in self._impl_find_child_runs(self._coerce_run(parent))) 

1217 

1218 @pydantic.validate_arguments 

1219 def cache_exp_meta(self, exp: ExpIdentifier) -> ExpApi: 

1220 """Get latest Experiment metadata and save to local cache. 

1221 

1222 :param exp: Experiment ID or object. 

1223 """ 

1224 return self.get_exp(exp, force_cache_refresh=True) 

1225 

1226 @pydantic.validate_arguments 

1227 def cache_run_meta(self, run: RunIdentifier) -> RunApi: 

1228 """Get latest Run metadata and save to local cache. 

1229 

1230 :param run: Run ID or object. 

1231 """ 

1232 return self.get_run(run, force_cache_refresh=True) 

1233 

1234 @pydantic.validate_arguments 

1235 def cache_model_meta(self, model: ModelIdentifier) -> ModelApi: 

1236 """Get latest Model metadata and save to local cache. 

1237 

1238 :param model: Model name or object. 

1239 """ 

1240 return self.get_model(model, force_cache_refresh=True) 

1241 

1242 @pydantic.validate_arguments 

1243 def cache_model_version_meta(self, model_version: ModelVersionIdentifier) -> ModelVersionApi: 

1244 """Get latest model version metadata and save to local cache. 

1245 

1246 :param model_version: Model version object or `(name, version)` tuple. 

1247 """ 

1248 return self.get_model_version(model_version, force_cache_refresh=True) 

1249 

1250 @pydantic.validate_arguments 

1251 def export_exp_meta(self, exp: ExpIdentifier, target: Path) -> ExpApi: 

1252 """Export experiment metadata cache to target. 

1253 

1254 :param exp: Experiment ID or object. 

1255 :param target: Cache export path. 

1256 """ 

1257 self._export_meta(exp := self.get_exp(id_ := self._coerce_exp_id(exp)), self._get_exp_cache(id_), target) 

1258 return exp 

1259 

1260 @pydantic.validate_arguments 

1261 def export_run_meta(self, run: RunIdentifier, target: Path) -> RunApi: 

1262 """Export run metadata cache to target. 

1263 

1264 :param run: Run ID or object. 

1265 :param target: Cache export path. 

1266 """ 

1267 self._export_meta(run := self.get_run(id_ := self._coerce_run_id(run)), self._get_run_cache(id_), target) 

1268 self.export_exp_meta(run.exp, target) 

1269 return run 

1270 

1271 @pydantic.validate_arguments 

1272 def export_model_meta(self, model: ModelIdentifier, target: Path) -> ModelApi: 

1273 """Export model metadata cache to target. 

1274 

1275 :param model: Model name or object. 

1276 :param target: Cache export path. 

1277 """ 

1278 name = self._coerce_model_name(model) 

1279 self._export_meta(model := self.get_model(name), self._get_model_cache(name), target) 

1280 return model 

1281 

1282 @pydantic.validate_arguments 

1283 def export_model_version_meta(self, mv: ModelVersionIdentifier, target: Path) -> ModelVersionApi: 

1284 """Export model version metadata cache to target. 

1285 

1286 :param mv: Model version object or `(name, version)` tuple. 

1287 :param target: Cache export path. 

1288 """ 

1289 tup = self._coerce_mv_tuple(mv) 

1290 self._export_meta(mv := self.get_model_version(tup), self._get_mv_cache(*tup), target) 

1291 self.export_model_meta(mv.model, target) 

1292 return mv 

1293 

1294 @pydantic.validate_arguments 

1295 def create_exp(self, name: str, tags: Mapping | None = None) -> ExpApi: 

1296 """Create Experiment and return its API. 

1297 

1298 :param name: See :attr:`schema.Experiment.name`. 

1299 :param tags: See :attr:`schema.Experiment.tags`. 

1300 """ 

1301 return ExpApi(**self._create_exp(name, tags or {})).using(self) 

1302 

1303 @pydantic.validate_arguments 

1304 def get_or_create_exp(self, name: str) -> ExpApi: 

1305 """Get or create Experiment and return its API. 

1306 

1307 :param name: See :attr:`schema.Experiment.name`. 

1308 """ 

1309 for exp in self._find_experiments({"name": name}, []): 

1310 break 

1311 else: 

1312 exp = self._create_exp(name, tags={}) 

1313 

1314 return ExpApi(**exp).using(self) 

1315 

1316 @pydantic.validate_arguments 

1317 def create_model(self, name: str, tags: Mapping | None = None) -> ModelApi: 

1318 """Create registered model and return its API. 

1319 

1320 :param name: See :attr:`schema.Model.name`. 

1321 :param tags: See :attr:`schema.Model.tags`. 

1322 """ 

1323 return ModelApi(**self._create_model(name, tags or {})).using(self) 

1324 

1325 @pydantic.validate_arguments 

1326 def get_or_create_model(self, name: str) -> ModelApi: 

1327 """Get or create registered Model and return its API. 

1328 

1329 :param name: See :attr:`schema.Model.name`. 

1330 """ 

1331 for model in self._find_models({"name": name}, []): 

1332 break 

1333 else: 

1334 model = self._create_model(name, tags={}) 

1335 

1336 return ModelApi(**model).using(self) 

1337 

1338 @pydantic.validate_arguments 

1339 def create_run( 

1340 self, 

1341 exp: ExpIdentifier, 

1342 name: str | None = None, 

1343 tags: Mapping | None = None, 

1344 repo: str | urls.Url | None = None, 

1345 parent: RunIdentifier | None = None, 

1346 ) -> RunApi: 

1347 """Declare a new experiment run to be used later. 

1348 

1349 :param exp: Experiment ID or object. 

1350 :param name: See :attr:`schema.Run.name`. 

1351 :param tags: See :attr:`schema.Run.tags`. 

1352 :param repo: (Experimental) Cloud storage URL to be used as alternative run artifacts repository. 

1353 :param parent: Parent run ID or object. 

1354 """ 

1355 return RunApi(**self._create_run(exp, name, repo, tags or {}, schema.RunStatus.SCHEDULED, parent)).using(self) 

1356 

1357 @pydantic.validate_arguments 

1358 def start_run( 

1359 self, 

1360 exp: ExpIdentifier, 

1361 name: str | None = None, 

1362 tags: Mapping | None = None, 

1363 repo: str | urls.Url | None = None, 

1364 parent: RunIdentifier | None = None, 

1365 ) -> RunApi: 

1366 """Start a new experiment run. 

1367 

1368 :param exp: Experiment ID or object. 

1369 :param name: See :attr:`schema.Run.name`. 

1370 :param tags: See :attr:`schema.Run.tags`. 

1371 :param repo: (Experimental) Cloud storage URL to be used as alternative run artifacts repository. 

1372 :param parent: Parent run ID or object. 

1373 """ 

1374 return RunApi(**self._create_run(exp, name, repo, tags or {}, schema.RunStatus.RUNNING, parent)).using(self) 

1375 

1376 @pydantic.validate_arguments 

1377 def resume_run(self, run: RunIdentifier) -> RunApi: 

1378 """Resume a previous experiment run. 

1379 

1380 :param run: Run ID or object. 

1381 """ 

1382 self._set_run_status(run_id := self._coerce_run_id(run), schema.RunStatus.RUNNING) 

1383 return self.get_run(run_id) 

1384 

1385 @pydantic.validate_arguments 

1386 def end_run(self, run: RunIdentifier, succeeded: bool = True) -> RunApi: 

1387 """End experiment run. 

1388 

1389 :param run: Run ID or object. 

1390 :param succeeded: Whether the run was successful. 

1391 """ 

1392 status = schema.RunStatus.FINISHED if succeeded else schema.RunStatus.FAILED 

1393 self._set_run_status(run_id := self._coerce_run_id(run), status) 

1394 self._set_run_end_time(run_id, datetime.now()) 

1395 return self.get_run(run_id) 

1396 

1397 @pydantic.validate_arguments 

1398 def set_tags_on_exp(self, exp: ExpIdentifier, tags: Mapping): 

1399 """Set tags on experiment. 

1400 

1401 :param exp: Experiment ID or object. 

1402 :param tags: See :attr:`schema.Experiment.tags`. 

1403 """ 

1404 self._update_exp_tags(exp, tags) 

1405 

1406 @pydantic.validate_arguments 

1407 def set_tags_on_run(self, run: RunIdentifier, tags: Mapping): 

1408 """Set tags on experiment run. 

1409 

1410 :param run: Run ID or object. 

1411 :param tags: See :attr:`schema.Run.tags`. 

1412 """ 

1413 self._update_run_tags(run, tags) 

1414 

1415 @pydantic.validate_arguments 

1416 def set_tags_on_model(self, model: ModelIdentifier, tags: Mapping): 

1417 """Set tags on registered model. 

1418 

1419 :param model: Model name or object. 

1420 :param tags: See :attr:`schema.Model.tags`. 

1421 """ 

1422 self._update_model_tags(model, tags) 

1423 

1424 @pydantic.validate_arguments 

1425 def set_tags_on_model_version(self, model_version: ModelVersionIdentifier, tags: Mapping): 

1426 """Set tags on model version. 

1427 

1428 :param model_version: Model version object or `(name, version)` tuple. 

1429 :param tags: See :attr:`schema.Model.tags`. 

1430 """ 

1431 self._update_mv_tags(model_version, tags) 

1432 

1433 @pydantic.validate_arguments 

1434 def log_params(self, run: RunIdentifier, params: Mapping): 

1435 """Log params to experiment run. 

1436 

1437 :param run: Run ID or object. 

1438 :param params: See :attr:`schema.Run.params`. 

1439 """ 

1440 self._log_run_params(run, params) 

1441 

1442 @pydantic.validate_arguments 

1443 def log_metrics(self, run: RunIdentifier, metrics: Mapping): 

1444 """Log metrics to experiment run. 

1445 

1446 :param run: Run ID or object. 

1447 :param metrics: See :attr:`schema.Run.metrics`. 

1448 """ 

1449 self._log_run_metrics(run, metrics)