inox.nn.module#
Base modules
Classes#
Descriptions#
- class inox.nn.module.Module(**kwargs)#
Base class for all modules.
A module is a PyTree whose attributes are branches, meaning that you can assign any PyTree-compatible object (
tuple
,list
,dict
, …), including other modules, as regular attribute. Parametric functions, such as neural networks, should subclassModule
.import jax import inox import inox.nn as nn class Classifier(nn.Module): def __init__(self, key, in_features, num_classes): keys = jax.random.split(key, 3) self.l1 = nn.Linear(keys[0], in_features, 64) self.l2 = nn.Linear(keys[1], 64, 64) self.l3 = nn.Linear(keys[2], 64, num_classes) self.relu = nn.ReLU() self.softmax = nn.Softmax() self.return_logits = True # static leaf def __call__(self, x): x = self.l1(x) x = self.l2(self.relu(x)) x = self.l3(self.relu(x)) if self.return_logits: return x else: return self.softmax(x) key = jax.random.key(0) model = Classifier(key)
Modules automatically detect non-array leaves and consider them as static. This results in module instances compatible with native JAX transformations (
jax.jit
,jax.vmap
,jax.grad
, …) out of the box.import optax @jax.jit def loss_fn(model, x, y): logits = jax.vmap(model)(x) loss = optax.softmax_cross_entropy(logits, y) return jax.numpy.mean(loss) grads = jax.grad(loss_fn)(model, x, y)
- Parameters:
kwargs – A name-value mapping.
- pure()#
Returns a functionally pure copy of the module.
JAX transformations are designed to work on pure functions. Some neural network layers, including batch normalization, are not pure as they hold a state which is updated as part of the layer’s execution.
The
Module.pure
method separates the functional definition of the module from its state, that is its parameters and buffers, which prevents unnoticed state mutations during training.# Impure output = model(*args) # Pure stateless, state = model.pure() output, mutations = stateless.apply(state, *args) state['buffers'].update(mutations) # only buffers can mutate
Using the functional definition of modules is safer but also handy for training with Optax optimizers when the model contains buffers.
stateless, state = model.pure() params, buffers = state['params'], state['buffers'] optimizer = optax.adamw(learning_rate=1e-3) opt_state = optimizer.init(params) @jax.jit def step(params, buffers, opt_state, x, y): # gradient descent step def ell(params): state = dict(params=params, buffers=buffers) logits, mutations = stateless.apply(state, x) loss = optax.softmax_cross_entropy(logits, y) return jax.numpy.mean(loss), mutations grads, mutations = jax.grad(ell, has_aux=True)(params) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) buffers.update(mutations) return params, buffers, opt_state for x, y in trainset: # training loop params, buffers, opt_state = step(params, buffers, opt_state, x, y) model = stateless.impure(dict(params=params, buffers=buffers))
See also
- impure(state)#
Returns a functionally impure copy of the module.
- apply(state, *args, method=None, **kwargs)#
Applies a module method for a given state.
- Parameters:
- Returns:
The method’s output and the state mutations.
- Return type:
- train(mode=True)#
Toggles between training and evaluation modes.
This is primarily useful for (sub)modules that behave differently at training and evaluation, such as
inox.nn.dropout.TrainingDropout
andinox.nn.normalization.BatchNorm
.model.train(False) # turns off dropout
- Parameters:
mode (bool) – Whether to turn training mode on or off.