Coverage for src/mlopus/mlflow/api/common/transfer.py: 78%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import inspect 

2import re 

3from datetime import datetime 

4from pathlib import Path 

5from typing import Any, Iterable, List 

6 

7from mlopus.utils import import_utils, pydantic, urls 

8 

9 

10class ObjMeta(pydantic.BaseModel): 

11 Name: str 

12 Size: int 

13 IsDir: bool 

14 MimeType: str 

15 ModTime: datetime 

16 

17 @classmethod 

18 def parse_many(cls, objs: Iterable[pydantic.ModelLike]) -> Iterable["ObjMeta"]: 

19 for obj in objs: 

20 yield cls.parse_obj(obj) 

21 

22 

23LsResult = List[ObjMeta] | ObjMeta 

24 

25 

26class FileTransfer(pydantic.BaseModel): 

27 """File transfer wrapper for MLflow API.""" 

28 

29 prog_bar: bool = pydantic.Field(default=True, description="Show progress bar when transfering files.") 

30 tool: Any = pydantic.Field( 

31 default="rclone_python.rclone", 

32 description=( 

33 "Fully qualified path of module, class or object that exposes the methods/functions " 

34 "`ls`, `copyto` and `sync`, with signatures compatible with the ones exposed in " 

35 "`rclone_python.rclone <https://pypi.org/project/rclone-python>`_." 

36 ), 

37 ) 

38 extra_args: dict[str, list[str]] = pydantic.Field( 

39 default={"sync": ["--copy-links"]}, 

40 description="Dict of extra arguments to pass to each of the functions exposed by the :attr:`tool`.", 

41 ) 

42 use_scheme: str | None = pydantic.Field( 

43 default=None, 

44 description="Replace remote URL schemes with this one. Incompatible with :attr:`map_scheme`.", 

45 ) 

46 map_scheme: dict[str | re.Pattern, str] | None = pydantic.Field( 

47 default=None, 

48 description=( 

49 "Replace remote URL schemes with the first value in this mapping whose key (regexp) matches the URL. " 

50 "Incompatible with :attr:`use_scheme`." 

51 ), 

52 ) 

53 

54 @pydantic.validator("map_scheme") # noqa 

55 @classmethod 

56 def _compile_map_scheme_regex(cls, v: dict | None) -> dict: 

57 return {re.compile(k) if isinstance(k, str) else k: v for k, v in (v or {}).items()} 

58 

59 @pydantic.root_validator(mode="after") 

60 def _scheme_rules_compatibility(self): 

61 assert False in (bool(self.use_scheme), bool(self.map_scheme)), "`use_scheme` and `map_scheme` are incompatible" 

62 return self 

63 

64 @pydantic.root_validator 

65 def _find_tool(self): 

66 self.tool = import_utils.find_attr(self.tool) if isinstance(self.tool, str) else self.tool 

67 return self 

68 

69 def _translate_scheme(self, url: urls.UrlLike) -> urls.UrlLike: 

70 if urls.is_local(url): 

71 return url 

72 

73 scheme = None 

74 if self.use_scheme: 

75 scheme = self.use_scheme 

76 elif self.map_scheme: 

77 for pattern, new_scheme in self.map_scheme.items(): 

78 if pattern.match(str(url)): 

79 scheme = new_scheme 

80 break 

81 

82 return urls.parse_url(url)._replace(scheme=scheme) if scheme else url # noqa 

83 

84 def ls(self, url: urls.UrlLike) -> LsResult: 

85 """If `url` is a dir, list the objects in it. If it's a file, return the file metadata.""" 

86 objs = list(ObjMeta.parse_many(self._tool("ls", url := str(self._translate_scheme(url))))) 

87 

88 if len(objs) == 1 and (not (one_obj := objs[0]).IsDir and one_obj.Name == Path(url).name): 

89 return one_obj 

90 

91 return objs 

92 

93 def is_file(self, url: urls.Url) -> bool: 

94 """Check if URL points to a file. If False, it may be a dir or not exist.""" 

95 return not isinstance(self.ls(url), list) 

96 

97 def pull_files(self, src: urls.Url, tgt: Path): 

98 """Pull files from `src` to `tgt`.""" 

99 match self.ls(src): 

100 case []: 

101 raise FileNotFoundError(src) 

102 case list(): 

103 func = "sync" 

104 case ObjMeta(): 

105 func = "copyto" 

106 case _: 

107 raise NotImplementedError("src=%s (%s)", src, type(src)) 

108 

109 src = self._translate_scheme(src) 

110 self._tool(func, *(str(x).rstrip("/") for x in (src, tgt))) 

111 

112 def push_files(self, src: Path, tgt: urls.Url): 

113 """Push files from `src` to `tgt`.""" 

114 tgt = self._translate_scheme(tgt) 

115 self._tool( 

116 "copyto" if src.is_file() else "sync", 

117 *(str(x).rstrip("/") for x in (src.expanduser().resolve(), tgt)), 

118 ) 

119 

120 def _tool(self, func: str, *args, **kwargs): 

121 call = getattr(self.tool, func) 

122 

123 if self.prog_bar and "show_progress" in inspect.signature(call).parameters: 

124 kwargs["show_progress"] = True 

125 

126 return call( 

127 *args, 

128 **kwargs, 

129 args=self.extra_args.get(func) or None, 

130 )