Coverage for src/mlopus/utils/dicts.py: 83%
81 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-13 14:49 +0000
1from copy import deepcopy
2from typing import Any, Sequence, Mapping, Tuple, Hashable, Dict, TypeVar, List, Iterable
4T = TypeVar("T")
6AnyDict = Dict[str, Any]
9class _Missing:
10 pass
13_MISSING = _Missing()
16def set_if_empty(_dict: dict, key: str, val: Any) -> dict:
17 """Set key to val in dict if current val is absent, None or an empty container."""
18 if not (current := _dict.get(key)) and current is not False:
19 _dict[key] = val
20 return _dict
23def set_reserved_key(_dict: Dict[T, Any] | None, key: T, val: Any) -> Dict[T, Any]:
24 """Set key in dict but raise exception if it was already present."""
25 if key in (_dict := {} if _dict is None else _dict):
26 raise KeyError(f"Reserved key: {key}")
27 _dict[key] = val
28 return _dict
31def map_leaf_vals(data: dict, mapper: callable) -> dict:
32 """Recursively map the leaf-values of a dict."""
33 new = {}
34 for key, val in data.items():
35 if isinstance(val, dict):
36 mapped = map_leaf_vals(val, mapper)
37 elif isinstance(val, (tuple, list, set)):
38 mapped = type(val)(mapper(x) for x in val)
39 else:
40 mapped = mapper(val)
41 new[key] = mapped
43 return new
46def get_nested(_dict: Mapping, keys: Sequence[Hashable], default: Any = _MISSING) -> Any:
47 """Given keys [a, b, c], return _dict[a][b][c]"""
48 target = _dict
50 for idx, key in enumerate(keys):
51 try:
52 target = target[key]
53 except KeyError:
54 if default is _MISSING:
55 raise KeyError(keys[0 : idx + 1]) # noqa
56 return default
58 return target
61def set_nested(_dict: Dict[Hashable, Any], keys: Sequence[Hashable], value: Any) -> Dict[Hashable, Any]:
62 """Given keys [a, b, c], set _dict[a][b][c] = value"""
63 target = _dict
65 for key in keys[:-1]:
66 if key not in target:
67 target[key] = {}
68 target = target[key]
70 target[keys[-1]] = value
71 return _dict
74def has_nested(_dict: Dict[Hashable, Any], keys: Sequence[Hashable]) -> bool:
75 """Given keys [a, b, c], tell if `_dict[a][b][c]` exists."""
76 try:
77 get_nested(_dict, keys)
78 return True
79 except KeyError:
80 return False
83def new_nested(keys: Sequence[Hashable], value: Any) -> Dict[Hashable, Any]:
84 """Given keys [a, b, c], produce {a: {b: {c: value}}}"""
85 return set_nested({}, keys, value)
88def filter_empty_leaves(dict_: Mapping) -> dict:
89 """Filter out leaf-values that are None or empty iterables."""
90 return unflatten(((k, v) for k, v in flatten(dict_).items() if v or v is False))
93def deep_merge(*dicts: dict):
94 """Merge dicts at the level of leaf-values."""
95 retval = {}
97 def _update(tgt: dict, src: Mapping, prefix_keys: List[str]):
98 for key, val in src.items():
99 _key = prefix_keys + [key]
101 if isinstance(val, Mapping) and (val or isinstance(get_nested(tgt, _key, None), Mapping)):
102 # Treat value as nested if it's a non-empty dict or if the target is already nested
103 _update(tgt, val, _key)
104 else:
105 # Treat value as a leaf (scalar) otherwise
106 set_nested(tgt, _key, deepcopy(val))
108 for _dict in dicts:
109 _update(retval, _dict, prefix_keys=[])
111 return retval
114def flatten(_dict: Mapping) -> Dict[Tuple[str, ...], Any]:
115 """Flatten dict turning nested keys into tuples."""
117 def _flatten(__dict: Mapping, prefix: Tuple[Hashable, ...]) -> dict:
118 flat = {}
120 for key, val in __dict.items():
121 key = (*prefix, key)
123 if isinstance(val, Mapping):
124 flat.update(_flatten(val, prefix=key))
125 else:
126 flat[key] = val
128 return flat
130 return _flatten(_dict, prefix=())
133def unflatten(_dict: Iterable[Tuple[Tuple[str, ...], Any]] | Mapping[Tuple[str, ...], Any]) -> Dict[str, Any]:
134 """Turn dict with top-level tuple keys into nested keys."""
135 result = {}
137 for key, val in _dict.items() if isinstance(_dict, Mapping) else _dict:
138 if isinstance(key, tuple):
139 set_nested(result, key, val)
141 return result