inox.nn.module#

Base modules

In Inox, a module is a PyTree whose branches are its attributes. A branch can be any PyTree-compatible object (bool, str, list, dict, …), including other modules. Parametric functions, such as neural networks, should subclass Module and indicate their parameters with Parameter.

import inox
import inox.nn as nn
import jax
import jax.numpy as jnp

class Linear(nn.Module):
    def __init__(self, in_features, out_features, key):
        keys = jax.random.split(key, 2)

        self.weight = Parameter(jax.random.normal(keys[0], (in_features, out_features)))
        self.bias = Parameter(jax.random.normal(keys[1], (out_features,)))

    def __call__(self, x):
        return x @ self.weight() + self.bias()

class Classifier(nn.Module):
    def __init__(self, in_features, num_classes, key):
        keys = jax.random.split(key, 3)

        self.l1 = Linear(in_features, 64, key=keys[0])
        self.l2 = Linear(64, 64, key=keys[1])
        self.l3 = Linear(64, num_classes, key=keys[2])
        self.relu = nn.ReLU()

        self.return_logits = True

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        if self.return_logits:
            return x
        else:
            return jax.nn.softmax(x)

key = jax.random.key(0)
model = Classifier(16, 3, key)

Classes#

Module

Base class for all modules.

ModuleDef

Abstraction for the static definition of a module.

Parameter

Wrapper to indicate an optimizable array.

ComplexParameter

Wrapper to indicate an optimizable complex array.

Descriptions#

class inox.nn.module.Module(**kwargs)#

Base class for all modules.

Parameters:

kwargs – A name-value mapping.

train(mode=True)#

Toggles between training and evaluation modes.

This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as inox.nn.dropout.TrainingDropout and inox.nn.normalization.BatchNorm.

Parameters:

mode (bool) – Whether to turn training mode on or off.

Example

>>> model.train(False)  # turns off dropout
partition(*filters)#

Splits the static definition of the module from its arrays.

The arrays are partitioned into a set of path-array mappings. Each mapping contains the arrays of the subset of nodes satisfying the corresponding filtering constraint. The last mapping is dedicated to arrays that do not satisfy any constraint.

Parameters:

filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into isinstance constraints.

Returns:

The module static definition and array partitions.

Return type:

Tuple[ModuleDef, Dict[str, Array]]

Examples

>>> static, arrays = model.partition()
>>> clone = static(arrays)
>>> static, params, others = model.partition(nn.Parameter)
>>> params, opt_state = optimizer.update(grads, opt_state, params)
>>> model = static(params, others)
>>> model.path[2].layer.frozen = True
>>> filtr = lambda x: getattr(x, 'frozen', False)
>>> static, frozen, others = model.partition(filtr)
class inox.nn.module.ModuleDef(treedef)#

Abstraction for the static definition of a module.

See also

Module.partition

Parameters:

treedef (PyTreeDef) – A module tree definition.

__call__(*arrays)#
Parameters:

arrays (Dict[str, Array]) – A set of array partitions.

Returns:

A new instance of the module.

Return type:

Module

class inox.nn.module.Parameter(value)#

Wrapper to indicate an optimizable array.

All arrays that require gradient updates in a Module should be wrapped in a Parameter instance.

Parameters:

value (Array) – An array.

Example

>>> weight = Parameter(jax.numpy.ones((3, 5))); weight
Parameter(float32[3, 5])
>>> y = x @ weight()
__call__()#
Returns:

The wrapped array.

Return type:

Array

class inox.nn.module.ComplexParameter(value)#

Wrapper to indicate an optimizable complex array.

The real and imaginary parts are stored as separate floating point arrays to enable gradient-based optimization.

Parameters:

value (Array) – A complex array.

Example

>>> value = jax.numpy.ones((3, 5)) + 1j * jax.numpy.zeros((3, 5))
>>> weight = ComplexParameter(value); weight
ComplexParameter(complex64[3, 5])
>>> y = x @ weight()