inox.nn.module#

Base modules

Classes#

Module

Base class for all modules.

ModuleDef

Abstraction for the static definition of a module.

Parameter

Wrapper to indicate optimizable arrays.

Descriptions#

class inox.nn.module.Module(**kwargs)#

Base class for all modules.

A module is a PyTree whose branches are its attributes. A branch can be any PyTree-compatible object (tuple, list, dict, …), including other modules. Parametric functions, such as neural networks, should subclass Module and indicate their parameters with Parameter.

import jax
import jax.random as jrd
import inox
import inox.nn as nn

class Linear(nn.Module):
    def __init__(self, key, in_features, out_features):
        keys = jrd.split(key, 2)

        self.weight = Parameter(jrd.normal(keys[0], (in_features, out_features)))
        self.bias = Parameter(jrd.normal(keys[1], (out_features,)))

    def __call__(self, x):
        return x @ self.weight() + self.bias()

class Classifier(nn.Module):
    def __init__(self, key, in_features, num_classes):
        keys = jrd.split(key, 3)

        self.l1 = Linear(keys[0], in_features, 64)
        self.l2 = Linear(keys[1], 64, 64)
        self.l3 = Linear(keys[2], 64, num_classes)
        self.relu = nn.ReLU()

        self.return_logits = True  # static leaf

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        if self.return_logits:
            return x
        else:
            return jax.nn.softmax(x)

key = jax.random.key(0)
model = Classifier(key)

Modules automatically detect non-array leaves and consider them as static (part of the tree structure). This results in module instances compatible with native JAX transformations (jax.jit, jax.vmap, jax.grad, …) out of the box.

import optax

@jax.jit
def loss_fn(model, x, y):
    logits = jax.vmap(model)(x)
    loss = optax.softmax_cross_entropy(logits, y)

    return jax.numpy.mean(loss)

grads = jax.grad(loss_fn)(model, x, y)

However, JAX transformations are designed to work on pure functions. Some neural network layers, including batch normalization, are not pure as they hold a state which is updated as part of the layer’s execution. In this case, using a functionally pure definition of the model is safer for training, but also convenient as some internal arrays do not require gradients.

modef, params, others = model.pure(nn.Parameter)
optimizer = optax.adamw(learning_rate=1e-3)
opt_state = optimizer.init(params)

@jax.jit
def step(params, others, opt_state, x, y):  # gradient descent step
    def loss_fn(params):
        model = modef(params, others)
        logits = jax.vmap(model)(x)
        loss = optax.softmax_cross_entropy(logits, y)
        _, _, others = model.pure(nn.Parameter)

        return jax.numpy.mean(loss), others

    grads, others = jax.grad(loss_fn, has_aux=True)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, others, opt_state

for x, y in trainset:  # training loop
    params, others, opt_state = step(params, others, opt_state, x, y)

model = modef(params, others)
Parameters:

kwargs – A name-value mapping.

train(mode=True)#

Toggles between training and evaluation modes.

This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as inox.nn.dropout.TrainingDropout and inox.nn.normalization.BatchNorm.

Parameters:

mode (bool) – Whether to turn training mode on or off.

Example

>>> model.train(False)  # turns off dropout
pure(*filters)#

Splits the functional definition of the module from its state.

The state is represented by a path-array mapping split into several collections. Each collection contains the leaves of the subset of nodes satisfying a filtering constraint. The last collection is dedicated to leaves that do not satisfy any constraint.

See also

ModuleDef

Parameters:

filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

Returns:

The module definition and state collection(s).

Return type:

Tuple[ModuleDef, Dict[str, Array]]

Examples

>>> modef, state = model.pure()
>>> clone = modef(state)
>>> modef, params, others = model.pure(nn.Parameter)
>>> params, opt_state = optimizer.update(grads, opt_state, params)
>>> model = modef(params, others)
>>> model.path[2].layer.frozen = True
>>> filtr = lambda x: getattr(x, 'frozen', False)
>>> modef, frozen, others = model.pure(filtr)
class inox.nn.module.ModuleDef(module)#

Abstraction for the static definition of a module.

See also

Module.pure

Parameters:

module (Module) – A module instance.

__call__(*state)#
Parameters:

state (Dict[str, Array]) – A set of state collections.

Returns:

A new instance of the module.

Return type:

Module

class inox.nn.module.Parameter(value)#

Wrapper to indicate optimizable arrays.

All arrays that require gradient updates in a Module should be wrapped in a Parameter instance.

Parameters:

value (Array) – An array.

Example

>>> weight = Parameter(jax.numpy.ones((3, 5)))
>>> bias = Parameter(jax.numpy.zeros(5))
>>> def linear(x):
...     return x @ weight() + bias()
__call__()#
Returns:

The wrapped array.

Return type:

Array