inox.nn.state#
Stateful modules
In Inox, in-place module mutations are not prohibited, but are not recommended as they often lead to silent errors around JAX transformations. Instead, it is safer to externalize the state of modules and handle mutations explicitely.
The inox.nn.state
module provides a simple interface to declare the state of
modules and apply state updates.
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
class Moments(nn.Module):
def __init__(self, features):
self.first = nn.StateEntry(jnp.zeros(features))
self.second = nn.StateEntry(jnp.ones(features))
def __call__(self, x, state):
first = state[self.first]
second = state[self.second]
state = update_state(state, {
self.first: 0.9 * first + 0.1 * x,
self.second: 0.9 * second + 0.1 * x**2,
})
return state
class MLP(nn.Module):
def __init__(self, in_features, num_classes, key):
keys = jax.random.split(key, 3)
self.in_stats = Moments(in_features)
self.out_stats = Moments(num_classes)
self.l1 = nn.Linear(in_features, 64, key=keys[0])
self.l2 = nn.Linear(64, 64, key=keys[1])
self.l3 = nn.Linear(64, num_classes, key=keys[2])
self.relu = nn.ReLU()
def __call__(self, x, state):
state = self.in_stats(x, state)
x = self.l1(x)
x = self.relu()
x = self.l2(x)
x = self.relu()
x = self.l3(x)
state = self.out_stats(x, state)
return x, state
key = jax.random.key(0)
model = MLP(16, 3, key)
model, state = nn.export_state(model)
y, state = model(x, state)
Functions#
Creates a copy of the state dictionary and updates it. |
|
Pulls the state entries out of a tree. |
Classes#
Wrapper to indicate a state entry. |
|
Wrapper to indicate a state key. |
Descriptions#
- class inox.nn.state.StateEntry(value)#
Wrapper to indicate a state entry.
- Parameters:
value (Any) – A value.
- class inox.nn.state.StateKey(key)#
Wrapper to indicate a state key.
- Parameters:
key (Hashable) – An hashable key.
- inox.nn.state.update_state(state, mutation)#
Creates a copy of the state dictionary and updates it.
- inox.nn.state.export_state(tree)#
Pulls the state entries out of a tree.
State entries are replaced by state keys which can be used to index the state dictionary.
- Parameters:
tree (PyTree) – A tree or module.
- Returns:
The stateless tree and the state dictionary.
- Return type:
Example
>>> tree = {'a': 1, 'b': StateEntry(jax.numpy.zeros(2))} >>> tree, state = export_state(tree) >>> tree {'a': 1, 'b': StateKey("['b']")} >>> state {StateKey("['b']"): Array([0., 0.], dtype=float32)} >>> state[tree['b']] Array([0., 0.], dtype=float32)