inox.nn.state#

Stateful modules

In Inox, in-place module mutations are not prohibited, but are not recommended as they often lead to silent errors around JAX transformations. Instead, it is safer to externalize the state of modules and handle mutations explicitely.

The inox.nn.state module provides a simple interface to declare the state of modules and apply state updates.

import inox
import inox.nn as nn
import jax
import jax.numpy as jnp

class Moments(nn.Module):
    def __init__(self, features):
        self.first = nn.StateEntry(jnp.zeros(features))
        self.second = nn.StateEntry(jnp.ones(features))

    def __call__(self, x, state):
        first = state[self.first]
        second = state[self.second]

        state = update_state(state, {
            self.first: 0.9 * first + 0.1 * x,
            self.second: 0.9 * second + 0.1 * x**2,
        })

        return state

class MLP(nn.Module):
    def __init__(self, in_features, num_classes, key):
        keys = jax.random.split(key, 3)

        self.in_stats = Moments(in_features)
        self.out_stats = Moments(num_classes)

        self.l1 = nn.Linear(in_features, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, num_classes, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x, state):
        state = self.in_stats(x, state)

        x = self.l1(x)
        x = self.relu()
        x = self.l2(x)
        x = self.relu()
        x = self.l3(x)

        state = self.out_stats(x, state)

        return x, state

key = jax.random.key(0)
model = MLP(16, 3, key)
model, state = nn.export_state(model)

y, state = model(x, state)

Functions#

update_state

Creates a copy of the state dictionary and updates it.

export_state

Pulls the state entries out of a tree.

Classes#

StateEntry

Wrapper to indicate a state entry.

StateKey

Wrapper to indicate a state key.

Descriptions#

class inox.nn.state.StateEntry(value)#

Wrapper to indicate a state entry.

Parameters:

value (Any) – A value.

value: Any#

Alias for field number 0

class inox.nn.state.StateKey(key)#

Wrapper to indicate a state key.

Parameters:

key (Hashable) – An hashable key.

key: Hashable#

Alias for field number 0

inox.nn.state.update_state(state, mutation)#

Creates a copy of the state dictionary and updates it.

Parameters:
  • state (Dict) – The state dictionary.

  • mutation (Dict) – The update.

Returns:

The updated state dictionary.

Return type:

Dict

inox.nn.state.export_state(tree)#

Pulls the state entries out of a tree.

State entries are replaced by state keys which can be used to index the state dictionary.

Parameters:

tree (PyTree) – A tree or module.

Returns:

The stateless tree and the state dictionary.

Return type:

Tuple[PyTree, Dict]

Example

>>> tree = {'a': 1, 'b': StateEntry(jax.numpy.zeros(2))}
>>> tree, state = export_state(tree)
>>> tree
{'a': 1, 'b': StateKey("['b']")}
>>> state
{StateKey("['b']"): Array([0., 0.], dtype=float32)}
>>> state[tree['b']]
Array([0., 0.], dtype=float32)