inox.tree_util#
Extended utilities for tree-like data structures
Functions#
Masks the static leaves of a tree. |
|
Unmasks the static leaves of a masked tree. |
|
Flattens a tree and partitions the leaves. |
|
Reconstructs a tree from the tree definition and leaf partitions. |
|
Creates a pretty representation of a tree. |
Classes#
Descriptions#
- class inox.tree_util.Namespace(**kwargs)#
PyTree class for name-value mappings.
- Parameters:
kwargs – A name-value mapping.
Example
>>> tree = Namespace(a=1, b='2'); tree Namespace( a = 1, b = '2' ) >>> tree.c = [3, False]; tree Namespace( a = 1, b = '2', c = [3, False] ) >>> jax.tree_util.tree_leaves(tree) [1, '2', 3, False]
- class inox.tree_util.Partial(func, *args, **kwds)#
A version of
functools.partial
that is a PyTree.- Parameters:
Examples
>>> increment = Partial(jax.numpy.add, 1) >>> increment(2) Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n') >>> println('Hello', 'World!') Hello World!
- class inox.tree_util.Static(value)#
Wraps an hashable value as a leafless PyTree.
- Parameters:
value (Hashable) – An hashable value to wrap.
Example
>>> tree = Static((0, 'one', None)) >>> tree.value (0, 'one', None) >>> jax.tree_util.tree_leaves(tree) [] >>> jax.tree_util.tree_structure(tree) PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
- inox.tree_util.tree_mask(tree, is_static=None)#
Masks the static leaves of a tree.
The structure of the tree remains unchanged, but leaves that are considered static are wrapped into a
Static
instance, which hides them fromjax.tree_util.tree_leaves
andjax.tree_util.tree_map
.See also
- Parameters:
- Returns:
The masked tree.
- Return type:
PyTree
Example
>>> tree = [1, jax.numpy.arange(2), 'three'] >>> jax.tree_util.tree_leaves(tree) [1, Array([0, 1], dtype=int32), 'three'] >>> tree = tree_mask(tree); tree [Static(1), Array([0, 1], dtype=int32), Static('three')] >>> jax.tree_util.tree_leaves(tree) [Array([0, 1], dtype=int32)]
- inox.tree_util.tree_unmask(tree)#
Unmasks the static leaves of a masked tree.
See also
- Parameters:
tree (PyTree) – The masked tree to unmask.
- Returns:
The unmasked tree.
- Return type:
PyTree
Example
>>> tree = [Static(1), jax.numpy.arange(2), Static('three')] >>> tree_unmask(tree) [1, Array([0, 1], dtype=int32), 'three']
- inox.tree_util.tree_partition(tree, *filters, is_leaf=None)#
Flattens a tree and partitions the leaves.
The leaves are partitioned into a set of path-leaf mappings. Each mapping contains the leaves of the subset of nodes satisfying the corresponding filtering constraint. The last mapping is dedicated to leaves that do not satisfy any constraint.
See also
- Parameters:
- Returns:
The tree definition and leaf partitions.
- Return type:
Example
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False]) >>> treedef, leaves = tree_partition(tree) >>> leaves {'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False} >>> treedef, arrays, others = tree_partition(tree, jax.Array) >>> arrays {'.b': Array([0, 1], dtype=int32)} >>> others {'.a': 1, '.c[0]': 'three', '.c[1]': False}
- inox.tree_util.tree_combine(treedef, *leaves)#
Reconstructs a tree from the tree definition and leaf partitions.
See also
- Parameters:
- Returns:
The reconstructed tree.
- Return type:
PyTree
Example
>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False]) >>> treedef, arrays, others = tree_partition(tree, jax.Array) >>> others = {key: str(leaf).upper() for key, leaf in others.items()} >>> tree_combine(treedef, arrays, others) Namespace( a = '1', b = int32[2], c = ['THREE', 'FALSE'] )
- inox.tree_util.tree_repr(tree, linewidth=88, typeonly=True, **kwargs)#
Creates a pretty representation of a tree.
- Parameters:
- Returns:
The representation string.
- Return type:
Example
>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}] >>> print(tree_repr(tree)) [ 1, 'two', (True, False), [0, 1, 2, 3, 4, 5], {'6': int32[7]} ]