inox.nn.module#
Base modules
Classes#
Descriptions#
- class inox.nn.module.Module(**kwargs)#
Base class for all modules.
A module is a PyTree whose branches are its attributes. A branch can be any PyTree-compatible object (
tuple
,list
,dict
, …), including other modules. Parametric functions, such as neural networks, should subclassModule
and indicate their parameters withParameter
.import jax import jax.random as jrd import inox import inox.nn as nn class Linear(nn.Module): def __init__(self, key, in_features, out_features): keys = jrd.split(key, 2) self.weight = Parameter(jrd.normal(keys[0], (in_features, out_features))) self.bias = Parameter(jrd.normal(keys[1], (out_features,))) def __call__(self, x): return x @ self.weight() + self.bias() class Classifier(nn.Module): def __init__(self, key, in_features, num_classes): keys = jrd.split(key, 3) self.l1 = Linear(keys[0], in_features, 64) self.l2 = Linear(keys[1], 64, 64) self.l3 = Linear(keys[2], 64, num_classes) self.relu = nn.ReLU() self.return_logits = True # static leaf def __call__(self, x): x = self.l1(x) x = self.l2(self.relu(x)) x = self.l3(self.relu(x)) if self.return_logits: return x else: return jax.nn.softmax(x) key = jax.random.key(0) model = Classifier(key)
Modules automatically detect non-array leaves and consider them as static (part of the tree structure). This results in module instances compatible with native JAX transformations (
jax.jit
,jax.vmap
,jax.grad
, …) out of the box.import optax @jax.jit def loss_fn(model, x, y): logits = jax.vmap(model)(x) loss = optax.softmax_cross_entropy(logits, y) return jax.numpy.mean(loss) grads = jax.grad(loss_fn)(model, x, y)
However, JAX transformations are designed to work on pure functions. Some neural network layers, including batch normalization, are not pure as they hold a state which is updated as part of the layer’s execution. In this case, using a functionally pure definition of the model is safer for training, but also convenient as some internal arrays do not require gradients.
modef, params, others = model.pure(nn.Parameter) optimizer = optax.adamw(learning_rate=1e-3) opt_state = optimizer.init(params) @jax.jit def step(params, others, opt_state, x, y): # gradient descent step def loss_fn(params): model = modef(params, others) logits = jax.vmap(model)(x) loss = optax.softmax_cross_entropy(logits, y) _, _, others = model.pure(nn.Parameter) return jax.numpy.mean(loss), others grads, others = jax.grad(loss_fn, has_aux=True)(params) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, others, opt_state for x, y in trainset: # training loop params, others, opt_state = step(params, others, opt_state, x, y) model = modef(params, others)
- Parameters:
kwargs – A name-value mapping.
- train(mode=True)#
Toggles between training and evaluation modes.
This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as
inox.nn.dropout.TrainingDropout
andinox.nn.normalization.BatchNorm
.- Parameters:
mode (bool) – Whether to turn training mode on or off.
Example
>>> model.train(False) # turns off dropout
- pure(*filters)#
Splits the functional definition of the module from its state.
The state is represented by a path-array mapping split into several collections. Each collection contains the leaves of the subset of nodes satisfying a filtering constraint. The last collection is dedicated to leaves that do not satisfy any constraint.
See also
- Parameters:
filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into
isinstance
constraints.- Returns:
The module definition and state collection(s).
- Return type:
Examples
>>> modef, state = model.pure() >>> clone = modef(state)
>>> modef, params, others = model.pure(nn.Parameter) >>> params, opt_state = optimizer.update(grads, opt_state, params) >>> model = modef(params, others)
>>> model.path[2].layer.frozen = True >>> filtr = lambda x: getattr(x, 'frozen', False) >>> modef, frozen, others = model.pure(filtr)
- class inox.nn.module.ModuleDef(module)#
Abstraction for the static definition of a module.
See also
- Parameters:
module (Module) – A module instance.
- class inox.nn.module.Parameter(value)#
Wrapper to indicate optimizable arrays.
All arrays that require gradient updates in a
Module
should be wrapped in aParameter
instance.- Parameters:
value (Array) – An array.
Example
>>> weight = Parameter(jax.numpy.ones((3, 5))) >>> bias = Parameter(jax.numpy.zeros(5)) >>> def linear(x): ... return x @ weight() + bias()