Coverage for src/mlopus/utils/pydantic.py: 88%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import functools 

2import inspect 

3from collections.abc import Mapping 

4from typing import Any, Type, TypeVar, Dict 

5 

6from pydantic import ( 

7 BaseModel as _BaseModel, 

8 create_model, 

9 fields, 

10 Field, 

11 field_validator, 

12 model_validator, 

13 ValidationError, 

14 types, 

15 validate_call, 

16) 

17from typing_extensions import Self 

18 

19from mlopus.utils import typing_utils, common 

20 

21T = TypeVar("T") # Any type 

22 

23__all__ = [ 

24 "types", 

25 "fields", 

26 "BaseModel", # Pydantic V1 BaseModel (patched) 

27 "create_model", 

28 "Field", 

29 "validator", 

30 "root_validator", 

31 "field_validator", 

32 "model_validator", 

33 "ValidationError", 

34] 

35 

36P = TypeVar("P", bound=_BaseModel) # Any type of `BaseModel` (patched or not) 

37 

38ModelLike = Mapping | _BaseModel # Anything that can be parsed into a `BaseModel` 

39 

40 

41def root_validator(*args, **kwargs): 

42 if not kwargs and len(args) == 1: 

43 return root_validator()(args[0]) 

44 kwargs.pop("allow_reuse", None) 

45 kwargs.setdefault("mode", "before" if kwargs.pop("pre", False) else "after") 

46 return model_validator(*args, **kwargs) 

47 

48 

49def validator(field: str, *args, **kwargs): 

50 kwargs.pop("allow_reuse", None) 

51 kwargs.setdefault("mode", "before" if kwargs.pop("pre", False) else "after") 

52 return field_validator(field, *args, **kwargs) 

53 

54 

55class BaseModel(_BaseModel): 

56 """Patch for pydantic BaseModel.""" 

57 

58 class Config: 

59 """Pydantic class config.""" 

60 

61 coerce_numbers_to_str = True # Fixes ValidationError when `str` is expected and `int` is passed 

62 repr_empty: bool = True # If `False`, skip fields with empty values in representation 

63 arbitrary_types_allowed = True # Fixes: RuntimeError: no validator found for <class '...'> 

64 ignored_types = (functools.cached_property,) # Fixes: TypeError: cannot pickle '_thread.RLock' object 

65 protected_namespaces = () # Fixes: UserWarning: Field "model_*" has conflict with protected namespace "model_" 

66 

67 def __repr__(self): 

68 """Representation skips fields if: 

69 - Field conf has `repr=False` or `exclude=True`. 

70 - Field value is empty and class conf has `repr_empty=False`. 

71 """ 

72 args = [ 

73 f"{k}={v}" # noqa 

74 for k, f in self.model_fields.items() 

75 if f.repr and not f.exclude and (not common.is_empty(v := getattr(self, k)) or self.Config.repr_empty) 

76 ] 

77 return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) 

78 

79 def __str__(self): 

80 """String matches representation.""" 

81 return repr(self) 

82 

83 def dict(self, *args, **kwargs): 

84 """Replace deprecated `dict` with `model_dump`.""" 

85 return self.model_dump(*args, **kwargs) 

86 

87 @classmethod 

88 def parse_obj(cls, obj: Any) -> Self: 

89 """Replace deprecated `parse_obj` with `model_validate`.""" 

90 return cls.model_validate(obj) 

91 

92 

93class EmptyStrAsMissing(_BaseModel): 

94 """Mixin for BaseModel.""" 

95 

96 @root_validator(pre=True) # noqa 

97 @classmethod 

98 def _handle_empty_str(cls, values: dict) -> dict: 

99 """Handles empty strings in input as missing values.""" 

100 return {k: v for k, v in values.items() if v != ""} 

101 

102 

103class EmptyDictAsMissing(_BaseModel): 

104 """Mixin for BaseModel.""" 

105 

106 @root_validator(pre=True) # noqa 

107 @classmethod 

108 def _handle_empty_dict(cls, values: dict) -> dict: 

109 """Handles empty dicts in input as missing values.""" 

110 return {k: v for k, v in values.items() if v != {}} 

111 

112 

113class ExcludeEmptyMixin(_BaseModel): 

114 """Mixin for BaseModel.""" 

115 

116 def model_dump(self, **kwargs) -> dict: 

117 """Ignores empty fields when serializing to dict.""" 

118 exclude = kwargs.get("exclude") or set() 

119 

120 for field in self.model_fields: 

121 if common.is_empty(getattr(self, field)): 

122 if isinstance(exclude, dict): 

123 exclude[field] = True 

124 else: 

125 exclude.add(field) 

126 

127 if isinstance(exclude, dict): 

128 exclude = set(k for k, v in exclude.items() if v) 

129 

130 return super().model_dump(**kwargs | {"exclude": exclude}) 

131 

132 

133class HashableMixin: 

134 """Mixin for BaseModel.""" 

135 

136 def __hash__(self): 

137 """Fixes: TypeError: unhashable type.""" 

138 return id(self) 

139 

140 

141class SignatureMixin: 

142 """Mixin for BaseModel.""" 

143 

144 def __getattribute__(self, attr: str) -> Any: 

145 """Fixes: AttributeError: '__signature__' attribute of '...' is class-only.""" 

146 if attr == "__signature__": 

147 return inspect.signature(self.__init__) 

148 return super().__getattribute__(attr) 

149 

150 

151class MappingMixin(_BaseModel, Mapping): 

152 """Mixin that allows passing BaseModel instances as kwargs with the '**' operator. 

153 

154 Example: 

155 class Foo(MappingMixin): 

156 x: int = 1 

157 y: int = 2 

158 

159 foo = Foo() 

160 

161 dict(**foo, z=3) # Returns: {"x": 1, "y": 2, "z": 3} 

162 """ 

163 

164 def __init__(self, *args, **kwargs): 

165 # Fix for `RuntimeError(Could not convert dictionary to <class>)` in `pydantic.validate_arguments` 

166 # when the function expects a `Mapping` and receives a pydantic object with the trait `MappingMixin`. 

167 if not kwargs and len(args) == 1 and isinstance(arg := args[0], dict): 

168 kwargs = arg 

169 super().__init__(**kwargs) 

170 

171 def __iter__(self): 

172 return iter(self.model_fields) 

173 

174 def __getitem__(self, __key): 

175 return getattr(self, __key) 

176 

177 def __len__(self): 

178 return len(self.model_fields) 

179 

180 

181class BaseParamsMixin(_BaseModel): 

182 """Mixin for BaseModel that stores a mapping of parameterized generic bases and their respective type args.""" 

183 

184 @classmethod 

185 def __pydantic_init_subclass__(cls, **kwargs): 

186 cls.__parameterized_bases__ = dict(typing_utils.iter_parameterized_bases(cls)) 

187 

188 @classmethod 

189 def _find_base_param(cls, of_base: type, at_pos: int, as_type_of: Type[T] | None = None) -> Type[T]: 

190 for base, params in cls.__parameterized_bases__.items(): 

191 if issubclass(base, of_base): 

192 break 

193 else: 

194 raise TypeError(f"Cannot find parameterized base of type {of_base}") 

195 return typing_utils.as_type(params[at_pos], of=as_type_of, strict=True) 

196 

197 

198def create_model_from_data( 

199 name: str, data: dict, __base__: Type[P] | None = None, use_defaults: bool = True, **kwargs 

200) -> Type[P]: 

201 """Infer pydantic model from data.""" 

202 _fields = {} 

203 

204 for key, value in data.items(): 

205 if isinstance(value, dict): 

206 type_ = create_model_from_data(key.capitalize(), value, **kwargs) 

207 default = type_.parse_obj(value) 

208 elif value is None: 

209 type_, default = Any, None 

210 else: 

211 type_, default = type(value), value 

212 

213 _fields[key] = (type_, default if use_defaults else Field()) 

214 

215 return create_model(name, **_fields, **kwargs, __base__=__base__) 

216 

217 

218def create_obj_from_data( 

219 name: str, data: dict, __base__: Type[P] | None = None, use_defaults_in_model: bool = False, **kwargs 

220) -> P: 

221 """Infer pydantic model from data and parse it.""" 

222 model = create_model_from_data(name, data, **kwargs, __base__=__base__, use_defaults=use_defaults_in_model) 

223 return model.parse_obj(data) 

224 

225 

226def force_set_attr(obj, key: str, val: Any): 

227 """Low-level attribute set on object (bypasses validations).""" 

228 object.__setattr__(obj, key, val) 

229 

230 

231def is_model_cls(type_: type) -> bool: 

232 """Check if type is pydantic base model.""" 

233 return typing_utils.safe_issubclass(type_, _BaseModel) 

234 

235 

236def is_model_obj(obj: Any) -> bool: 

237 """Check if object is instance of pydantic base model.""" 

238 return is_model_cls(type(obj)) 

239 

240 

241def as_model_cls(type_: type) -> Type[P] | None: 

242 """If type is pydantic base model, return it. Else return None.""" 

243 return type_ if is_model_cls(type_) else None 

244 

245 

246def as_model_obj(obj: Any) -> P | None: 

247 """If object is instance of pydantic base model, return it. Else return None.""" 

248 return obj if is_model_obj(obj) else None 

249 

250 

251def validate_arguments(_func: callable = None, *, config: Dict[str, Any] = None): 

252 """Patch of `validate_arguments` that allows skipping the return type validation. 

253 

254 Return type validation is turned off by default when the function's 

255 return type is a string alias to a type that hasn't been defined yet. 

256 """ 

257 config = config or {} 

258 

259 if _func is None: 

260 return functools.partial(validate_arguments, config=config) 

261 

262 if not config.get("validate_return", not isinstance(_func.__annotations__.get("return"), str)): 

263 _func.__annotations__.pop("return", None) 

264 

265 return validate_call(config=config)(_func)