inox.tree

Extended utilities for tree-like data structures.

Classes

Namespace

PyTree class for name-value mappings.

Partial

A version of functools.partial that is a PyTree.

Static

Wraps an hashable value as a leafless PyTree.

Functions

mask

Masks the static leaves of a tree.

unmask

Unmasks the masked leaves of a tree.

partition

Flattens a tree and partitions its leaves.

combine

Reconstructs a tree from the tree definition and leaf partitions.

prepr

Creates a pretty representation of a tree.

Descriptions

class inox.tree.Namespace(**kwargs)

PyTree class for name-value mappings.

Parameters:

kwargs – A name-value mapping.

Example

>>> tree = Namespace(a=1, b='2'); tree
Namespace(
  a = 1,
  b = '2'
)
>>> tree.c = [3, False]; tree
Namespace(
  a = 1,
  b = '2',
  c = [3, False]
)
>>> jax.tree.leaves(tree)
[1, '2', 3, False]
class inox.tree.Partial(func, *args, **kwds)

A version of functools.partial that is a PyTree.

Parameters:
  • func (Callable) – A function.

  • args (Any) – Positional arguments for future calls.

  • kwds (Any) – Keyword arguments for future calls.

Examples

>>> increment = Partial(jax.numpy.add, 1)
>>> increment(2)
Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n')
>>> println('Hello', 'World!')
Hello
World!
class inox.tree.Static(value)

Wraps an hashable value as a leafless PyTree.

Parameters:

value (Hashable) – An hashable value to wrap.

Example

>>> tree = Static((0, 'one', None))
>>> tree.value
(0, 'one', None)
>>> jax.tree.leaves(tree)
[]
>>> jax.tree.structure(tree)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
inox.tree.mask(tree, is_static=None)

Masks the static leaves of a tree.

The structure of the tree remains unchanged, but leaves that are considered static are masked, which hides them from jax.tree.leaves and jax.tree.map. Applying inox.tree.mask more than once leads to the same tree.

See also

inox.tree.unmask

Parameters:
  • tree (PyTree) – The tree to mask.

  • is_static (Callable[[Any], bool]) – A predicate for what to consider static. If None, all non-array leaves are considered static.

Returns:

The masked tree.

Return type:

PyTree

Example

>>> tree = [1, jax.numpy.arange(2), 'three']
>>> jax.tree.leaves(tree)
[1, Array([0, 1], dtype=int32), 'three']
>>> tree = mask(tree); tree
[Mask(1), Array([0, 1], dtype=int32), Mask('three')]
>>> jax.tree.leaves(tree)
[Array([0, 1], dtype=int32)]
>>> unmask(tree)
[1, Array([0, 1], dtype=int32), 'three']
inox.tree.unmask(tree)

Unmasks the masked leaves of a tree.

See also

inox.tree.mask

Parameters:

tree (PyTree) – The masked tree to unmask.

Returns:

The unmasked tree.

Return type:

PyTree

inox.tree.partition(tree, *filters, is_leaf=None)

Flattens a tree and partitions its leaves.

The leaves are partitioned into a set of path-leaf mappings. The mapping in which a leaf is contained is chosen according to its oldest (closest to the root) ancestor that satisfies a constraint. If a node satisfies several constraints, the first one is selected. The last mapping is dedicated to leaves that do not satisfy any constraint.

Parameters:
  • tree (PyTree) – The tree to flatten.

  • filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

  • is_leaf (Callable[[Any], bool]) – A predicate for what to consider as a leaf.

Returns:

The tree definition and leaf partitions.

Return type:

Tuple[PyTreeDef, Dict[str, Any]]

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, leaves = inox.tree.partition(tree)
>>> leaves
{'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False}
>>> treedef, arrays, others = inox.tree.partition(tree, jax.Array)
>>> arrays
{'.b': Array([0, 1], dtype=int32)}
>>> others
{'.a': 1, '.c[0]': 'three', '.c[1]': False}
inox.tree.combine(treedef, *leaves)

Reconstructs a tree from the tree definition and leaf partitions.

Parameters:
  • treedef (PyTreeDef) – The tree definition.

  • leaves (Dict[str, Any]) – The set of leaf partitions.

Returns:

The reconstructed tree.

Return type:

PyTree

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, arrays, others = inox.tree.partition(tree, jax.Array)
>>> others = {key: str(leaf).upper() for key, leaf in others.items()}
>>> inox.tree.combine(treedef, arrays, others)
Namespace(
  a = '1',
  b = int32[2],
  c = ['THREE', 'FALSE']
)
inox.tree.prepr(tree, linewidth=88, typeonly=True, **kwargs)

Creates a pretty representation of a tree.

Parameters:
  • tree (PyTree) – The tree to represent.

  • linewidth (int) – The maximum line width before elements of tuples, lists and dicts are represented on separate lines.

  • typeonly (bool) – Whether to represent the type of arrays instead of their elements.

Returns:

The representation string.

Return type:

str

Example

>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jax.numpy.arange(7)}]
>>> tree.append(tree)
>>> print(inox.tree.prepr(tree))
[1, 'two', (True, False), [0, 1, 2, 3, 4], {'6': int32[7]}, [...]]