Coverage for src/mlopus/utils/typing_utils.py: 80%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import types 

2import typing 

3from typing import Any, TypeVar, Type, Tuple, Iterator 

4 

5T = TypeVar("T") # Any type 

6 

7NoneType = type(None) 

8 

9 

10def is_optional(annotation: type | Type[T]) -> bool: 

11 """Tell if typing annotation is an optional.""" 

12 return ( 

13 (origin := typing.get_origin(annotation)) is types.UnionType 

14 or origin is typing.Union 

15 and any(arg is NoneType for arg in typing.get_args(annotation)) 

16 ) 

17 

18 

19def assert_isinstance(subject: Any, type_: type): 

20 """Assert subject is instance of type.""" 

21 if not isinstance(subject, type_): 

22 raise TypeError(f"Expected an instance of {type_}: {subject}") 

23 

24 

25def assert_issubclass(subject: Any, type_: type): 

26 """Assert subject is subclass of type.""" 

27 if not safe_issubclass(subject, type_): 

28 raise TypeError(f"Expected a subclass of {type_}: {subject}") 

29 

30 

31def as_type(subject: Any, of: Type[T] | None = None, strict: bool = False) -> Type[T] | type | None: 

32 """Coerce subject to type.""" 

33 

34 if isinstance(subject, TypeVar) and (bound := subject.__bound__): 

35 subject = bound 

36 

37 if origin := get_origin(subject): 

38 subject = origin 

39 

40 if not isinstance(subject, type): 

41 if strict: 

42 raise TypeError(f"Cannot coerce to type: {subject}") 

43 return None 

44 

45 if of is not None and not issubclass(subject, of): 

46 raise TypeError(f"Expected a subclass of {of}: {subject}") 

47 

48 return subject 

49 

50 

51def safe_issubclass(subject: Any, bound: type) -> bool: 

52 """Replacement for `issubclass` that works with generic type aliases (e.g.: Foo[T]). 

53 

54 Example: 

55 class Foo(Generic[T]): pass 

56 

57 issubclass(Foo[int], Foo) # Raises: TypeError 

58 

59 is_subclass_or_origin(Foo[int], Foo) # Returns: True 

60 """ 

61 if isinstance(subject, type): 

62 return issubclass(subject, bound) 

63 

64 if isinstance(origin := typing.get_origin(subject), type): 

65 return issubclass(origin, bound) 

66 

67 return False 

68 

69 

70def get_origin(subject: type) -> type: 

71 """Get type origin from parameterized generic type (handles edge case for Pydantic v2).""" 

72 if pgm := getattr(subject, "__pydantic_generic_metadata__", None): 

73 return pgm["origin"] 

74 else: 

75 return typing.get_origin(subject) 

76 

77 

78def get_args(subject: type) -> Tuple[type, ...]: 

79 """Get type param args from parameterized generic type (handles edge case for Pydantic v2).""" 

80 if pgm := getattr(subject, "__pydantic_generic_metadata__", None): 

81 return pgm["args"] 

82 else: 

83 return typing.get_args(subject) 

84 

85 

86def iter_parameterized_bases(cls: type) -> Iterator[Tuple[type, Tuple[type, ...]]]: 

87 """Iterate pairs of (type_origin, type_param_args) for all parameterized generic types in the class bases.""" 

88 for base in set(cls.__bases__).union([cls.__base__]): 

89 if base is not None: 

90 if args := get_args(base): 

91 yield get_origin(base), args 

92 yield from iter_parameterized_bases(base)