Coverage for src/mlopus/utils/dicts.py: 83%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-13 14:49 +0000

1from copy import deepcopy 

2from typing import Any, Sequence, Mapping, Tuple, Hashable, Dict, TypeVar, List, Iterable 

3 

4T = TypeVar("T") 

5 

6AnyDict = Dict[str, Any] 

7 

8 

9class _Missing: 

10 pass 

11 

12 

13_MISSING = _Missing() 

14 

15 

16def set_if_empty(_dict: dict, key: str, val: Any) -> dict: 

17 """Set key to val in dict if current val is absent, None or an empty container.""" 

18 if not (current := _dict.get(key)) and current is not False: 

19 _dict[key] = val 

20 return _dict 

21 

22 

23def set_reserved_key(_dict: Dict[T, Any] | None, key: T, val: Any) -> Dict[T, Any]: 

24 """Set key in dict but raise exception if it was already present.""" 

25 if key in (_dict := {} if _dict is None else _dict): 

26 raise KeyError(f"Reserved key: {key}") 

27 _dict[key] = val 

28 return _dict 

29 

30 

31def map_leaf_vals(data: dict, mapper: callable) -> dict: 

32 """Recursively map the leaf-values of a dict.""" 

33 new = {} 

34 for key, val in data.items(): 

35 if isinstance(val, dict): 

36 mapped = map_leaf_vals(val, mapper) 

37 elif isinstance(val, (tuple, list, set)): 

38 mapped = type(val)(mapper(x) for x in val) 

39 else: 

40 mapped = mapper(val) 

41 new[key] = mapped 

42 

43 return new 

44 

45 

46def get_nested(_dict: Mapping, keys: Sequence[Hashable], default: Any = _MISSING) -> Any: 

47 """Given keys [a, b, c], return _dict[a][b][c]""" 

48 target = _dict 

49 

50 for idx, key in enumerate(keys): 

51 try: 

52 target = target[key] 

53 except KeyError: 

54 if default is _MISSING: 

55 raise KeyError(keys[0 : idx + 1]) # noqa 

56 return default 

57 

58 return target 

59 

60 

61def set_nested(_dict: Dict[Hashable, Any], keys: Sequence[Hashable], value: Any) -> Dict[Hashable, Any]: 

62 """Given keys [a, b, c], set _dict[a][b][c] = value""" 

63 target = _dict 

64 

65 for key in keys[:-1]: 

66 if key not in target: 

67 target[key] = {} 

68 target = target[key] 

69 

70 target[keys[-1]] = value 

71 return _dict 

72 

73 

74def has_nested(_dict: Dict[Hashable, Any], keys: Sequence[Hashable]) -> bool: 

75 """Given keys [a, b, c], tell if `_dict[a][b][c]` exists.""" 

76 try: 

77 get_nested(_dict, keys) 

78 return True 

79 except KeyError: 

80 return False 

81 

82 

83def new_nested(keys: Sequence[Hashable], value: Any) -> Dict[Hashable, Any]: 

84 """Given keys [a, b, c], produce {a: {b: {c: value}}}""" 

85 return set_nested({}, keys, value) 

86 

87 

88def filter_empty_leaves(dict_: Mapping) -> dict: 

89 """Filter out leaf-values that are None or empty iterables.""" 

90 return unflatten(((k, v) for k, v in flatten(dict_).items() if v or v is False)) 

91 

92 

93def deep_merge(*dicts: dict): 

94 """Merge dicts at the level of leaf-values.""" 

95 retval = {} 

96 

97 def _update(tgt: dict, src: Mapping, prefix_keys: List[str]): 

98 for key, val in src.items(): 

99 _key = prefix_keys + [key] 

100 

101 if isinstance(val, Mapping) and (val or isinstance(get_nested(tgt, _key, None), Mapping)): 

102 # Treat value as nested if it's a non-empty dict or if the target is already nested 

103 _update(tgt, val, _key) 

104 else: 

105 # Treat value as a leaf (scalar) otherwise 

106 set_nested(tgt, _key, deepcopy(val)) 

107 

108 for _dict in dicts: 

109 _update(retval, _dict, prefix_keys=[]) 

110 

111 return retval 

112 

113 

114def flatten(_dict: Mapping) -> Dict[Tuple[str, ...], Any]: 

115 """Flatten dict turning nested keys into tuples.""" 

116 

117 def _flatten(__dict: Mapping, prefix: Tuple[Hashable, ...]) -> dict: 

118 flat = {} 

119 

120 for key, val in __dict.items(): 

121 key = (*prefix, key) 

122 

123 if isinstance(val, Mapping): 

124 flat.update(_flatten(val, prefix=key)) 

125 else: 

126 flat[key] = val 

127 

128 return flat 

129 

130 return _flatten(_dict, prefix=()) 

131 

132 

133def unflatten(_dict: Iterable[Tuple[Tuple[str, ...], Any]] | Mapping[Tuple[str, ...], Any]) -> Dict[str, Any]: 

134 """Turn dict with top-level tuple keys into nested keys.""" 

135 result = {} 

136 

137 for key, val in _dict.items() if isinstance(_dict, Mapping) else _dict: 

138 if isinstance(key, tuple): 

139 set_nested(result, key, val) 

140 

141 return result