inox.tree_util#

Extended utilities for tree-like data structures

Functions#

tree_mask

Masks the static leaves of a tree.

tree_unmask

Unmasks the static leaves of a masked tree.

tree_partition

Flattens a tree and partitions the leaves.

tree_combine

Reconstructs a tree from the tree definition and leaf partitions.

tree_repr

Creates a pretty representation of a tree.

Classes#

Namespace

PyTree class for name-value mappings.

Partial

A version of functools.partial that is a PyTree.

Static

Wraps an hashable value as a leafless PyTree.

Descriptions#

class inox.tree_util.Namespace(**kwargs)#

PyTree class for name-value mappings.

Parameters:

kwargs – A name-value mapping.

Example

>>> tree = Namespace(a=1, b='2'); tree
Namespace(
  a = 1,
  b = '2'
)
>>> tree.c = [3, False]; tree
Namespace(
  a = 1,
  b = '2',
  c = [3, False]
)
>>> jax.tree_util.tree_leaves(tree)
[1, '2', 3, False]
class inox.tree_util.Partial(func, *args, **kwds)#

A version of functools.partial that is a PyTree.

Parameters:
  • func (Callable) – A function.

  • args (Any) – Positional arguments for future calls.

  • kwds (Any) – Keyword arguments for future calls.

Examples

>>> increment = Partial(jax.numpy.add, 1)
>>> increment(2)
Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n')
>>> println('Hello', 'World!')
Hello
World!
class inox.tree_util.Static(value)#

Wraps an hashable value as a leafless PyTree.

Parameters:

value (Hashable) – An hashable value to wrap.

Example

>>> tree = Static((0, 'one', None))
>>> tree.value
(0, 'one', None)
>>> jax.tree_util.tree_leaves(tree)
[]
>>> jax.tree_util.tree_structure(tree)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
inox.tree_util.tree_mask(tree, is_static=None)#

Masks the static leaves of a tree.

The structure of the tree remains unchanged, but leaves that are considered static are wrapped into a Static instance, which hides them from jax.tree_util.tree_leaves and jax.tree_util.tree_map.

See also

tree_unmask

Parameters:
  • tree (PyTree) – The tree to mask.

  • is_static (Callable[[Any], bool] | None) – A predicate for what to consider static. If None, all non-array leaves are considered static.

Returns:

The masked tree.

Return type:

PyTree

Example

>>> tree = [1, jax.numpy.arange(2), 'three']
>>> jax.tree_util.tree_leaves(tree)
[1, Array([0, 1], dtype=int32), 'three']
>>> tree = tree_mask(tree); tree
[Static(1), Array([0, 1], dtype=int32), Static('three')]
>>> jax.tree_util.tree_leaves(tree)
[Array([0, 1], dtype=int32)]
inox.tree_util.tree_unmask(tree)#

Unmasks the static leaves of a masked tree.

See also

tree_mask

Parameters:

tree (PyTree) – The masked tree to unmask.

Returns:

The unmasked tree.

Return type:

PyTree

Example

>>> tree = [Static(1), jax.numpy.arange(2), Static('three')]
>>> tree_unmask(tree)
[1, Array([0, 1], dtype=int32), 'three']
inox.tree_util.tree_partition(tree, *filters, is_leaf=None)#

Flattens a tree and partitions the leaves.

The leaves are partitioned into a set of path-leaf mappings. Each mapping contains the leaves of the subset of nodes satisfying the corresponding filtering constraint. The last mapping is dedicated to leaves that do not satisfy any constraint.

See also

tree_combine

Parameters:
  • tree (PyTree) – The tree to flatten.

  • filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

  • is_leaf (Callable[[Any], bool] | None) – A predicate for what to consider as a leaf.

Returns:

The tree definition and leaf partitions.

Return type:

Tuple[PyTreeDef, Dict[str, Any]]

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, leaves = tree_partition(tree)
>>> leaves
{'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False}
>>> treedef, arrays, others = tree_partition(tree, jax.Array)
>>> arrays
{'.b': Array([0, 1], dtype=int32)}
>>> others
{'.a': 1, '.c[0]': 'three', '.c[1]': False}
inox.tree_util.tree_combine(treedef, *leaves)#

Reconstructs a tree from the tree definition and leaf partitions.

See also

tree_partition

Parameters:
  • treedef (PyTreeDef) – The tree definition.

  • leaves (Dict[str, Any]) – The set of leaf partitions.

Returns:

The reconstructed tree.

Return type:

PyTree

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, arrays, others = tree_partition(tree, jax.Array)
>>> others = {key: str(leaf).upper() for key, leaf in others.items()}
>>> tree_combine(treedef, arrays, others)
Namespace(
  a = '1',
  b = int32[2],
  c = ['THREE', 'FALSE']
)
inox.tree_util.tree_repr(tree, linewidth=88, typeonly=True, **kwargs)#

Creates a pretty representation of a tree.

Parameters:
  • tree (PyTree) – The tree to represent.

  • linewidth (int) – The maximum line width before elements of tuples, lists and dicts are represented on separate lines.

  • typeonly (bool) – Whether to represent the type of arrays instead of their elements.

Returns:

The representation string.

Return type:

str

Example

>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jnp.arange(7)}]
>>> print(tree_repr(tree))
[
  1,
  'two',
  (True, False),
  [0, 1, 2, 3, 4, 5],
  {'6': int32[7]}
]