inox.nn.share

Sharing modules

In JAX, implicit shared references (objects with the same Python id) to a node in a tree are treated as distinct objects. While this design choice is reasonable, it makes it difficult to express that two layers in a module are identical, and preserve their shared identity through transformations.

Inox provides a mechanism to express explicitly such identity. During flattening of a Scope module, if several Reference instances in the tree share the same identification tag, the first occurence (depth-first order) is preserved, while the following occurences are pruned. During unflattening, the pruned occurences are filled in with the preserved occurence, which preserves their shared identity.

import inox
import inox.nn as nn
import jax
import jax.numpy as jnp

class WeightSharingMLP(nn.Scope):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(in_features=64, out_features=64, key=keys[0])
        self.l3 = nn.Linear(in_features=64, out_features=64, key=keys[1])
        self.l4 = nn.Linear(in_features=64, out_features=64, key=keys[2])

        self.l1 = nn.Reference('l1', self.l1)
        self.l2 = self.l1  # tied layer
        self.l3.weight = nn.Reference('l3.weight', self.l3.weight)
        self.l4.weight = self.l3.weight  # tied parameter

        self.relu = nn.ReLU()

    def __call__(self, x):  # standard __call__
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))
        x = self.l4(self.relu(x))

        return x

key = jax.random.key(0)
model = WeightSharingMLP(key)
static, params, others = model.partition(nn.Parameter)

print(inox.tree.prepr(params))  # does not contain 'l2' and 'l4.weight'
{
  '.l1.value.bias.value': float32[64],
  '.l1.value.weight.value': float32[64, 64],
  '.l3.bias.value': float32[64],
  '.l3.weight.value.value': float32[64, 64],
  '.l4.bias.value': float32[64]
}

Classes

Scope

Subclass of inox.nn.module.Module which handles shared object references within its scope.

Reference

Creates a reference to an object.

Descriptions

class inox.nn.share.Scope(**kwargs)

Subclass of inox.nn.module.Module which handles shared object references within its scope.

Parameters:

kwargs – A name-value mapping.

class inox.nn.share.Reference(tag, obj)

Creates a reference to an object.

A Reference instance forwards __call__, __iter__, __getattr__, and __getitem__ operations to the object it references. For arithmetic operations (+, *, …), use ref.obj directly instead.

Parameters:
  • tag (Hashable) – An identification tag.

  • obj (Any) – The object to reference.

Example

>>> dummy = ["zero", nn.Parameter(jax.numpy.ones((2, 3)))]
>>> dummy = Reference("dummy-list", dummy)
>>> dummy  # repr preceded by asterisk
*['zero', Parameter(float32[2, 3])]
>>> len(dummy)
2
>>> "zero" in dummy
True
>>> dummy[1].shape
(2, 3)