Coverage for src/mlopus/mlflow/api/common/transfer.py: 78%
74 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import inspect
2import re
3from datetime import datetime
4from pathlib import Path
5from typing import Any, Iterable, List
7from mlopus.utils import import_utils, pydantic, urls
10class ObjMeta(pydantic.BaseModel):
11 Name: str
12 Size: int
13 IsDir: bool
14 MimeType: str
15 ModTime: datetime
17 @classmethod
18 def parse_many(cls, objs: Iterable[pydantic.ModelLike]) -> Iterable["ObjMeta"]:
19 for obj in objs:
20 yield cls.parse_obj(obj)
23LsResult = List[ObjMeta] | ObjMeta
26class FileTransfer(pydantic.BaseModel):
27 """File transfer wrapper for MLflow API."""
29 prog_bar: bool = pydantic.Field(default=True, description="Show progress bar when transfering files.")
30 tool: Any = pydantic.Field(
31 default="rclone_python.rclone",
32 description=(
33 "Fully qualified path of module, class or object that exposes the methods/functions "
34 "`ls`, `copyto` and `sync`, with signatures compatible with the ones exposed in "
35 "`rclone_python.rclone <https://pypi.org/project/rclone-python>`_."
36 ),
37 )
38 extra_args: dict[str, list[str]] = pydantic.Field(
39 default={"sync": ["--copy-links"]},
40 description="Dict of extra arguments to pass to each of the functions exposed by the :attr:`tool`.",
41 )
42 use_scheme: str | None = pydantic.Field(
43 default=None,
44 description="Replace remote URL schemes with this one. Incompatible with :attr:`map_scheme`.",
45 )
46 map_scheme: dict[str | re.Pattern, str] | None = pydantic.Field(
47 default=None,
48 description=(
49 "Replace remote URL schemes with the first value in this mapping whose key (regexp) matches the URL. "
50 "Incompatible with :attr:`use_scheme`."
51 ),
52 )
54 @pydantic.validator("map_scheme") # noqa
55 @classmethod
56 def _compile_map_scheme_regex(cls, v: dict | None) -> dict:
57 return {re.compile(k) if isinstance(k, str) else k: v for k, v in (v or {}).items()}
59 @pydantic.root_validator(mode="after")
60 def _scheme_rules_compatibility(self):
61 assert False in (bool(self.use_scheme), bool(self.map_scheme)), "`use_scheme` and `map_scheme` are incompatible"
62 return self
64 @pydantic.root_validator
65 def _find_tool(self):
66 self.tool = import_utils.find_attr(self.tool) if isinstance(self.tool, str) else self.tool
67 return self
69 def _translate_scheme(self, url: urls.UrlLike) -> urls.UrlLike:
70 if urls.is_local(url):
71 return url
73 scheme = None
74 if self.use_scheme:
75 scheme = self.use_scheme
76 elif self.map_scheme:
77 for pattern, new_scheme in self.map_scheme.items():
78 if pattern.match(str(url)):
79 scheme = new_scheme
80 break
82 return urls.parse_url(url)._replace(scheme=scheme) if scheme else url # noqa
84 def ls(self, url: urls.UrlLike) -> LsResult:
85 """If `url` is a dir, list the objects in it. If it's a file, return the file metadata."""
86 objs = list(ObjMeta.parse_many(self._tool("ls", url := str(self._translate_scheme(url)))))
88 if len(objs) == 1 and (not (one_obj := objs[0]).IsDir and one_obj.Name == Path(url).name):
89 return one_obj
91 return objs
93 def is_file(self, url: urls.Url) -> bool:
94 """Check if URL points to a file. If False, it may be a dir or not exist."""
95 return not isinstance(self.ls(url), list)
97 def pull_files(self, src: urls.Url, tgt: Path):
98 """Pull files from `src` to `tgt`."""
99 match self.ls(src):
100 case []:
101 raise FileNotFoundError(src)
102 case list():
103 func = "sync"
104 case ObjMeta():
105 func = "copyto"
106 case _:
107 raise NotImplementedError("src=%s (%s)", src, type(src))
109 src = self._translate_scheme(src)
110 self._tool(func, *(str(x).rstrip("/") for x in (src, tgt)))
112 def push_files(self, src: Path, tgt: urls.Url):
113 """Push files from `src` to `tgt`."""
114 tgt = self._translate_scheme(tgt)
115 self._tool(
116 "copyto" if src.is_file() else "sync",
117 *(str(x).rstrip("/") for x in (src.expanduser().resolve(), tgt)),
118 )
120 def _tool(self, func: str, *args, **kwargs):
121 call = getattr(self.tool, func)
123 if self.prog_bar and "show_progress" in inspect.signature(call).parameters:
124 kwargs["show_progress"] = True
126 return call(
127 *args,
128 **kwargs,
129 args=self.extra_args.get(func) or None,
130 )