inox.nn.module#

Base modules

Classes#

Module

Base class for all modules.

Buffer

Container for non-optimizable arrays.

Descriptions#

class inox.nn.module.Module(**kwargs)#

Base class for all modules.

A module is a PyTree whose attributes are branches, meaning that you can assign any PyTree-compatible object (tuple, list, dict, …), including other modules, as regular attribute. Parametric functions, such as neural networks, should subclass Module.

import jax
import inox
import inox.nn as nn

class Classifier(nn.Module):
    def __init__(self, key, in_features, num_classes):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(keys[0], in_features, 64)
        self.l2 = nn.Linear(keys[1], 64, 64)
        self.l3 = nn.Linear(keys[2], 64, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

        self.return_logits = True  # static leaf

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        if self.return_logits:
            return x
        else:
            return self.softmax(x)

key = jax.random.key(0)
model = Classifier(key)

Modules automatically detect non-array leaves and consider them as static. This results in module instances compatible with native JAX transformations (jax.jit, jax.vmap, jax.grad, …) out of the box.

import optax

@jax.jit
def loss_fn(model, x, y):
    logits = jax.vmap(model)(x)
    loss = optax.softmax_cross_entropy(logits, y)

    return jax.numpy.mean(loss)

grads = jax.grad(loss_fn)(model, x, y)
Parameters:

kwargs – A name-value mapping.

pure()#

Returns a functionally pure copy of the module.

JAX transformations are designed to work on pure functions. Some neural network layers, including batch normalization, are not pure as they hold a state which is updated as part of the layer’s execution.

The Module.pure method separates the functional definition of the module from its state, that is its parameters and buffers, which prevents unnoticed state mutations during training.

# Impure
output = model(*args)

# Pure
stateless, state = model.pure()
output, mutations = stateless.apply(state, *args)
state['buffers'].update(mutations)  # only buffers can mutate

Using the functional definition of modules is safer but also handy for training with Optax optimizers when the model contains buffers.

stateless, state = model.pure()
params, buffers = state['params'], state['buffers']
optimizer = optax.adamw(learning_rate=1e-3)
opt_state = optimizer.init(params)

@jax.jit
def step(params, buffers, opt_state, x, y):  # gradient descent step
    def ell(params):
        state = dict(params=params, buffers=buffers)
        logits, mutations = stateless.apply(state, x)
        loss = optax.softmax_cross_entropy(logits, y)

        return jax.numpy.mean(loss), mutations

    grads, mutations = jax.grad(ell, has_aux=True)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    buffers.update(mutations)

    return params, buffers, opt_state

for x, y in trainset:  # training loop
    params, buffers, opt_state = step(params, buffers, opt_state, x, y)

model = stateless.impure(dict(params=params, buffers=buffers))
Returns:

A stateless copy of the module and the state dictionary.

Return type:

Tuple[Module, Dict[str, Dict[str, Array]]]

impure(state)#

Returns a functionally impure copy of the module.

Parameters:

state (Dict[str, Dict[str, Array]]) – The state dictionary.

Returns:

A copy of the module where parameters and buffers have been put back in place.

apply(state, *args, method=None, **kwargs)#

Applies a module method for a given state.

Parameters:
  • state (Dict[str, Dict[str, Array]]) – The state dictionary.

  • method (str | Callable | None) – The method to apply. Either a name (string) or a callable. If None, __call__ is applied instead.

  • args – The postitional arguments of the method.

  • kwargs – The keyword arguments of the method.

Returns:

The method’s output and the state mutations.

Return type:

Tuple[Any, Dict[str, Array]]

train(mode=True)#

Toggles between training and evaluation modes.

This is primarily useful for (sub)modules that behave differently at training and evaluation, such as inox.nn.dropout.TrainingDropout and inox.nn.normalization.BatchNorm.

model.train(False)  # turns off dropout
Parameters:

mode (bool) – Whether to turn training mode on or off.

class inox.nn.module.Buffer(**kwargs)#

Container for non-optimizable arrays.

All arrays that do not require gradient updates in a module, such as constants or running statistics should be leaves of a Buffer instance.

Parameters:

kwargs – A name-value mapping.