inox.tree

Extended utilities for tree-like data structures.

Classes

Namespace

PyTree class for name-value mappings.

Partial

A version of functools.partial that is a PyTree.

Static

Wraps an hashable value as a leafless PyTree.

Functions

mask_static

Masks the static leaves of a tree.

unmask_static

Unmasks the static leaves of a masked tree.

partition

Flattens a tree and partitions the leaves.

combine

Reconstructs a tree from the tree definition and leaf partitions.

prepr

Creates a pretty representation of a tree.

Descriptions

class inox.tree.Namespace(**kwargs)

PyTree class for name-value mappings.

Parameters:

kwargs – A name-value mapping.

Example

>>> tree = Namespace(a=1, b='2'); tree
Namespace(
  a = 1,
  b = '2'
)
>>> tree.c = [3, False]; tree
Namespace(
  a = 1,
  b = '2',
  c = [3, False]
)
>>> jax.tree.leaves(tree)
[1, '2', 3, False]
class inox.tree.Partial(func, *args, **kwds)

A version of functools.partial that is a PyTree.

Parameters:
  • func (Callable) – A function.

  • args (Any) – Positional arguments for future calls.

  • kwds (Any) – Keyword arguments for future calls.

Examples

>>> increment = Partial(jax.numpy.add, 1)
>>> increment(2)
Array(3, dtype=int32, weak_type=True)
>>> println = Partial(print, sep='\n')
>>> println('Hello', 'World!')
Hello
World!
class inox.tree.Static(value)

Wraps an hashable value as a leafless PyTree.

Parameters:

value (Hashable) – An hashable value to wrap.

Example

>>> tree = Static((0, 'one', None))
>>> tree.value
(0, 'one', None)
>>> jax.tree.leaves(tree)
[]
>>> jax.tree.structure(tree)
PyTreeDef(CustomNode(Static[(0, 'one', None)], []))
inox.tree.mask_static(tree, is_static=None)

Masks the static leaves of a tree.

The structure of the tree remains unchanged, but leaves that are considered static are masked, which hides them from jax.tree.leaves and jax.tree.map. Applying inox.tree.mask_static several times leads to the same tree.

Parameters:
  • tree (PyTree) – The tree to mask.

  • is_static (Callable[[Any], bool] | None) – A predicate for what to consider static. If None, all non-array leaves are considered static.

Returns:

The masked tree.

Return type:

PyTree

Example

>>> tree = [1, jax.numpy.arange(2), 'three']
>>> jax.tree.leaves(tree)
[1, Array([0, 1], dtype=int32), 'three']
>>> tree = inox.tree.mask_static(tree); tree
[Mask(1), Array([0, 1], dtype=int32), Mask('three')]
>>> jax.tree.leaves(tree)
[Array([0, 1], dtype=int32)]
>>> inox.tree.unmask_static(tree)
[1, Array([0, 1], dtype=int32), 'three']
inox.tree.unmask_static(tree)

Unmasks the static leaves of a masked tree.

Parameters:

tree (PyTree) – The masked tree to unmask.

Returns:

The unmasked tree.

Return type:

PyTree

inox.tree.partition(tree, *filters, is_leaf=None)

Flattens a tree and partitions the leaves.

The leaves are partitioned into a set of path-leaf mappings. The mapping in which a leaf is contained is chosen according to its oldest (closest to the root) ancestor that satisfies a constraint. If a node satisfies several constraints, the first one is selected. The last mapping is dedicated to leaves that do not satisfy any constraint.

Parameters:
  • tree (PyTree) – The tree to flatten.

  • filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

  • is_leaf (Callable[[Any], bool] | None) – A predicate for what to consider as a leaf.

Returns:

The tree definition and leaf partitions.

Return type:

Tuple[PyTreeDef, Dict[str, Any]]

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, leaves = inox.tree.partition(tree)
>>> leaves
{'.a': 1, '.b': Array([0, 1], dtype=int32), '.c[0]': 'three', '.c[1]': False}
>>> treedef, arrays, others = inox.tree.partition(tree, jax.Array)
>>> arrays
{'.b': Array([0, 1], dtype=int32)}
>>> others
{'.a': 1, '.c[0]': 'three', '.c[1]': False}
inox.tree.combine(treedef, *leaves)

Reconstructs a tree from the tree definition and leaf partitions.

Parameters:
  • treedef (PyTreeDef) – The tree definition.

  • leaves (Dict[str, Any]) – The set of leaf partitions.

Returns:

The reconstructed tree.

Return type:

PyTree

Example

>>> tree = Namespace(a=1, b=jax.numpy.arange(2), c=['three', False])
>>> treedef, arrays, others = inox.tree.partition(tree, jax.Array)
>>> others = {key: str(leaf).upper() for key, leaf in others.items()}
>>> inox.tree.combine(treedef, arrays, others)
Namespace(
  a = '1',
  b = int32[2],
  c = ['THREE', 'FALSE']
)
inox.tree.prepr(tree, linewidth=88, typeonly=True, **kwargs)

Creates a pretty representation of a tree.

Parameters:
  • tree (PyTree) – The tree to represent.

  • linewidth (int) – The maximum line width before elements of tuples, lists and dicts are represented on separate lines.

  • typeonly (bool) – Whether to represent the type of arrays instead of their elements.

Returns:

The representation string.

Return type:

str

Example

>>> tree = [1, 'two', (True, False), list(range(5)), {'6': jax.numpy.arange(7)}]
>>> tree.append(Namespace(eight=None))
>>> print(inox.tree.prepr(tree))
[
  1,
  'two',
  (True, False),
  [0, 1, 2, 3, 4],
  {'6': int32[7]},
  Namespace(
    eight = None
  )
]