Coverage for src/mlopus/utils/typing_utils.py: 80%
45 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1import types
2import typing
3from typing import Any, TypeVar, Type, Tuple, Iterator
5T = TypeVar("T") # Any type
7NoneType = type(None)
10def is_optional(annotation: type | Type[T]) -> bool:
11 """Tell if typing annotation is an optional."""
12 return (
13 (origin := typing.get_origin(annotation)) is types.UnionType
14 or origin is typing.Union
15 and any(arg is NoneType for arg in typing.get_args(annotation))
16 )
19def assert_isinstance(subject: Any, type_: type):
20 """Assert subject is instance of type."""
21 if not isinstance(subject, type_):
22 raise TypeError(f"Expected an instance of {type_}: {subject}")
25def assert_issubclass(subject: Any, type_: type):
26 """Assert subject is subclass of type."""
27 if not safe_issubclass(subject, type_):
28 raise TypeError(f"Expected a subclass of {type_}: {subject}")
31def as_type(subject: Any, of: Type[T] | None = None, strict: bool = False) -> Type[T] | type | None:
32 """Coerce subject to type."""
34 if isinstance(subject, TypeVar) and (bound := subject.__bound__):
35 subject = bound
37 if origin := get_origin(subject):
38 subject = origin
40 if not isinstance(subject, type):
41 if strict:
42 raise TypeError(f"Cannot coerce to type: {subject}")
43 return None
45 if of is not None and not issubclass(subject, of):
46 raise TypeError(f"Expected a subclass of {of}: {subject}")
48 return subject
51def safe_issubclass(subject: Any, bound: type) -> bool:
52 """Replacement for `issubclass` that works with generic type aliases (e.g.: Foo[T]).
54 Example:
55 class Foo(Generic[T]): pass
57 issubclass(Foo[int], Foo) # Raises: TypeError
59 is_subclass_or_origin(Foo[int], Foo) # Returns: True
60 """
61 if isinstance(subject, type):
62 return issubclass(subject, bound)
64 if isinstance(origin := typing.get_origin(subject), type):
65 return issubclass(origin, bound)
67 return False
70def get_origin(subject: type) -> type:
71 """Get type origin from parameterized generic type (handles edge case for Pydantic v2)."""
72 if pgm := getattr(subject, "__pydantic_generic_metadata__", None):
73 return pgm["origin"]
74 else:
75 return typing.get_origin(subject)
78def get_args(subject: type) -> Tuple[type, ...]:
79 """Get type param args from parameterized generic type (handles edge case for Pydantic v2)."""
80 if pgm := getattr(subject, "__pydantic_generic_metadata__", None):
81 return pgm["args"]
82 else:
83 return typing.get_args(subject)
86def iter_parameterized_bases(cls: type) -> Iterator[Tuple[type, Tuple[type, ...]]]:
87 """Iterate pairs of (type_origin, type_param_args) for all parameterized generic types in the class bases."""
88 for base in set(cls.__bases__).union([cls.__base__]):
89 if base is not None:
90 if args := get_args(base):
91 yield get_origin(base), args
92 yield from iter_parameterized_bases(base)