inox.nn.state

Stateful modules

In Inox, in-place module mutations are not prohibited, but are not recommended as they often lead to silent errors around JAX transformations. Instead, it is safer to externalize the state of modules and handle mutations explicitely.

The inox.nn.state module provides a simple interface to declare the state of modules and apply state updates. During initialization, mutable arrays are wrapped in StateEntry instances. After initialization, these arrays are pulled out and replaced with hashable StateKey instances using the export_state function. The state is represented by a dictionary which is used and updated during the module’s execution.

import inox
import inox.nn as nn
import jax
import jax.numpy as jnp

class Moments(nn.Module):
    def __init__(self, features):
        self.first = nn.StateEntry(jnp.zeros(features))
        self.second = nn.StateEntry(jnp.ones(features))

    def __call__(self, x, state):
        first = state[self.first]
        second = state[self.second]

        state = update_state(state, {
            self.first: 0.9 * first + 0.1 * x,
            self.second: 0.9 * second + 0.1 * x**2,
        })

        return state

class MLP(nn.Module):
    def __init__(self, in_features, num_classes, key):
        keys = jax.random.split(key, 3)

        self.in_stats = Moments(in_features)
        self.out_stats = Moments(num_classes)

        self.l1 = nn.Linear(in_features, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, num_classes, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x, state):
        state = self.in_stats(x, state)

        x = self.l1(x)
        x = self.relu()
        x = self.l2(x)
        x = self.relu()
        x = self.l3(x)

        state = self.out_stats(x, state)

        return x, state

key = jax.random.key(0)
model = MLP(16, 3, key)
model, state = nn.export_state(model)

y, state = model(x, state)

Classes

StateEntry

Wrapper to indicate a state entry.

StateKey

Wrapper to indicate a state key.

Functions

update_state

Creates a copy of the state dictionary and updates it.

export_state

Pulls the state entries out of a tree.

Descriptions

class inox.nn.state.StateEntry(value)

Wrapper to indicate a state entry.

Parameters:

value (Any) – A value.

value: Any

Alias for field number 0

class inox.nn.state.StateKey(key)

Wrapper to indicate a state key.

Parameters:

key (Hashable) – An hashable key.

key: Hashable

Alias for field number 0

inox.nn.state.update_state(state, mutation)

Creates a copy of the state dictionary and updates it.

Parameters:
  • state (Dict) – The state dictionary.

  • mutation (Dict) – The update.

Returns:

The updated state dictionary.

Return type:

Dict

inox.nn.state.export_state(tree)

Pulls the state entries out of a tree.

State entries are replaced by state keys which can be used to index the state dictionary.

Parameters:

tree (PyTree) – A tree or module.

Returns:

The stateless tree and the state dictionary.

Return type:

Tuple[PyTree, Dict]

Example

>>> tree = {'a': 1, 'b': StateEntry(jax.numpy.zeros(2))}
>>> tree, state = export_state(tree)
>>> tree
{'a': 1, 'b': StateKey("['b']")}
>>> state
{StateKey("['b']"): Array([0., 0.], dtype=float32)}
>>> state[tree['b']]
Array([0., 0.], dtype=float32)