Coverage for src/mlopus/utils/pydantic.py: 88%
120 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import functools
2import inspect
3from collections.abc import Mapping
4from typing import Any, Type, TypeVar, Dict
6from pydantic import (
7 BaseModel as _BaseModel,
8 create_model,
9 fields,
10 Field,
11 field_validator,
12 model_validator,
13 ValidationError,
14 types,
15 validate_call,
16)
17from typing_extensions import Self
19from mlopus.utils import typing_utils, common
21T = TypeVar("T") # Any type
23__all__ = [
24 "types",
25 "fields",
26 "BaseModel", # Pydantic V1 BaseModel (patched)
27 "create_model",
28 "Field",
29 "validator",
30 "root_validator",
31 "field_validator",
32 "model_validator",
33 "ValidationError",
34]
36P = TypeVar("P", bound=_BaseModel) # Any type of `BaseModel` (patched or not)
38ModelLike = Mapping | _BaseModel # Anything that can be parsed into a `BaseModel`
41def root_validator(*args, **kwargs):
42 if not kwargs and len(args) == 1:
43 return root_validator()(args[0])
44 kwargs.pop("allow_reuse", None)
45 kwargs.setdefault("mode", "before" if kwargs.pop("pre", False) else "after")
46 return model_validator(*args, **kwargs)
49def validator(field: str, *args, **kwargs):
50 kwargs.pop("allow_reuse", None)
51 kwargs.setdefault("mode", "before" if kwargs.pop("pre", False) else "after")
52 return field_validator(field, *args, **kwargs)
55class BaseModel(_BaseModel):
56 """Patch for pydantic BaseModel."""
58 class Config:
59 """Pydantic class config."""
61 coerce_numbers_to_str = True # Fixes ValidationError when `str` is expected and `int` is passed
62 repr_empty: bool = True # If `False`, skip fields with empty values in representation
63 arbitrary_types_allowed = True # Fixes: RuntimeError: no validator found for <class '...'>
64 ignored_types = (functools.cached_property,) # Fixes: TypeError: cannot pickle '_thread.RLock' object
65 protected_namespaces = () # Fixes: UserWarning: Field "model_*" has conflict with protected namespace "model_"
67 def __repr__(self):
68 """Representation skips fields if:
69 - Field conf has `repr=False` or `exclude=True`.
70 - Field value is empty and class conf has `repr_empty=False`.
71 """
72 args = [
73 f"{k}={v}" # noqa
74 for k, f in self.model_fields.items()
75 if f.repr and not f.exclude and (not common.is_empty(v := getattr(self, k)) or self.Config.repr_empty)
76 ]
77 return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
79 def __str__(self):
80 """String matches representation."""
81 return repr(self)
83 def dict(self, *args, **kwargs):
84 """Replace deprecated `dict` with `model_dump`."""
85 return self.model_dump(*args, **kwargs)
87 @classmethod
88 def parse_obj(cls, obj: Any) -> Self:
89 """Replace deprecated `parse_obj` with `model_validate`."""
90 return cls.model_validate(obj)
93class EmptyStrAsMissing(_BaseModel):
94 """Mixin for BaseModel."""
96 @root_validator(pre=True) # noqa
97 @classmethod
98 def _handle_empty_str(cls, values: dict) -> dict:
99 """Handles empty strings in input as missing values."""
100 return {k: v for k, v in values.items() if v != ""}
103class EmptyDictAsMissing(_BaseModel):
104 """Mixin for BaseModel."""
106 @root_validator(pre=True) # noqa
107 @classmethod
108 def _handle_empty_dict(cls, values: dict) -> dict:
109 """Handles empty dicts in input as missing values."""
110 return {k: v for k, v in values.items() if v != {}}
113class ExcludeEmptyMixin(_BaseModel):
114 """Mixin for BaseModel."""
116 def model_dump(self, **kwargs) -> dict:
117 """Ignores empty fields when serializing to dict."""
118 exclude = kwargs.get("exclude") or set()
120 for field in self.model_fields:
121 if common.is_empty(getattr(self, field)):
122 if isinstance(exclude, dict):
123 exclude[field] = True
124 else:
125 exclude.add(field)
127 if isinstance(exclude, dict):
128 exclude = set(k for k, v in exclude.items() if v)
130 return super().model_dump(**kwargs | {"exclude": exclude})
133class HashableMixin:
134 """Mixin for BaseModel."""
136 def __hash__(self):
137 """Fixes: TypeError: unhashable type."""
138 return id(self)
141class SignatureMixin:
142 """Mixin for BaseModel."""
144 def __getattribute__(self, attr: str) -> Any:
145 """Fixes: AttributeError: '__signature__' attribute of '...' is class-only."""
146 if attr == "__signature__":
147 return inspect.signature(self.__init__)
148 return super().__getattribute__(attr)
151class MappingMixin(_BaseModel, Mapping):
152 """Mixin that allows passing BaseModel instances as kwargs with the '**' operator.
154 Example:
155 class Foo(MappingMixin):
156 x: int = 1
157 y: int = 2
159 foo = Foo()
161 dict(**foo, z=3) # Returns: {"x": 1, "y": 2, "z": 3}
162 """
164 def __init__(self, *args, **kwargs):
165 # Fix for `RuntimeError(Could not convert dictionary to <class>)` in `pydantic.validate_arguments`
166 # when the function expects a `Mapping` and receives a pydantic object with the trait `MappingMixin`.
167 if not kwargs and len(args) == 1 and isinstance(arg := args[0], dict):
168 kwargs = arg
169 super().__init__(**kwargs)
171 def __iter__(self):
172 return iter(self.model_fields)
174 def __getitem__(self, __key):
175 return getattr(self, __key)
177 def __len__(self):
178 return len(self.model_fields)
181class BaseParamsMixin(_BaseModel):
182 """Mixin for BaseModel that stores a mapping of parameterized generic bases and their respective type args."""
184 @classmethod
185 def __pydantic_init_subclass__(cls, **kwargs):
186 cls.__parameterized_bases__ = dict(typing_utils.iter_parameterized_bases(cls))
188 @classmethod
189 def _find_base_param(cls, of_base: type, at_pos: int, as_type_of: Type[T] | None = None) -> Type[T]:
190 for base, params in cls.__parameterized_bases__.items():
191 if issubclass(base, of_base):
192 break
193 else:
194 raise TypeError(f"Cannot find parameterized base of type {of_base}")
195 return typing_utils.as_type(params[at_pos], of=as_type_of, strict=True)
198def create_model_from_data(
199 name: str, data: dict, __base__: Type[P] | None = None, use_defaults: bool = True, **kwargs
200) -> Type[P]:
201 """Infer pydantic model from data."""
202 _fields = {}
204 for key, value in data.items():
205 if isinstance(value, dict):
206 type_ = create_model_from_data(key.capitalize(), value, **kwargs)
207 default = type_.parse_obj(value)
208 elif value is None:
209 type_, default = Any, None
210 else:
211 type_, default = type(value), value
213 _fields[key] = (type_, default if use_defaults else Field())
215 return create_model(name, **_fields, **kwargs, __base__=__base__)
218def create_obj_from_data(
219 name: str, data: dict, __base__: Type[P] | None = None, use_defaults_in_model: bool = False, **kwargs
220) -> P:
221 """Infer pydantic model from data and parse it."""
222 model = create_model_from_data(name, data, **kwargs, __base__=__base__, use_defaults=use_defaults_in_model)
223 return model.parse_obj(data)
226def force_set_attr(obj, key: str, val: Any):
227 """Low-level attribute set on object (bypasses validations)."""
228 object.__setattr__(obj, key, val)
231def is_model_cls(type_: type) -> bool:
232 """Check if type is pydantic base model."""
233 return typing_utils.safe_issubclass(type_, _BaseModel)
236def is_model_obj(obj: Any) -> bool:
237 """Check if object is instance of pydantic base model."""
238 return is_model_cls(type(obj))
241def as_model_cls(type_: type) -> Type[P] | None:
242 """If type is pydantic base model, return it. Else return None."""
243 return type_ if is_model_cls(type_) else None
246def as_model_obj(obj: Any) -> P | None:
247 """If object is instance of pydantic base model, return it. Else return None."""
248 return obj if is_model_obj(obj) else None
251def validate_arguments(_func: callable = None, *, config: Dict[str, Any] = None):
252 """Patch of `validate_arguments` that allows skipping the return type validation.
254 Return type validation is turned off by default when the function's
255 return type is a string alias to a type that hasn't been defined yet.
256 """
257 config = config or {}
259 if _func is None:
260 return functools.partial(validate_arguments, config=config)
262 if not config.get("validate_return", not isinstance(_func.__annotations__.get("return"), str)):
263 _func.__annotations__.pop("return", None)
265 return validate_call(config=config)(_func)