inox.nn.module

Base modules

In Inox, a module is a PyTree whose branches are its attributes. A branch can be any PyTree-compatible object (bool, str, list, dict, …), including other modules. Parametric functions, such as neural networks, should subclass Module and indicate their parameters with Parameter.

import inox
import inox.nn as nn
import jax
import jax.numpy as jnp

class Linear(nn.Module):
    def __init__(self, in_features, out_features, key):
        keys = jax.random.split(key, 2)

        self.weight = nn.Parameter(jax.random.normal(keys[0], (in_features, out_features)))
        self.bias = nn.Parameter(jax.random.normal(keys[1], (out_features,)))

    def __call__(self, x):
        return x @ self.weight() + self.bias()

class Classifier(nn.Module):
    def __init__(self, in_features, num_classes, key):
        keys = jax.random.split(key, 3)

        self.l1 = Linear(in_features, 64, key=keys[0])
        self.l2 = Linear(64, 64, key=keys[1])
        self.l3 = Linear(64, num_classes, key=keys[2])
        self.relu = nn.ReLU()

        self.return_logits = True

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        if self.return_logits:
            return x
        else:
            return jax.nn.softmax(x)

key = jax.random.key(0)
model = Classifier(16, 3, key)

Classes

Module

Base class for all modules.

ModuleDef

Abstraction for the static definition of a module.

Parameter

Wrapper to indicate an optimizable array.

ComplexParameter

Wrapper to indicate an optimizable complex array.

Descriptions

class inox.nn.module.Module(**kwargs)

Base class for all modules.

Parameters:

kwargs – A name-value mapping.

train(mode=True)

Toggles between training and evaluation modes.

This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as inox.nn.dropout.TrainingDropout and inox.nn.normalization.BatchNorm.

Parameters:

mode (bool) – Whether to turn training mode on or off.

Example

>>> model = Module(dropout=nn.TrainingDropout())
>>> model
Module(
  dropout = TrainingDropout(
    p = float32[]
  )
)
>>> model.train(False)
>>> model
Module(
  dropout = TrainingDropout(
    p = float32[],
    training = False
  )
)
partition(*filters)

Splits the static definition of the module from its arrays.

The arrays are partitioned into a set of path-array mappings. The mapping in which an array is contained is chosen according to its oldest (closest to the root) ancestor that satisfies a constraint. If a node satisfies several constraints, the first one is selected. The last mapping is dedicated to arrays that do not satisfy any constraint.

Parameters:

filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

Returns:

The module static definition and array partitions.

Return type:

Tuple[ModuleDef, Dict[str, Array]]

Examples

>>> keys = jax.random.split(jax.random.key(0), 2)
>>> model = Module(layers=[
...     nn.Linear(3, 64, key=keys[0]),
...     nn.ReLU(),
...     nn.TrainingDropout(),
...     nn.Linear(64, 5, key=keys[1]),
... ])
>>> static, arrays = model.partition()
>>> print(inox.tree.prepr(arrays))
{
  '.layers[0].bias.value': float32[64],
  '.layers[0].weight.value': float32[3, 64],
  '.layers[2].p': float32[],
  '.layers[3].bias.value': float32[5],
  '.layers[3].weight.value': float32[64, 5]
}
>>> static, params, others = model.partition(nn.Parameter)
>>> print(inox.tree.prepr(params))
{
  '.layers[0].bias.value': float32[64],
  '.layers[0].weight.value': float32[3, 64],
  '.layers[3].bias.value': float32[5],
  '.layers[3].weight.value': float32[64, 5]
}
>>> print(inox.tree.prepr(others))
{'.layers[2].p': float32[]}
>>> grads = jax.tree.map(jax.numpy.ones_like, params)
>>> params = jax.tree.map(lambda x, y: x + 0.01 * y, params, grads)
>>> model = static(params, others)  # updated copy
>>> model.layers[3].frozen = True
>>> filtr = lambda x: getattr(x, 'frozen', False)
>>> static, frozen, params, others = model.partition(filtr, nn.Parameter)
>>> print(inox.tree.prepr(frozen))
{'.layers[3].bias.value': float32[5], '.layers[3].weight.value': float32[64, 5]}
>>> print(inox.tree.prepr(params))
{'.layers[0].bias.value': float32[64], '.layers[0].weight.value': float32[3, 64]}
class inox.nn.module.ModuleDef(treedef, leaves)

Abstraction for the static definition of a module.

See also

Module.partition

Parameters:
  • treedef (PyTreeDef) – A module tree definition.

  • leaves (Dict[str, Any]) – The static (non-array) leaves of the module.

treedef: PyTreeDef

Alias for field number 0

leaves: Dict[str, Any]

Alias for field number 1

__call__(*arrays)
Parameters:

arrays (Dict[str, Array]) – A set of array partitions.

Returns:

A new instance of the module.

Return type:

Module

class inox.nn.module.Parameter(value)

Wrapper to indicate an optimizable array.

All arrays that require gradient updates in a Module should be wrapped in a Parameter instance.

Parameters:

value (Array) – An array.

Example

>>> weight = Parameter(jax.numpy.ones((3, 5))); weight
Parameter(float32[3, 5])
>>> x = jax.random.normal(jax.random.key(0), (16, 3))
>>> y = x @ weight()
__call__()
Returns:

The wrapped array.

Return type:

Array

class inox.nn.module.ComplexParameter(value)

Wrapper to indicate an optimizable complex array.

The real and imaginary parts are stored as separate floating point arrays to enable gradient-based optimization.

Parameters:

value (Array) – A complex array.

Example

>>> value = jax.numpy.ones((3, 5)) + 1j * jax.numpy.zeros((3, 5))
>>> weight = ComplexParameter(value); weight
ComplexParameter(complex64[3, 5])
>>> x = jax.random.normal(jax.random.key(0), (16, 3))
>>> y = x @ weight()