inox.nn.share#

Sharing modules

In a vanilla inox.nn.module.Module, shared references to the same layer or parameter would be treaded as separate copies and their weights would not be tied. The Scope class correctly handles such cases when shared references are explicited with Reference.

import inox
import inox.nn
import jax
import jax.numpy as jnp

class WeightSharingMLP(nn.Scope):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(in_features=64, out_features=64, key=keys[0])
        self.l3 = nn.Linear(in_features=64, out_features=64, key=keys[1])
        self.l4 = nn.Linear(in_features=64, out_features=64, key=keys[2])

        self.l1 = nn.Reference('l1', self.l1)
        self.l2 = self.l1  # tied layer
        self.l3.weight = nn.Reference('l3.weight', self.l3.weight)
        self.l4.weight = self.l3.weight  # tied parameter

        self.relu = nn.ReLU()

    def __call__(self, x):  # standard __call__
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))
        x = self.l4(self.relu(x))

        return x

key = jax.random.key(0)
model = WeightSharingMLP(key)
static, params, others = model.partition(nn.Parameter)

print(inox.tree_repr(params))  # does not contain 'l2' and 'l4.weight'
{
  '.l1.value.bias.value': float32[64],
  '.l1.value.weight.value': float32[64, 64],
  '.l3.bias.value': float32[64],
  '.l3.weight.value.value': float32[64, 64],
  '.l4.bias.value': float32[64]
}

Classes#

Scope

Subclass of inox.nn.module.Module which handles shared object references within its scope.

Reference

Creates a reference to a value.

Descriptions#

class inox.nn.share.Scope(**kwargs)#

Subclass of inox.nn.module.Module which handles shared object references within its scope.

All references with the same identification tag in a scope are considered to be the same and all but one copies are pruned during the flattening of the scope tree. Cyclic references are allowed, with the exception of a scope referencing itself.

Warning

Shared references and in-place mutations are very hard to combine properly. Conversely, Reference works seamlessly with inox.nn.state utils.

Parameters:

kwargs – A name-value mapping.

class inox.nn.share.Reference(tag, value)#

Creates a reference to a value.

A Reference instance forwards __call__ and __getattr__ operations to the value it references. For indexing (__getitem__) or arithmetic operations (+, *, …), use ref.value directly instead.

See also

Scope

Parameters:
  • tag (Hashable) – An identification tag.

  • value (Any) – The value to reference.

Example

>>> weight = Reference('my-ref', Parameter(jax.random.ones((3, 5))))
>>> weight  # repr preceded by an asterisk
*Parameter(float[3, 5])
>>> weight.shape
(3, 5)
>>> weight()
Array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]], dtype=float32)
tag: Hashable#

Alias for field number 0

value: Any#

Alias for field number 1