Coverage for src/mlopus/utils/import_utils.py: 94%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1import importlib 

2import typing 

3from typing import List, TypeVar, Type 

4 

5import importlib_metadata 

6 

7from mlopus.utils import typing_utils 

8 

9T = TypeVar("T") # Any type 

10 

11EntryPoint = importlib_metadata.EntryPoint 

12 

13 

14def fq_name(type_: type | Type[T]) -> str: 

15 """Get fully qualified type name to be used with `find_type`.""" 

16 return "%s:%s" % (type_.__module__, type_.__qualname__) 

17 

18 

19def find_attr(fq_name_: str, type_: Type[T] | None = None) -> T: 

20 """Find attribute by fully qualified name (e.g.: `package.module:Class.attribute`).""" 

21 if ":" in fq_name_: 

22 mod_path, attr_path = fq_name_.split(":") 

23 else: 

24 mod_path, attr_path = fq_name_, None 

25 

26 cursor = importlib.import_module(mod_path) 

27 

28 if attr_path: 

29 for part in attr_path.split("."): 

30 cursor = getattr(cursor, part) 

31 

32 if type_ is not None: 

33 typing_utils.assert_isinstance(cursor, type_) 

34 

35 return cursor 

36 

37 

38def find_type(fq_name_: str, type_: Type[T] | None = None) -> Type[T]: 

39 """Find type by fully qualified name (e.g.: `package.module:Class.InnerClass`).""" 

40 found = find_attr(fq_name_, type) 

41 

42 if type_ is not None: 

43 typing_utils.assert_issubclass(found, type_) 

44 

45 return typing.cast(Type[T], found) 

46 

47 

48def list_plugins(group: str) -> List[EntryPoint]: 

49 """List plugin objects in group.""" 

50 return list(importlib_metadata.entry_points(group=group)) 

51 

52 

53def load_plugin(group: str, name: str, type_: Type[T] | None = None) -> Type[T]: 

54 """Load named plugin from specified group.""" 

55 if (n := len(plugins := importlib_metadata.entry_points(group=group, name=name))) != 1: 

56 raise RuntimeError(f"Expected exactly 1 plugin named '{name}' in group '{group}', fround {n}: {plugins}") 

57 

58 val = list(plugins)[0].load() 

59 

60 if type_ is not None: 

61 typing_utils.assert_issubclass(val, type_) 

62 

63 return val