inox.nn.module#
Base modules
In Inox, a module is a PyTree whose branches are its attributes. A branch can be any
PyTree-compatible object (bool
, str
, list
, dict
, …), including
other modules. Parametric functions, such as neural networks, should subclass
Module
and indicate their parameters with Parameter
.
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
class Linear(nn.Module):
def __init__(self, in_features, out_features, key):
keys = jax.random.split(key, 2)
self.weight = Parameter(jax.random.normal(keys[0], (in_features, out_features)))
self.bias = Parameter(jax.random.normal(keys[1], (out_features,)))
def __call__(self, x):
return x @ self.weight() + self.bias()
class Classifier(nn.Module):
def __init__(self, in_features, num_classes, key):
keys = jax.random.split(key, 3)
self.l1 = Linear(in_features, 64, key=keys[0])
self.l2 = Linear(64, 64, key=keys[1])
self.l3 = Linear(64, num_classes, key=keys[2])
self.relu = nn.ReLU()
self.return_logits = True
def __call__(self, x):
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
if self.return_logits:
return x
else:
return jax.nn.softmax(x)
key = jax.random.key(0)
model = Classifier(16, 3, key)
Classes#
Base class for all modules. |
|
Abstraction for the static definition of a module. |
|
Wrapper to indicate an optimizable array. |
|
Wrapper to indicate an optimizable complex array. |
Descriptions#
- class inox.nn.module.Module(**kwargs)#
Base class for all modules.
- Parameters:
kwargs – A name-value mapping.
- train(mode=True)#
Toggles between training and evaluation modes.
This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as
inox.nn.dropout.TrainingDropout
andinox.nn.normalization.BatchNorm
.- Parameters:
mode (bool) – Whether to turn training mode on or off.
Example
>>> model.train(False) # turns off dropout
- partition(*filters)#
Splits the static definition of the module from its arrays.
The arrays are partitioned into a set of path-array mappings. Each mapping contains the arrays of the subset of nodes satisfying the corresponding filtering constraint. The last mapping is dedicated to arrays that do not satisfy any constraint.
See also
- Parameters:
filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into
isinstance
constraints.- Returns:
The module static definition and array partitions.
- Return type:
Examples
>>> static, arrays = model.partition() >>> clone = static(arrays)
>>> static, params, others = model.partition(nn.Parameter) >>> params, opt_state = optimizer.update(grads, opt_state, params) >>> model = static(params, others)
>>> model.path[2].layer.frozen = True >>> filtr = lambda x: getattr(x, 'frozen', False) >>> static, frozen, others = model.partition(filtr)
- class inox.nn.module.ModuleDef(treedef)#
Abstraction for the static definition of a module.
See also
- Parameters:
treedef (PyTreeDef) – A module tree definition.
- class inox.nn.module.Parameter(value)#
Wrapper to indicate an optimizable array.
All arrays that require gradient updates in a
Module
should be wrapped in aParameter
instance.- Parameters:
value (Array) – An array.
Example
>>> weight = Parameter(jax.numpy.ones((3, 5))); weight Parameter(float32[3, 5]) >>> y = x @ weight()
- class inox.nn.module.ComplexParameter(value)#
Wrapper to indicate an optimizable complex array.
The real and imaginary parts are stored as separate floating point arrays to enable gradient-based optimization.
- Parameters:
value (Array) – A complex array.
Example
>>> value = jax.numpy.ones((3, 5)) + 1j * jax.numpy.zeros((3, 5)) >>> weight = ComplexParameter(value); weight ComplexParameter(complex64[3, 5]) >>> y = x @ weight()