inox.tree_util#

Extended utilities for tree-like data structures

Functions#

tree_repr

Creates a pretty representation of a tree.

Classes#

Namespace

PyTree class for name-value mappings.

Static

Wraps an hashable value as a leafless PyTree.

Auto

Subclass of Namespace that automatically detects non-array leaves and considers them as static.

Descriptions#

class inox.tree_util.Namespace(**kwargs)#

PyTree class for name-value mappings.

Parameters:

kwargs – A name-value mapping.

Example

>>> ns = Namespace(a=1, b='2'); ns
Namespace(
  a = 1,
  b = '2'
)
>>> ns.c = [3, False]; ns
Namespace(
  a = 1,
  b = '2',
  c = [3, False]
)
>>> jax.tree_util.tree_leaves(ns)
[1, '2', 3, False]
class inox.tree_util.Static(value)#

Wraps an hashable value as a leafless PyTree.

Parameters:

value (Hashable) – An hashable value to wrap.

Example

>>> x = Static((0, 'one', None))
>>> x.value
(0, 'one', None)
>>> jax.tree_util.tree_leaves(x)
[]
>>> jax.tree_util.tree_structure(x)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
class inox.tree_util.Auto(**kwargs)#

Subclass of Namespace that automatically detects non-array leaves and considers them as static.

Important

object() leaves are never considered static.

Parameters:

kwargs – A name-value mapping.

Example

>>> auto = Auto(a=1, b=jnp.array(2.0)); auto
Auto(
  a = 1,
  b = float32[]
)
>>> auto.c = ['3', jnp.arange(4)]; auto
Auto(
  a = 1,
  b = float32[],
  c = ['3', int32[4]]
)
>>> jax.tree_util.tree_leaves(auto)  # only arrays
[Array(2., dtype=float32, weak_type=True), Array([0, 1, 2, 3], dtype=int32)]
inox.tree_util.tree_repr(x, /, linewidth=88, typeonly=True, **kwargs)#

Creates a pretty representation of a tree.

Parameters:
  • x (PyTree) – The tree to represent.

  • linewidth (int) – The maximum line width before elements of tuples, lists and dicts are represented on separate lines.

  • typeonly (bool) – Whether to represent the type of arrays instead of their elements.

Returns:

The representation string.

Return type:

str

Example

>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}]
>>> print(tree_repr(tree))
[
  1,
  'two',
  (True, False),
  [0, 1, 2, 3, 4, 5],
  {'6': int32[7]}
]