Coverage for src/mlopus/mlflow/api/base.py: 87%
626 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import contextlib
2import logging
3import os.path
4import tempfile
5import typing
6from abc import ABC, abstractmethod
7from datetime import datetime
8from pathlib import Path
9from typing import Type, TypeVar, Callable, Iterator, Tuple, Mapping
11from mlopus.utils import pydantic, paths, urls, mongo, string_utils, iter_utils
12from . import contract
13from .common import schema, decorators, serde, exceptions, patterns, transfer
14from .exp import ExpApi
15from .model import ModelApi
16from .mv import ModelVersionApi
17from .run import RunApi
19logger = logging.getLogger(__name__)
21A = TypeVar("A") # Any type
23# Entity types
24E = schema.Experiment
25R = schema.Run
26M = schema.Model
27V = schema.ModelVersion
28T = TypeVar("T", bound=schema.BaseEntity)
30# Identifier types
31ExpIdentifier = contract.ExpIdentifier | ExpApi
32RunIdentifier = contract.RunIdentifier | RunApi
33ModelIdentifier = contract.ModelIdentifier | ModelApi
34ModelVersionIdentifier = contract.ModelVersionIdentifier | ModelVersionApi
37class BaseMlflowApi(contract.MlflowApiContract, ABC, frozen=True):
38 """Base class for API clients that use "MLflow-like" backends for experiment tracking and model registry.
40 Important:
41 Implementations of this interface are meant to be thread-safe and independent of env vars/globals,
42 so multiple API instances can coexist in the same program if necessary.
43 """
45 cache_dir: Path | None = pydantic.Field(
46 default=None,
47 description=(
48 "Root path for cached artifacts and metadata. "
49 "If not specified, then a default is determined by the respective API plugin."
50 ),
51 )
53 offline_mode: bool = pydantic.Field(
54 default=False,
55 description=(
56 "If `True`, block all operations that attempt communication "
57 "with the MLflow server (i.e.: only use cached metadata). "
58 "Artifacts are still accessible if they are cached or if "
59 ":attr:`pull_artifacts_in_offline_mode` is `True`."
60 ),
61 )
63 pull_artifacts_in_offline_mode: bool = pydantic.Field(
64 default=False,
65 description=(
66 "If `True`, allow pulling artifacts from storage to cache in offline mode. "
67 "Useful if caching metadata only and pulling artifacts on demand "
68 "(the artifact's URL must be known beforehand, e.g. by caching the metadata of its parent entity). "
69 ),
70 )
72 temp_artifacts_dir: Path = pydantic.Field(
73 default=None,
74 description=(
75 "Path for temporary artifacts that are stored by artifact dumpers before being published and preserved "
76 "after a publish error (e.g.: an upload interruption). Defaults to a path inside the local cache."
77 ),
78 )
80 cache_local_artifacts: bool = pydantic.Field(
81 default=False,
82 description=(
83 "Use local cache even if the run artifacts repository is in the local file system. "
84 "May be used for testing cache without connecting to a remote MLflow server."
85 "Not recommended in production because of unecessary duplicated disk usage. "
86 ),
87 )
89 always_pull_artifacts: bool = pydantic.Field(
90 default=False,
91 description=(
92 "When accessing a cached artifact file or dir, re-sync it with the remote artifacts repository, even "
93 "on a cache hit. Prevents accessing stale data if the remote artifact has been changed in the meantime. "
94 "The default data transfer utility (based on rclone) is pretty efficient for syncing directories, but "
95 "enabling this option may still add some overhead of calculating checksums if they contain many files."
96 ),
97 )
99 file_transfer: transfer.FileTransfer = pydantic.Field(
100 repr=False,
101 default_factory=transfer.FileTransfer,
102 description=(
103 "Utility for uploading/downloading artifact files or dirs. Also used for listing files. Based on "
104 "RClone by default. Users may replace this with a different implementation when subclassing the API."
105 ),
106 )
108 entity_serializer: serde.EntitySerializer = pydantic.Field(
109 repr=False,
110 default_factory=serde.EntitySerializer,
111 description=(
112 "Utility for (de)serializing entity metadata (i.e.: exp, runs, models, versions)."
113 "Users may replace this with a different implementation when subclassing the API."
114 ),
115 )
117 def __init__(self, **kwargs):
118 super().__init__(**kwargs)
120 # Apply default to cache dir, expand and resolve
121 if paths.is_cwd(cache := self.cache_dir or ""):
122 cache = self._impl_default_cache_dir()
123 pydantic.force_set_attr(self, key="cache_dir", val=cache.expanduser().resolve())
125 # Apply default to temp artifacts dir, expand and resolve
126 if paths.is_cwd(tmp := self.temp_artifacts_dir or ""):
127 tmp = self._artifacts_cache.joinpath("temp")
128 pydantic.force_set_attr(self, key="temp_artifacts_dir", val=tmp.expanduser().resolve())
130 @property
131 def in_offline_mode(self) -> "BaseMlflowApi":
132 """Get an offline copy of this API."""
133 return self.model_copy(update={"offline_mode": True})
135 # =======================================================================================================
136 # === Metadata cache locators ===========================================================================
138 @property
139 def _metadata_cache(self) -> Path:
140 return self.cache_dir.joinpath("metadata")
142 @property
143 def _exp_cache(self) -> Path:
144 return self._metadata_cache.joinpath("exp")
146 @property
147 def _run_cache(self) -> Path:
148 return self._metadata_cache.joinpath("run")
150 @property
151 def _model_cache(self) -> Path:
152 return self._metadata_cache.joinpath("model")
154 @property
155 def _mv_cache(self) -> Path:
156 return self._metadata_cache.joinpath("mv")
158 def _get_exp_cache(self, exp_id: str) -> Path:
159 return self._exp_cache.joinpath(exp_id)
161 def _get_run_cache(self, run_id: str) -> Path:
162 return self._run_cache.joinpath(run_id)
164 def _get_model_cache(self, name: str) -> Path:
165 return self._model_cache.joinpath(patterns.encode_model_name(name))
167 def _get_mv_cache(self, name: str, version: str) -> Path:
168 return self._mv_cache.joinpath(patterns.encode_model_name(name), version)
170 # =======================================================================================================
171 # === Artifact cache locators ===========================================================================
173 @property
174 def _artifacts_cache(self) -> Path:
175 return self.cache_dir.joinpath("artifacts")
177 @property
178 def _run_artifacts_cache(self) -> Path:
179 return self._artifacts_cache.joinpath("runs")
181 def _get_run_artifact_cache_path(
182 self, run: RunIdentifier, path_in_run: str = "", allow_base_resolve: bool = True
183 ) -> Path:
184 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=allow_base_resolve)
185 return self._run_artifacts_cache.joinpath(self._coerce_run_id(run), path_in_run)
187 def _get_temp_artifacts_dir(self) -> Path:
188 return Path(tempfile.mkdtemp(dir=paths.ensure_is_dir(self.temp_artifacts_dir)))
190 # =======================================================================================================
191 # === Cache protection ==================================================================================
193 @contextlib.contextmanager
194 def _lock_run_artifact(self, run_id: str, path_in_run: str, allow_base_resolve: bool = True) -> Path:
195 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=allow_base_resolve)
196 with paths.dir_lock(self._get_run_artifact_cache_path(run_id)) as path:
197 yield path.joinpath(path_in_run)
199 # =======================================================================================================
200 # === Cache cleanup =====================================================================================
202 def _clean_temp_artifacts(self):
203 paths.ensure_non_existing(self.temp_artifacts_dir, force=True)
205 def _clean_all_meta_cache(self):
206 for path in (self._exp_cache, self._run_cache, self._model_cache, self._mv_cache):
207 with paths.dir_lock(path):
208 paths.ensure_empty_dir(path, force=True)
210 def _clean_run_artifact(self, run_id: str, path_in_run: str = ""):
211 with self._lock_run_artifact(run_id, path_in_run) as path:
212 paths.ensure_non_existing(path, force=True)
214 def _clean_all_runs_artifacts(self):
215 with paths.dir_lock(self._run_artifacts_cache) as path:
216 paths.ensure_non_existing(path, force=True)
218 def _clean_all_cache(self):
219 self._clean_all_meta_cache()
220 self._clean_all_runs_artifacts()
222 # =======================================================================================================
223 # === Metadata Getters ==================================================================================
225 @classmethod
226 def _meta_fetcher(cls, fetcher: Callable[[A], T], args: A) -> Callable[[], T]:
227 return lambda: fetcher(args)
229 def _meta_cache_reader(self, path: Path, type_: Type[T]) -> Callable[[], T]:
230 return lambda: self.entity_serializer.load(type_, path)
232 def _meta_cache_writer(self, lock_dir: Path, path: Path, type_: Type[T]) -> Callable[[T], None]:
233 def _write_meta_cache(meta: T):
234 assert isinstance(meta, type_)
235 with paths.dir_lock(lock_dir):
236 self.entity_serializer.save(meta, path)
238 return _write_meta_cache
240 def _get_meta(
241 self,
242 fetcher: Callable[[], T],
243 cache_reader: Callable[[], T],
244 cache_writer: Callable[[T], None],
245 force_cache_refresh: bool = False,
246 ) -> T:
247 """Get metadata."""
248 if force_cache_refresh:
249 if self.offline_mode:
250 raise RuntimeError("Cannot refresh cache on offline mode.")
251 cache_writer(meta := fetcher())
252 elif not self.offline_mode:
253 meta = fetcher()
254 else:
255 meta = cache_reader()
257 return meta
259 def _get_exp(self, exp_id: str, **cache_opts: bool) -> E:
260 """Get Experiment metadata."""
261 return self._get_meta(
262 self._meta_fetcher(self._fetch_exp, exp_id),
263 self._meta_cache_reader(cache := self._get_exp_cache(exp_id), E),
264 self._meta_cache_writer(self._exp_cache, cache, E),
265 **cache_opts,
266 )
268 def _get_run(self, run_id: str, **cache_opts: bool) -> R:
269 """Get Run metadata."""
270 return self._get_meta(
271 self._meta_fetcher(self._fetch_run, run_id),
272 self._meta_cache_reader(cache := self._get_run_cache(run_id), R),
273 lambda run: [
274 self._meta_cache_writer(self._run_cache, cache, R)(run),
275 self._meta_cache_writer(self._exp_cache, self._get_exp_cache(run.exp.id), E)(run.exp),
276 ][0],
277 **cache_opts,
278 )
280 def _get_model(self, name: str, **cache_opts: bool) -> M:
281 """Get Model metadata."""
282 return self._get_meta(
283 self._meta_fetcher(self._fetch_model, name),
284 self._meta_cache_reader(cache := self._get_model_cache(name), M),
285 self._meta_cache_writer(self._model_cache, cache, M),
286 **cache_opts,
287 )
289 def _get_mv(self, name_and_version: Tuple[str, str], **cache_opts: bool) -> V:
290 """Get ModelVersion metadata."""
291 return self._get_meta(
292 self._meta_fetcher(self._fetch_mv, name_and_version),
293 self._meta_cache_reader(cache := self._get_mv_cache(*name_and_version), V),
294 lambda mv: [
295 self._meta_cache_writer(self._mv_cache, cache, V)(mv),
296 self._meta_cache_writer(self._model_cache, self._get_model_cache(mv.model.name), M)(mv.model),
297 ][0],
298 **cache_opts,
299 )
301 @decorators.online
302 def _fetch_exp(self, exp_id: str) -> E:
303 return self._impl_fetch_exp(exp_id)
305 @decorators.online
306 def _fetch_run(self, run_id: str) -> R:
307 return self._impl_fetch_run(run_id)
309 @decorators.online
310 def _fetch_model(self, name: str) -> M:
311 return self._impl_fetch_model(name)
313 @decorators.online
314 def _fetch_mv(self, name_and_version: Tuple[str, str]) -> V:
315 return self._impl_fetch_mv(name_and_version)
317 def _export_meta(self, meta: T, cache: Path, target: Path):
318 paths.ensure_only_parents(target := target / cache.relative_to(self.cache_dir), force=True)
319 self.entity_serializer.save(meta, target)
321 # =======================================================================================================
322 # === Metadata Finders ==================================================================================
324 def _find_experiments(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[E]:
325 return self._find_meta(self._exp_cache, E, query, sorting, self._impl_find_experiments)
327 def _find_runs(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[R]:
328 return self._find_meta(self._run_cache, R, query, sorting, self._impl_find_runs)
330 def _find_models(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[M]:
331 return self._find_meta(self._model_cache, M, query, sorting, self._impl_find_models)
333 def _find_mv(self, query: mongo.Query, sorting: mongo.Sorting) -> Iterator[V]:
334 return self._find_meta(self._mv_cache, V, query, sorting, self._impl_find_mv)
336 def _find_meta(
337 self,
338 cache: Path,
339 type_: Type[T],
340 query: mongo.Query,
341 sorting: mongo.Sorting,
342 finder: Callable[[mongo.Query, mongo.Sorting], Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[T]]],
343 ) -> Iterator[T]:
344 if self.offline_mode:
345 logger.warning("Metadata search in offline mode may yield incomplete or stale results.")
346 cached = (self.entity_serializer.load(type_, x) for x in paths.iter_files(cache))
347 paginator = iter_utils.Paginator[T].single_page(cached)
348 else:
349 query, sorting, paginator = finder(query, sorting)
351 if sorting:
352 paginator = paginator.collapse()
354 if query:
355 paginator = paginator.map_pages(
356 lambda results: list(
357 mongo.find_all(
358 results,
359 query=query,
360 sorting=sorting,
361 to_doc=lambda obj: obj.dict(),
362 from_doc=type_.parse_obj,
363 )
364 )
365 )
367 for page in paginator:
368 for result in page:
369 yield result
371 # =======================================================================================================
372 # === Artifact getters ==================================================================================
374 def _list_run_artifacts(self, run: RunIdentifier, path_in_run: str = "") -> transfer.LsResult:
375 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True)
377 if self.offline_mode and not self.pull_artifacts_in_offline_mode:
378 logger.warning("Listing run artifacts in offline mode may yield incomplete or stale results.")
379 subject = self._get_run_artifact_cache_path(run, path_in_run)
380 elif urls.is_local(subject := urls.urljoin(self._coerce_run(run).repo, path_in_run)):
381 subject = subject.path
383 return self.file_transfer.ls(subject)
385 def _pull_run_artifact(self, run: RunIdentifier, path_in_run: str) -> Path:
386 """Pull artifact from run repo to local cache, unless repo is already local."""
387 if self.offline_mode and not self.pull_artifacts_in_offline_mode:
388 raise RuntimeError("Artifact pull is disabled.")
390 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True)
392 with self._lock_run_artifact((run := self._coerce_run(run)).id, path_in_run) as target:
393 if urls.is_local(url := run.repo_url):
394 source = Path(url.path) / path_in_run
396 if self.cache_local_artifacts:
397 paths.place_path(source, target, mode="copy", overwrite=True)
398 else:
399 logger.info("Run artifacts repo is local, nothing to pull.")
400 return source
401 else:
402 self.file_transfer.pull_files(urls.urljoin(run.repo_url, path_in_run), target)
404 return target
406 def _get_run_artifact(self, run: RunIdentifier, path_in_run: str) -> Path:
407 """Get path to local run artifact, may trigger a cache pull."""
408 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=True)
409 cache = self._get_run_artifact_cache_path(run, path_in_run)
411 if (not self.offline_mode or self.pull_artifacts_in_offline_mode) and (
412 not cache.exists() or self.always_pull_artifacts
413 ):
414 return self._pull_run_artifact(run, path_in_run)
416 if not cache.exists():
417 raise FileNotFoundError(cache)
419 return cache
421 def _place_run_artifact(
422 self,
423 run: RunIdentifier,
424 path_in_run: str,
425 target: Path,
426 link: bool,
427 overwrite: bool,
428 ):
429 """Place local run artifact on target path, may trigger a cache pull.
431 The resulting files are always write-protected, but directories are not.
432 """
433 mode = typing.cast(paths.PathOperation, "link" if link else "copy")
435 if (src := self._get_run_artifact(run, path_in_run)).is_dir() and link:
436 paths.ensure_only_parents(target, force=overwrite)
438 for dirpath, _, filenames in os.walk(src): # Recursively create symbolic links for files
439 relpath = Path(dirpath).relative_to(src)
440 for filename in filenames:
441 paths.place_path(Path(dirpath) / filename, target / relpath / filename, mode, overwrite)
442 else:
443 paths.place_path(src, target, mode, overwrite)
445 if target.is_dir() and not link: # Recursively fix permissions of copied directories
446 target.chmod(paths.Mode.rwx)
448 for dirpath, dirnames, _ in os.walk(target):
449 for dirname in dirnames:
450 Path(dirpath, dirname).chmod(paths.Mode.rwx)
452 # =======================================================================================================
453 # === Arguments pre-processing ==========================================================================
455 def _coerce_exp(self, exp: ExpIdentifier) -> E:
456 match exp:
457 case E():
458 return exp
459 case str():
460 return self._get_exp(exp)
461 case _:
462 raise TypeError("Expected Experiment, ExpApi or experiment ID as string.")
464 @classmethod
465 @string_utils.retval_matches(patterns.EXP_ID)
466 def _coerce_exp_id(cls, exp: ExpIdentifier) -> str:
467 match exp:
468 case E():
469 return exp.id
470 case str():
471 return exp
472 case _:
473 raise TypeError("Expected Experiment, ExpApi or experiment ID as string.")
475 def _coerce_run(self, run: RunIdentifier) -> R:
476 match run:
477 case R():
478 return run
479 case str():
480 return self._get_run(run)
481 case _:
482 raise TypeError("Expected Run, RunApi or run ID as string.")
484 @classmethod
485 @string_utils.retval_matches(patterns.RUN_ID)
486 def _coerce_run_id(cls, run: RunIdentifier) -> str:
487 match run:
488 case R():
489 return run.id
490 case str():
491 return run
492 case _:
493 raise TypeError("Expected Run, RunApi or run ID as string.")
495 def _coerce_model(self, model: ModelIdentifier) -> M:
496 match model:
497 case M():
498 return model
499 case str():
500 return self._get_model(model)
501 case _:
502 raise TypeError("Expected Model, ModelApi or model name as string.")
504 @classmethod
505 @string_utils.retval_matches(patterns.MODEL_NAME)
506 def _coerce_model_name(cls, model: ModelIdentifier) -> str:
507 match model:
508 case M():
509 return model.name
510 case str():
511 return model
512 case _:
513 raise TypeError("Expected Model, ModelApi or model name as string.")
515 def _coerce_mv(self, mv: ModelVersionIdentifier) -> V:
516 match mv:
517 case V():
518 return mv
519 case (str(), str()):
520 return self._get_mv(*mv)
521 case _:
522 raise TypeError("Expected ModelVersion or tuple or (name, version) strings.")
524 @classmethod
525 @string_utils.retval_matches(patterns.MODEL_NAME, index=0)
526 @string_utils.retval_matches(patterns.MODEL_VERSION, index=1)
527 def _coerce_mv_tuple(cls, mv: ModelVersionIdentifier) -> Tuple[str, str]:
528 match mv:
529 case V():
530 return mv.model.name, mv.version
531 case (str(), str()):
532 return mv
533 case _:
534 raise TypeError("Expected ModelVersion or tuple or (name, version) strings.")
536 @classmethod
537 def _valid_path_in_run(cls, path_in_run: str, allow_empty: bool = False) -> str:
538 """Validate the `path_in_run` for an artifact or model.
540 - Cannot be empty, unless specified
541 - Slashes are trimmed
542 - Cannot do path climbing (e.g.: "../")
543 - Cannot do current path referencing (e.g.: "./")
545 Valid path example: "a/b/c"
546 """
547 if (path_in_run := str(path_in_run).strip("/")) and os.path.abspath(path := "/root/" + path_in_run) == path:
548 return path_in_run
549 if allow_empty and path_in_run == "":
550 return path_in_run
551 raise paths.IllegalPath(f"`path_in_run={path_in_run}`")
553 # =======================================================================================================
554 # === Experiment tracking ===============================================================================
556 @decorators.online
557 def _create_exp(self, name: str, tags: Mapping) -> E:
558 """Create experiment and return its metadata."""
559 return self._impl_create_exp(name, tags)
561 @decorators.online
562 def _create_model(self, name: str, tags: Mapping) -> M:
563 """Create registered model and return its metadata."""
564 return self._impl_create_model(self._coerce_model_name(name), tags)
566 @decorators.online
567 def _create_run(
568 self,
569 exp: ExpIdentifier,
570 name: str | None,
571 repo: urls.Url | str | None,
572 tags: Mapping,
573 status: schema.RunStatus,
574 parent: RunIdentifier | None,
575 ) -> R:
576 """Create run with start at current UTC time."""
577 if repo is not None:
578 repo = urls.parse_url(repo, resolve_if_local=True)
579 run_id = self._impl_create_run(
580 self._coerce_exp_id(exp), name, repo, parent_run_id=self._coerce_run_id(parent) if parent else None
581 )
582 self._update_run_tags(run_id, tags)
583 self._set_run_status(run_id, status)
584 return self._get_run(run_id)
586 @decorators.online
587 def _set_run_status(self, run: RunIdentifier, status: schema.RunStatus):
588 self._impl_set_run_status(self._coerce_run_id(run), status)
590 @decorators.online
591 def _set_run_end_time(self, run: RunIdentifier, end_time: datetime):
592 self._impl_set_run_end_time(self._coerce_run_id(run), end_time)
594 @decorators.online
595 def _update_exp_tags(self, exp: ExpIdentifier, tags: Mapping):
596 self._impl_update_exp_tags(self._coerce_exp_id(exp), tags)
598 @decorators.online
599 def _update_run_tags(self, run: RunIdentifier, tags: Mapping):
600 self._impl_update_run_tags(self._coerce_run_id(run), tags)
602 @decorators.online
603 def _update_model_tags(self, model: ModelIdentifier, tags: Mapping):
604 self._impl_update_model_tags(self._coerce_model_name(model), tags)
606 @decorators.online
607 def _update_mv_tags(self, mv: ModelVersionIdentifier, tags: Mapping):
608 self._impl_update_mv_tags(*self._coerce_mv_tuple(mv), tags=tags)
610 @decorators.online
611 def _log_run_params(self, run: RunIdentifier, params: Mapping):
612 self._impl_log_run_params(self._coerce_run_id(run), params)
614 @decorators.online
615 def _log_run_metrics(self, run: RunIdentifier, metrics: Mapping):
616 self._impl_log_run_metrics(self._coerce_run_id(run), metrics)
618 # =======================================================================================================
619 # === Model registry ====================================================================================
621 @decorators.online
622 def _register_mv(
623 self, model: ModelIdentifier, run: RunIdentifier, path_in_run: str, version: str | None, tags: Mapping
624 ) -> V:
625 path_in_run = self._valid_path_in_run(path_in_run)
626 return self._impl_register_mv(self._coerce_model(model), self._coerce_run(run), path_in_run, version, tags)
628 # =======================================================================================================
629 # === Abstract Methods ==================================================================================
631 @abstractmethod
632 def _impl_default_cache_dir(self) -> Path:
633 """Get default cache dir based on the current MLflow API settings."""
635 @abstractmethod
636 def _impl_get_exp_url(self, exp_id: str) -> urls.Url:
637 """Get Experiment URL."""
639 @abstractmethod
640 def _impl_get_run_url(self, run_id: str, exp_id: str) -> urls.Url:
641 """Get Run URL."""
643 @abstractmethod
644 def _impl_get_model_url(self, name: str) -> urls.Url:
645 """Get URL to registered model."""
647 @abstractmethod
648 def _impl_get_mv_url(self, name: str, version: str) -> urls.Url:
649 """Get model version URL."""
651 @abstractmethod
652 def _impl_fetch_exp(self, exp_id: str) -> E:
653 """Get Experiment by ID."""
655 @abstractmethod
656 def _impl_fetch_run(self, run_id: str) -> R:
657 """Get Run by ID."""
659 @abstractmethod
660 def _impl_fetch_model(self, name: str) -> M:
661 """Get registered Model by name."""
663 @abstractmethod
664 def _impl_fetch_mv(self, name_and_version: Tuple[str, str]) -> V:
665 """Get ModelVersion by name and version."""
667 @abstractmethod
668 def _impl_find_experiments(
669 self, query: mongo.Query, sorting: mongo.Sorting
670 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[E]]:
671 """Push down MongoDB query where possible and return query remainder with results iterator."""
673 @abstractmethod
674 def _impl_find_runs(
675 self, query: mongo.Query, sorting: mongo.Sorting
676 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[R]]:
677 """Push down MongoDB query where possible and return query remainder with results iterator."""
679 @abstractmethod
680 def _impl_find_models(
681 self, query: mongo.Query, sorting: mongo.Sorting
682 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[M]]:
683 """Push down MongoDB query where possible and return query remainder with results iterator."""
685 @abstractmethod
686 def _impl_find_mv(
687 self, query: mongo.Query, sorting: mongo.Sorting
688 ) -> Tuple[mongo.Query, mongo.Sorting, iter_utils.Paginator[V]]:
689 """Push down MongoDB query where possible and return query remainder with results iterator."""
691 @abstractmethod
692 def _impl_find_child_runs(self, run: R) -> Iterator[R]:
693 """Find child runs."""
695 @abstractmethod
696 def _impl_create_exp(self, name: str, tags: Mapping) -> E:
697 """Create experiment and return its metadata."""
699 @abstractmethod
700 def _impl_create_model(self, name: str, tags: Mapping) -> M:
701 """Create registered model and return its metadata."""
703 @abstractmethod
704 def _impl_create_run(
705 self, exp_id: str, name: str | None, repo: urls.Url | None, parent_run_id: str | None = None
706 ) -> str:
707 """Create experiment run."""
709 @abstractmethod
710 def _impl_set_run_status(self, run_id: str, status: schema.RunStatus):
711 """Set Run status."""
713 @abstractmethod
714 def _impl_set_run_end_time(self, run_id: str, end_time: datetime):
715 """Set Run end time."""
717 @abstractmethod
718 def _impl_update_exp_tags(self, exp_id: str, tags: Mapping):
719 """Update Exp tags."""
721 @abstractmethod
722 def _impl_update_run_tags(self, run_id: str, tags: Mapping):
723 """Update Run tags."""
725 @abstractmethod
726 def _impl_update_model_tags(self, name: str, tags: Mapping):
727 """Update Model tags."""
729 @abstractmethod
730 def _impl_update_mv_tags(self, name: str, version: str, tags: Mapping):
731 """Update Exp tags."""
733 @abstractmethod
734 def _impl_log_run_params(self, run_id: str, params: Mapping):
735 """Log run params."""
737 @abstractmethod
738 def _impl_log_run_metrics(self, run_id: str, metrics: Mapping):
739 """Log run metrics."""
741 @abstractmethod
742 def _impl_register_mv(self, model: M, run: R, path_in_run: str, version: str | None, tags: Mapping) -> V:
743 """Register model version."""
745 # =======================================================================================================
746 # === Public Methods ====================================================================================
748 def clean_all_cache(self):
749 """Clean all cached metadata and artifacts."""
750 self._clean_all_cache()
752 def clean_temp_artifacts(self):
753 """Clean temporary artifacts."""
754 self._clean_temp_artifacts()
756 @pydantic.validate_arguments
757 def clean_cached_run_artifact(self, run: RunIdentifier, path_in_run: str = ""):
758 """Clean cached artifact for specified run.
760 :param run: Run ID or object.
761 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
762 """
763 self._clean_run_artifact(self._coerce_run_id(run), path_in_run)
765 @pydantic.validate_arguments
766 def clean_cached_model_artifact(self, model_version: ModelVersionIdentifier):
767 """Clean cached artifact for specified model version.
769 :param model_version: Model version object or `(name, version)` tuple.
770 """
771 mv = self._coerce_mv(model_version)
772 self.clean_cached_run_artifact(mv.run, mv.path_in_run)
774 @pydantic.validate_arguments
775 def list_run_artifacts(self, run: RunIdentifier, path_in_run: str = "") -> transfer.LsResult:
776 """List run artifacts in repo.
778 :param run: Run ID or object.
779 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
780 """
781 return self._list_run_artifacts(run, path_in_run)
783 @pydantic.validate_arguments
784 def list_model_artifact(self, model_version: ModelVersionIdentifier, path_suffix: str = "") -> transfer.LsResult:
785 """List model version artifacts in repo.
787 :param model_version: Model version object or `(name, version)` tuple.
788 :param path_suffix: Plain relative path inside model artifact dir (e.g.: `a/b/c`).
789 """
790 return self.list_run_artifacts(
791 run=(mv := self._coerce_mv(model_version)).run,
792 path_in_run=mv.path_in_run + "/" + path_suffix.strip("/"),
793 )
795 @pydantic.validate_arguments
796 def cache_run_artifact(self, run: RunIdentifier, path_in_run: str = "") -> Path:
797 """Pull run artifact from MLflow server to local cache.
799 :param run: Run ID or object.
800 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
801 """
802 return self._pull_run_artifact(run, path_in_run)
804 @pydantic.validate_arguments
805 def cache_model_artifact(self, model_version: ModelVersionIdentifier) -> Path:
806 """Pull model version artifact from MLflow server to local cache.
808 :param model_version: Model version object or `(name, version)` tuple.
809 """
810 mv = self._coerce_mv(model_version)
811 return self.cache_run_artifact(mv.run, mv.path_in_run)
813 @pydantic.validate_arguments
814 def get_run_artifact(self, run: RunIdentifier, path_in_run: str = "") -> Path:
815 """Get local path to run artifact.
817 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
819 :param run: Run ID or object.
820 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
821 """
822 return self._get_run_artifact(self._coerce_run_id(run), path_in_run)
824 @pydantic.validate_arguments
825 def get_model_artifact(self, model_version: ModelVersionIdentifier) -> Path:
826 """Get local path to model artifact.
828 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
830 :param model_version: Model version object or `(name, version)` tuple.
831 """
832 mv = self._coerce_mv(model_version)
833 return self.get_run_artifact(mv.run, mv.path_in_run)
835 @pydantic.validate_arguments
836 def place_run_artifact(
837 self,
838 run: RunIdentifier,
839 target: Path,
840 path_in_run: str = "",
841 overwrite: bool = False,
842 link: bool = True,
843 ):
844 """Place run artifact on target path.
846 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
847 The resulting files are always write-protected, but directories are not.
849 :param run: Run ID or object.
850 :param target: Target path.
851 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
852 :param overwrite: Overwrite target path if exists.
853 :param link: Use symbolic link instead of copy.
854 """
855 self._place_run_artifact(self._coerce_run_id(run), path_in_run, target, link, overwrite)
857 @pydantic.validate_arguments
858 def place_model_artifact(
859 self,
860 model_version: ModelVersionIdentifier,
861 target: Path,
862 overwrite: bool = False,
863 link: bool = True,
864 ):
865 """Place model version artifact on target path.
867 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
869 :param model_version: Model version object or `(name, version)` tuple.
870 :param target: Target path.
871 :param overwrite: Overwrite target path if exists.
872 :param link: Use symbolic link instead of copy.
873 """
874 mv = self._coerce_mv(model_version)
875 self.place_run_artifact(mv.run, target, mv.path_in_run, overwrite, link)
877 @pydantic.validate_arguments
878 def export_run_artifact(
879 self,
880 run: RunIdentifier,
881 target: Path,
882 path_in_run: str = "",
883 ) -> Path:
884 """Export run artifact cache to target path while keeping the original cache structure.
886 The target path can then be used as cache dir by the `generic` MLflow API in offline mode.
888 :param run: Run ID or object.
889 :param target: Cache export path.
890 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
891 """
892 if paths.is_sub_dir(target, self.cache_dir) or paths.is_sub_dir(self.cache_dir, target):
893 raise paths.IllegalPath(f"Cannot export cache to itself, its subdirs or parents: {target}")
894 cache = self._get_run_artifact(run, path_in_run)
895 target = self.model_copy(update={"cache_dir": target})._get_run_artifact_cache_path(run, path_in_run)
896 paths.place_path(cache, target, mode="copy", overwrite=True)
897 paths.rchmod(target, paths.Mode.rwx) # Exported caches are not write-protected
898 return target
900 @pydantic.validate_arguments
901 def export_model_artifact(
902 self,
903 model_version: ModelVersionIdentifier,
904 target: Path,
905 ) -> Path:
906 """Export model version artifact cache to target path while keeping the original cache structure.
908 The target path can then be used as cache dir by the `generic` MLflow API in offline mode.
910 :param model_version: Model version object or `(name, version)` tuple.
911 :param target: Cache export path.
912 """
913 mv = self._coerce_mv(model_version)
914 return self.export_run_artifact(mv.run, target, mv.path_in_run)
916 @pydantic.validate_arguments
917 def load_run_artifact(self, run: RunIdentifier, loader: Callable[[Path], A], path_in_run: str = "") -> A:
918 """Load run artifact.
920 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
922 :param run: Run ID or object.
923 :param loader: Loader callback.
924 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
925 """
926 return loader(self._get_run_artifact(self._coerce_run_id(run), path_in_run))
928 @pydantic.validate_arguments
929 def load_model_artifact(self, model_version: ModelVersionIdentifier, loader: Callable[[Path], A]) -> A:
930 """Load model version artifact.
932 Triggers a cache pull on a cache miss or if :attr:`always_pull_artifacts`.
934 :param model_version: Model version object or `(name, version)` tuple.
935 :param loader: Loader callback.
936 """
937 mv = self._coerce_mv(model_version)
938 logger.info("Loading model: %s v%s", mv.model.name, mv.version)
939 return self.load_run_artifact(mv.run, loader, mv.path_in_run)
941 @decorators.online
942 @pydantic.validate_arguments
943 def log_run_artifact(
944 self,
945 run: RunIdentifier,
946 source: Path | Callable[[Path], None],
947 path_in_run: str | None = None,
948 keep_the_source: bool | None = None,
949 allow_duplication: bool | None = None,
950 use_cache: bool | None = None,
951 ):
952 """Publish artifact file or dir to experiment run.
954 The flags :paramref:`keep_the_source`, :paramref:`allow_duplication` and :paramref:`use_cache` are
955 experimental and may conflict with one another. It is recommended to leave them unspecified, so this
956 method will do a best-effort to use cache if it makes sense to, keep the source files if it makes
957 sense to (possibly as a symbolic link) and avoid duplicated disk usage when possible.
959 :param run: | Run ID or object.
961 :param source: | Path to artifact file or dir, or a dumper callback.
962 | If it's a callback and the upload is interrupted, the temporary artifact is kept.
964 :param path_in_run: Plain relative path inside run artifacts (e.g.: `a/b/c`)
966 - If `source` is a `Path`: Defaults to file or dir name.
967 - If `source` is a callback: No default available.
969 :param keep_the_source:
970 - If `source` is a `Path`: Keep that file or dir (defaults to `True`).
971 - If `source` is a callback: Keep the temporary artifact, even after a successful upload (defaults to `False`).
973 :param allow_duplication: | If `False`, a `source` file or dir may be replaced with a symbolic link to the local cache in order to avoid duplicated disk usage.
974 | Defaults to `True` if :paramref:`keep_the_source` is `True` and the run artifacts repo is local.
976 :param use_cache: | If `True`, keep artifact in local cache after publishing.
977 | Defaults to `True` if the run artifacts repo is remote.
978 """
979 tmp = None
981 if using_dumper := callable(source):
982 logger.debug("Using temporary artifact path: %s", tmp := self._get_temp_artifacts_dir())
983 source(source := tmp.joinpath("artifact"))
985 if (source := Path(source).expanduser().resolve()).is_relative_to(self._run_artifacts_cache):
986 raise paths.IllegalPath(f"Source path points to artifact cache: {source}")
988 if keep_the_source is None:
989 keep_the_source = False if using_dumper else True # noqa: SIM211
991 try:
992 if using_dumper:
993 assert path_in_run, "When using an artifact dumper, `path_in_run` must be specified."
994 else:
995 path_in_run = path_in_run or source.name
996 path_in_run = self._valid_path_in_run(path_in_run, allow_empty=False)
998 run = self._coerce_run(run)
999 target = urls.urljoin(run.repo_url, path_in_run)
1001 if repo_is_local := urls.is_local(target):
1002 if use_cache is None:
1003 use_cache = False
1005 if allow_duplication is None:
1006 allow_duplication = True if keep_the_source or use_cache else False # noqa: SIM211,SIM210
1008 if keep_the_source:
1009 if allow_duplication:
1010 mode = "copy"
1011 else:
1012 raise RuntimeError("Cannot keep the source without duplication when artifacts repo is local")
1013 else:
1014 mode = "move"
1016 paths.place_path(
1017 source,
1018 target.path,
1019 mode=typing.cast(paths.PathOperation, mode),
1020 overwrite=True,
1021 move_abs_links=False,
1022 )
1023 else:
1024 if use_cache is None:
1025 use_cache = True
1027 if allow_duplication is None:
1028 allow_duplication = False
1030 self.file_transfer.push_files(source, target)
1031 except BaseException as exc:
1032 raise exceptions.FailedToPublishArtifact(source) from exc
1034 logger.debug(f"Artifact successfully published to '{target}'")
1036 if use_cache:
1037 with self._lock_run_artifact(run.id, path_in_run, allow_base_resolve=False) as cache:
1038 if repo_is_local:
1039 if allow_duplication:
1040 paths.place_path(target.path, cache, mode="copy", overwrite=True)
1041 else:
1042 raise RuntimeError("Cannot cache artifact without duplication when run artifacts repo is local")
1043 elif keep_the_source:
1044 if allow_duplication:
1045 paths.place_path(source, cache, mode="copy", overwrite=True)
1046 else:
1047 logger.warning("Keeping the `source` as a symbolic link to the cached artifact")
1048 logger.debug(f"{source} -> {cache}")
1049 paths.place_path(source, cache, mode="move", overwrite=True)
1050 paths.place_path(cache, source, mode="link", overwrite=True)
1051 else:
1052 paths.place_path(source, cache, mode="move", overwrite=True)
1054 if not keep_the_source:
1055 paths.ensure_non_existing(source, force=True)
1056 if tmp is not None:
1057 paths.ensure_non_existing(tmp, force=True)
1059 @pydantic.validate_arguments
1060 def log_model_version(
1061 self,
1062 model: ModelIdentifier,
1063 run: RunIdentifier,
1064 source: Path | Callable[[Path], None],
1065 path_in_run: str | None = None,
1066 keep_the_source: bool | None = None,
1067 allow_duplication: bool | None = None,
1068 use_cache: bool | None = None,
1069 version: str | None = None,
1070 tags: Mapping | None = None,
1071 ) -> ModelVersionApi:
1072 """Publish artifact file or dir as model version inside the specified experiment run.
1074 :param model: | Model name or object.
1076 :param run: | Run ID or object.
1078 :param source: | See :paramref:`log_run_artifact.source`
1080 :param path_in_run: | Plain relative path inside run artifacts (e.g.: `a/b/c`).
1081 | Defaults to model name.
1083 :param keep_the_source: | See :paramref:`log_run_artifact.keep_the_source`
1085 :param allow_duplication: | See :paramref:`log_run_artifact.allow_duplication`
1087 :param use_cache: | See :paramref:`log_run_artifact.use_cache`
1089 :param version: | Arbitrary model version
1090 | (not supported by all backends).
1092 :param tags: | Model version tags.
1093 | See :class:`schema.ModelVersion.tags`
1095 :return: New model version metadata with API handle.
1096 """
1097 logger.info("Logging version of model '%s'", model_name := self._coerce_model_name(model))
1098 path_in_run = path_in_run or patterns.encode_model_name(model_name)
1099 self.log_run_artifact(run, source, path_in_run, keep_the_source, allow_duplication, use_cache)
1100 return ModelVersionApi(**self._register_mv(model, run, path_in_run, version, tags or {})).using(self)
1102 @pydantic.validate_arguments
1103 def get_exp_url(self, exp: ExpIdentifier) -> str:
1104 """Get Experiment URL.
1106 :param exp: Exp ID or object.
1107 """
1108 return str(self._impl_get_exp_url(self._coerce_exp_id(exp)))
1110 @pydantic.validate_arguments
1111 def get_run_url(self, run: RunIdentifier, exp: ExpIdentifier | None = None) -> str:
1112 """Get Run URL.
1114 :param run: Run ID or object.
1115 :param exp: Exp ID or object.
1117 Caveats:
1118 - :paramref:`exp` must be specified on :attr:`~BaseMlflowApi.offline_mode`
1119 if :paramref:`run` is an ID and the run metadata is not in cache.
1120 """
1121 exp = self._coerce_run(run).exp if exp is None else exp
1122 return str(self._impl_get_run_url(self._coerce_run_id(run), self._coerce_exp_id(exp)))
1124 @pydantic.validate_arguments
1125 def get_model_url(self, model: ModelIdentifier) -> str:
1126 """Get URL to registered model.
1128 :param model: Model name or object.
1129 """
1130 return str(self._impl_get_model_url(self._coerce_model_name(model)))
1132 @pydantic.validate_arguments
1133 def get_model_version_url(self, model_version: ModelVersionIdentifier) -> str:
1134 """Get model version URL.
1136 :param model_version: Model version object or `(name, version)` tuple.
1137 """
1138 return str(self._impl_get_mv_url(*self._coerce_mv_tuple(model_version)))
1140 @pydantic.validate_arguments
1141 def get_exp(self, exp: ExpIdentifier, **cache_opts: bool) -> ExpApi:
1142 """Get Experiment API by ID.
1144 :param exp: Exp ID or object.
1145 """
1146 return ExpApi(**self._get_exp(self._coerce_exp_id(exp), **cache_opts)).using(self)
1148 @pydantic.validate_arguments
1149 def get_run(self, run: RunIdentifier, **cache_opts: bool) -> RunApi:
1150 """Get Run API by ID.
1152 :param run: Run ID or object.
1153 """
1154 return RunApi(**self._get_run(self._coerce_run_id(run), **cache_opts)).using(self)
1156 @pydantic.validate_arguments
1157 def get_model(self, model: ModelIdentifier, **cache_opts: bool) -> ModelApi:
1158 """Get Model API by name.
1160 :param model: Model name or object.
1161 """
1162 return ModelApi(**self._get_model(self._coerce_model_name(model), **cache_opts)).using(self)
1164 @pydantic.validate_arguments
1165 def get_model_version(self, model_version: ModelVersionIdentifier, **cache_opts: bool) -> ModelVersionApi:
1166 """Get ModelVersion API by name and version.
1168 :param model_version: Model version object or `(name, version)` tuple.
1169 """
1170 return ModelVersionApi(**self._get_mv(self._coerce_mv_tuple(model_version), **cache_opts)).using(self)
1172 @pydantic.validate_arguments
1173 def find_exps(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[ExpApi]:
1174 """Search experiments with query in MongoDB query language.
1176 :param query: Query in MongoDB query language.
1177 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`).
1178 """
1179 return (ExpApi(**x).using(self) for x in self._find_experiments(query or {}, sorting or []))
1181 @pydantic.validate_arguments
1182 def find_runs(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[RunApi]:
1183 """Search runs with query in MongoDB query language.
1185 :param query: Query in MongoDB query language.
1186 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`).
1187 """
1188 return (RunApi(**x).using(self) for x in self._find_runs(query or {}, sorting or []))
1190 @pydantic.validate_arguments
1191 def find_models(self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None) -> Iterator[ModelApi]:
1192 """Search registered models with query in MongoDB query language.
1194 :param query: Query in MongoDB query language.
1195 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`).
1196 """
1197 return (ModelApi(**x).using(self) for x in self._find_models(query or {}, sorting or []))
1199 @pydantic.validate_arguments
1200 def find_model_versions(
1201 self, query: mongo.Query | None = None, sorting: mongo.Sorting | None = None
1202 ) -> Iterator[ModelVersionApi]:
1203 """Search model versions with query in MongoDB query language.
1205 :param query: Query in MongoDB query language.
1206 :param sorting: Sorting criteria (e.g.: `[("asc_field", 1), ("desc_field", -1)]`).
1207 """
1208 return (ModelVersionApi(**x).using(self) for x in self._find_mv(query or {}, sorting or []))
1210 @pydantic.validate_arguments
1211 def find_child_runs(self, parent: RunIdentifier) -> Iterator[RunApi]:
1212 """Find child runs.
1214 :param parent: Run ID or object.
1215 """
1216 return (RunApi(**x).using(self) for x in self._impl_find_child_runs(self._coerce_run(parent)))
1218 @pydantic.validate_arguments
1219 def cache_exp_meta(self, exp: ExpIdentifier) -> ExpApi:
1220 """Get latest Experiment metadata and save to local cache.
1222 :param exp: Experiment ID or object.
1223 """
1224 return self.get_exp(exp, force_cache_refresh=True)
1226 @pydantic.validate_arguments
1227 def cache_run_meta(self, run: RunIdentifier) -> RunApi:
1228 """Get latest Run metadata and save to local cache.
1230 :param run: Run ID or object.
1231 """
1232 return self.get_run(run, force_cache_refresh=True)
1234 @pydantic.validate_arguments
1235 def cache_model_meta(self, model: ModelIdentifier) -> ModelApi:
1236 """Get latest Model metadata and save to local cache.
1238 :param model: Model name or object.
1239 """
1240 return self.get_model(model, force_cache_refresh=True)
1242 @pydantic.validate_arguments
1243 def cache_model_version_meta(self, model_version: ModelVersionIdentifier) -> ModelVersionApi:
1244 """Get latest model version metadata and save to local cache.
1246 :param model_version: Model version object or `(name, version)` tuple.
1247 """
1248 return self.get_model_version(model_version, force_cache_refresh=True)
1250 @pydantic.validate_arguments
1251 def export_exp_meta(self, exp: ExpIdentifier, target: Path) -> ExpApi:
1252 """Export experiment metadata cache to target.
1254 :param exp: Experiment ID or object.
1255 :param target: Cache export path.
1256 """
1257 self._export_meta(exp := self.get_exp(id_ := self._coerce_exp_id(exp)), self._get_exp_cache(id_), target)
1258 return exp
1260 @pydantic.validate_arguments
1261 def export_run_meta(self, run: RunIdentifier, target: Path) -> RunApi:
1262 """Export run metadata cache to target.
1264 :param run: Run ID or object.
1265 :param target: Cache export path.
1266 """
1267 self._export_meta(run := self.get_run(id_ := self._coerce_run_id(run)), self._get_run_cache(id_), target)
1268 self.export_exp_meta(run.exp, target)
1269 return run
1271 @pydantic.validate_arguments
1272 def export_model_meta(self, model: ModelIdentifier, target: Path) -> ModelApi:
1273 """Export model metadata cache to target.
1275 :param model: Model name or object.
1276 :param target: Cache export path.
1277 """
1278 name = self._coerce_model_name(model)
1279 self._export_meta(model := self.get_model(name), self._get_model_cache(name), target)
1280 return model
1282 @pydantic.validate_arguments
1283 def export_model_version_meta(self, mv: ModelVersionIdentifier, target: Path) -> ModelVersionApi:
1284 """Export model version metadata cache to target.
1286 :param mv: Model version object or `(name, version)` tuple.
1287 :param target: Cache export path.
1288 """
1289 tup = self._coerce_mv_tuple(mv)
1290 self._export_meta(mv := self.get_model_version(tup), self._get_mv_cache(*tup), target)
1291 self.export_model_meta(mv.model, target)
1292 return mv
1294 @pydantic.validate_arguments
1295 def create_exp(self, name: str, tags: Mapping | None = None) -> ExpApi:
1296 """Create Experiment and return its API.
1298 :param name: See :attr:`schema.Experiment.name`.
1299 :param tags: See :attr:`schema.Experiment.tags`.
1300 """
1301 return ExpApi(**self._create_exp(name, tags or {})).using(self)
1303 @pydantic.validate_arguments
1304 def get_or_create_exp(self, name: str) -> ExpApi:
1305 """Get or create Experiment and return its API.
1307 :param name: See :attr:`schema.Experiment.name`.
1308 """
1309 for exp in self._find_experiments({"name": name}, []):
1310 break
1311 else:
1312 exp = self._create_exp(name, tags={})
1314 return ExpApi(**exp).using(self)
1316 @pydantic.validate_arguments
1317 def create_model(self, name: str, tags: Mapping | None = None) -> ModelApi:
1318 """Create registered model and return its API.
1320 :param name: See :attr:`schema.Model.name`.
1321 :param tags: See :attr:`schema.Model.tags`.
1322 """
1323 return ModelApi(**self._create_model(name, tags or {})).using(self)
1325 @pydantic.validate_arguments
1326 def get_or_create_model(self, name: str) -> ModelApi:
1327 """Get or create registered Model and return its API.
1329 :param name: See :attr:`schema.Model.name`.
1330 """
1331 for model in self._find_models({"name": name}, []):
1332 break
1333 else:
1334 model = self._create_model(name, tags={})
1336 return ModelApi(**model).using(self)
1338 @pydantic.validate_arguments
1339 def create_run(
1340 self,
1341 exp: ExpIdentifier,
1342 name: str | None = None,
1343 tags: Mapping | None = None,
1344 repo: str | urls.Url | None = None,
1345 parent: RunIdentifier | None = None,
1346 ) -> RunApi:
1347 """Declare a new experiment run to be used later.
1349 :param exp: Experiment ID or object.
1350 :param name: See :attr:`schema.Run.name`.
1351 :param tags: See :attr:`schema.Run.tags`.
1352 :param repo: (Experimental) Cloud storage URL to be used as alternative run artifacts repository.
1353 :param parent: Parent run ID or object.
1354 """
1355 return RunApi(**self._create_run(exp, name, repo, tags or {}, schema.RunStatus.SCHEDULED, parent)).using(self)
1357 @pydantic.validate_arguments
1358 def start_run(
1359 self,
1360 exp: ExpIdentifier,
1361 name: str | None = None,
1362 tags: Mapping | None = None,
1363 repo: str | urls.Url | None = None,
1364 parent: RunIdentifier | None = None,
1365 ) -> RunApi:
1366 """Start a new experiment run.
1368 :param exp: Experiment ID or object.
1369 :param name: See :attr:`schema.Run.name`.
1370 :param tags: See :attr:`schema.Run.tags`.
1371 :param repo: (Experimental) Cloud storage URL to be used as alternative run artifacts repository.
1372 :param parent: Parent run ID or object.
1373 """
1374 return RunApi(**self._create_run(exp, name, repo, tags or {}, schema.RunStatus.RUNNING, parent)).using(self)
1376 @pydantic.validate_arguments
1377 def resume_run(self, run: RunIdentifier) -> RunApi:
1378 """Resume a previous experiment run.
1380 :param run: Run ID or object.
1381 """
1382 self._set_run_status(run_id := self._coerce_run_id(run), schema.RunStatus.RUNNING)
1383 return self.get_run(run_id)
1385 @pydantic.validate_arguments
1386 def end_run(self, run: RunIdentifier, succeeded: bool = True) -> RunApi:
1387 """End experiment run.
1389 :param run: Run ID or object.
1390 :param succeeded: Whether the run was successful.
1391 """
1392 status = schema.RunStatus.FINISHED if succeeded else schema.RunStatus.FAILED
1393 self._set_run_status(run_id := self._coerce_run_id(run), status)
1394 self._set_run_end_time(run_id, datetime.now())
1395 return self.get_run(run_id)
1397 @pydantic.validate_arguments
1398 def set_tags_on_exp(self, exp: ExpIdentifier, tags: Mapping):
1399 """Set tags on experiment.
1401 :param exp: Experiment ID or object.
1402 :param tags: See :attr:`schema.Experiment.tags`.
1403 """
1404 self._update_exp_tags(exp, tags)
1406 @pydantic.validate_arguments
1407 def set_tags_on_run(self, run: RunIdentifier, tags: Mapping):
1408 """Set tags on experiment run.
1410 :param run: Run ID or object.
1411 :param tags: See :attr:`schema.Run.tags`.
1412 """
1413 self._update_run_tags(run, tags)
1415 @pydantic.validate_arguments
1416 def set_tags_on_model(self, model: ModelIdentifier, tags: Mapping):
1417 """Set tags on registered model.
1419 :param model: Model name or object.
1420 :param tags: See :attr:`schema.Model.tags`.
1421 """
1422 self._update_model_tags(model, tags)
1424 @pydantic.validate_arguments
1425 def set_tags_on_model_version(self, model_version: ModelVersionIdentifier, tags: Mapping):
1426 """Set tags on model version.
1428 :param model_version: Model version object or `(name, version)` tuple.
1429 :param tags: See :attr:`schema.Model.tags`.
1430 """
1431 self._update_mv_tags(model_version, tags)
1433 @pydantic.validate_arguments
1434 def log_params(self, run: RunIdentifier, params: Mapping):
1435 """Log params to experiment run.
1437 :param run: Run ID or object.
1438 :param params: See :attr:`schema.Run.params`.
1439 """
1440 self._log_run_params(run, params)
1442 @pydantic.validate_arguments
1443 def log_metrics(self, run: RunIdentifier, metrics: Mapping):
1444 """Log metrics to experiment run.
1446 :param run: Run ID or object.
1447 :param metrics: See :attr:`schema.Run.metrics`.
1448 """
1449 self._log_run_metrics(run, metrics)