inox.tree_util#
Extended utilities for tree-like data structures
Functions#
Creates a pretty representation of a tree. |
Classes#
Descriptions#
- class inox.tree_util.Namespace(**kwargs)#
PyTree class for name-value mappings.
- Parameters:
kwargs – A name-value mapping.
Example
>>> ns = Namespace(a=1, b='2'); ns Namespace( a = 1, b = '2' ) >>> ns.c = [3, False]; ns Namespace( a = 1, b = '2', c = [3, False] ) >>> jax.tree_util.tree_leaves(ns) [1, '2', 3, False]
- class inox.tree_util.Static(value)#
Wraps an hashable value as a leafless PyTree.
- Parameters:
value (Hashable) – An hashable value to wrap.
Example
>>> x = Static((0, 'one', None)) >>> x.value (0, 'one', None) >>> jax.tree_util.tree_leaves(x) [] >>> jax.tree_util.tree_structure(x) PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
- class inox.tree_util.Auto(**kwargs)#
Subclass of
Namespace
that automatically detects non-array leaves and considers them as static.Important
object()
leaves are never considered static.- Parameters:
kwargs – A name-value mapping.
Example
>>> auto = Auto(a=1, b=jnp.array(2.0)); auto Auto( a = 1, b = float32[] ) >>> auto.c = ['3', jnp.arange(4)]; auto Auto( a = 1, b = float32[], c = ['3', int32[4]] ) >>> jax.tree_util.tree_leaves(auto) # only arrays [Array(2., dtype=float32, weak_type=True), Array([0, 1, 2, 3], dtype=int32)]
- inox.tree_util.tree_repr(x, /, linewidth=88, typeonly=True, **kwargs)#
Creates a pretty representation of a tree.
- Parameters:
- Returns:
The representation string.
- Return type:
Example
>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}] >>> print(tree_repr(tree)) [ 1, 'two', (True, False), [0, 1, 2, 3, 4, 5], {'6': int32[7]} ]