inox.tree¶
Extended utilities for tree-like data structures.
Classes¶
Functions¶
Descriptions¶
- class inox.tree.Namespace(**kwargs)¶
PyTree class for name-value mappings.
- Parameters:
kwargs – A name-value mapping.
Example
>>> tree = Namespace(a=1, b='2'); tree Namespace( a = 1, b = '2' ) >>> tree.c = [3, False]; tree Namespace( a = 1, b = '2', c = [3, False] ) >>> jax.tree.leaves(tree) [1, '2', 3, False]
- class inox.tree.Partial(func, *args, **kwds)¶
A version of
functools.partialthat is a PyTree.- Parameters:
Examples
>>> increment = Partial(jax.numpy.add, 1) >>> increment(2) Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n') >>> println('Hello', 'World!') Hello World!
- class inox.tree.Static(value)¶
Wraps an hashable value as a leafless PyTree.
- Parameters:
value (Hashable) – An hashable value to wrap.
Example
>>> tree = Static((0, 'one', None)) >>> tree.value (0, 'one', None) >>> jax.tree.leaves(tree) [] >>> jax.tree.structure(tree) PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
- inox.tree.mask(tree, is_static=None)¶
Masks the static leaves of a tree.
The structure of the tree remains unchanged, but leaves that are considered static are masked, which hides them from
jax.tree.leavesandjax.tree.map. Applyinginox.tree.maskmore than once leads to the same tree.See also
- Parameters:
- Returns:
The masked tree.
- Return type:
PyTree
Example
>>> tree = [1, jax.numpy.arange(2), 'three'] >>> jax.tree.leaves(tree) [1, Array([0, 1], dtype=int32), 'three'] >>> tree = mask(tree); tree [Mask(1), Array([0, 1], dtype=int32), Mask('three')] >>> jax.tree.leaves(tree) [Array([0, 1], dtype=int32)] >>> unmask(tree) [1, Array([0, 1], dtype=int32), 'three']
- inox.tree.unmask(tree)¶
Unmasks the masked leaves of a tree.
See also
- Parameters:
tree (PyTree) – The masked tree to unmask.
- Returns:
The unmasked tree.
- Return type:
PyTree
- inox.tree.partition(tree, *filters, is_leaf=None)¶
Flattens a tree and partitions its leaves.
The leaves are partitioned into a set of path-leaf mappings. The mapping in which a leaf is contained is chosen according to its oldest (closest to the root) ancestor that satisfies a constraint. If a node satisfies several constraints, the first one is selected. The last mapping is dedicated to leaves that do not satisfy any constraint.
See also
- Parameters:
- Returns:
The tree definition and leaf partitions.
- Return type:
Example
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False]) >>> treedef, leaves = inox.tree.partition(tree) >>> leaves {'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False} >>> treedef, arrays, others = inox.tree.partition(tree, jax.Array) >>> arrays {'.b': Array([0, 1], dtype=int32)} >>> others {'.a': 1, '.c[0]': 'three', '.c[1]': False}
- inox.tree.combine(treedef, *leaves)¶
Reconstructs a tree from the tree definition and leaf partitions.
See also
- Parameters:
- Returns:
The reconstructed tree.
- Return type:
PyTree
Example
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False]) >>> treedef, arrays, others = inox.tree.partition(tree, jax.Array) >>> others = {key: str(leaf).upper() for key, leaf in others.items()} >>> inox.tree.combine(treedef, arrays, others) Namespace( a = '1', b = int32[2], c = ['THREE', 'FALSE'] )
- inox.tree.prepr(tree, linewidth=88, typeonly=True, **kwargs)¶
Creates a pretty representation of a tree.
- Parameters:
- Returns:
The representation string.
- Return type:
Example
>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jax.numpy.arange(7)}] >>> tree.append(tree) >>> print(inox.tree.prepr(tree)) [1, 'two', (True, False), [0, 1, 2, 3, 4], {'6': int32[7]}, [...]]