inox.nn.module¶
Base modules
In Inox, a module is a PyTree whose branches are its attributes. A branch can be any
PyTree-compatible object (bool, str, list, dict, …), including
other modules. Parametric functions, such as neural networks, should subclass
Module and indicate their parameters with Parameter.
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
class Linear(nn.Module):
def __init__(self, in_features, out_features, key):
keys = jax.random.split(key, 2)
self.weight = nn.Parameter(jax.random.normal(keys[0], (in_features, out_features)))
self.bias = nn.Parameter(jax.random.normal(keys[1], (out_features,)))
def __call__(self, x):
return x @ self.weight() + self.bias()
class Classifier(nn.Module):
def __init__(self, in_features, num_classes, key):
keys = jax.random.split(key, 3)
self.l1 = Linear(in_features, 64, key=keys[0])
self.l2 = Linear(64, 64, key=keys[1])
self.l3 = Linear(64, num_classes, key=keys[2])
self.relu = nn.ReLU()
self.return_logits = True
def __call__(self, x):
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
if self.return_logits:
return x
else:
return jax.nn.softmax(x)
key = jax.random.key(0)
model = Classifier(16, 3, key)
Classes¶
Base class for all modules. |
|
Abstraction for the static definition of a module. |
|
Wrapper to indicate an optimizable array. |
|
Wrapper to indicate an optimizable complex array. |
Descriptions¶
- class inox.nn.module.Module(**kwargs)¶
Base class for all modules.
- Parameters:
kwargs – A name-value mapping.
- train(mode=True)¶
Toggles between training and evaluation modes.
This method is primarily useful for (sub)modules that behave differently at training and evaluation, such as
inox.nn.dropout.TrainingDropoutandinox.nn.normalization.BatchNorm.- Parameters:
mode (bool) – Whether to turn training mode on or off.
Example
>>> model = Module(dropout=nn.TrainingDropout()) >>> model Module( dropout = TrainingDropout( p = float32[] ) ) >>> model.train(False) >>> model Module( dropout = TrainingDropout( p = float32[], training = False ) )
- partition(*filters)¶
Splits the static definition of the module from its arrays.
The arrays are partitioned into a set of path-array mappings. The mapping in which an array is contained is chosen according to its oldest (closest to the root) ancestor that satisfies a constraint. If a node satisfies several constraints, the first one is selected. The last mapping is dedicated to arrays that do not satisfy any constraint.
See also
- Parameters:
filters (type | Callable[[Any], bool]) – A set of filtering constraints. Types are transformed into
isinstanceconstraints.- Returns:
The module static definition and array partitions.
- Return type:
Examples
>>> keys = jax.random.split(jax.random.key(0), 2) >>> model = Module(layers=[ ... nn.Linear(3, 64, key=keys[0]), ... nn.ReLU(), ... nn.TrainingDropout(), ... nn.Linear(64, 5, key=keys[1]), ... ]) >>> static, arrays = model.partition() >>> print(inox.tree.prepr(arrays)) { '.layers[0].bias.value': float32[64], '.layers[0].weight.value': float32[3, 64], '.layers[2].p': float32[], '.layers[3].bias.value': float32[5], '.layers[3].weight.value': float32[64, 5] }
>>> static, params, others = model.partition(nn.Parameter) >>> print(inox.tree.prepr(params)) { '.layers[0].bias.value': float32[64], '.layers[0].weight.value': float32[3, 64], '.layers[3].bias.value': float32[5], '.layers[3].weight.value': float32[64, 5] } >>> print(inox.tree.prepr(others)) {'.layers[2].p': float32[]}
>>> grads = jax.tree.map(jax.numpy.ones_like, params) >>> params = jax.tree.map(lambda x, y: x + 0.01 * y, params, grads) >>> model = static(params, others) # updated copy
>>> model.layers[3].frozen = True >>> filtr = lambda x: getattr(x, 'frozen', False) >>> static, frozen, params, others = model.partition(filtr, nn.Parameter) >>> print(inox.tree.prepr(frozen)) {'.layers[3].bias.value': float32[5], '.layers[3].weight.value': float32[64, 5]} >>> print(inox.tree.prepr(params)) {'.layers[0].bias.value': float32[64], '.layers[0].weight.value': float32[3, 64]}
- class inox.nn.module.ModuleDef(treedef, leaves)¶
Abstraction for the static definition of a module.
See also
- Parameters:
- treedef: PyTreeDef¶
Alias for field number 0
- class inox.nn.module.Parameter(value)¶
Wrapper to indicate an optimizable array.
All arrays that require gradient updates in a
Moduleshould be wrapped in aParameterinstance.- Parameters:
value (Array) – An array.
Example
>>> weight = Parameter(jax.numpy.ones((3, 5))); weight Parameter(float32[3, 5]) >>> x = jax.random.normal(jax.random.key(0), (16, 3)) >>> y = x @ weight()
- class inox.nn.module.ComplexParameter(value)¶
Wrapper to indicate an optimizable complex array.
The real and imaginary parts are stored as separate floating point arrays to enable gradient-based optimization.
- Parameters:
value (Array) – A complex array.
Example
>>> value = jax.numpy.ones((3, 5)) + 1j * jax.numpy.zeros((3, 5)) >>> weight = ComplexParameter(value); weight ComplexParameter(complex64[3, 5]) >>> x = jax.random.normal(jax.random.key(0), (16, 3)) >>> y = x @ weight()