inox.nn.share¶
Sharing modules
In JAX, implicit shared references (objects with the same Python id) to a node in
a tree are treated as distinct objects. While this design choice is reasonable, it makes
it difficult to express that two layers in a module are identical, and preserve their
shared identity through transformations.
Inox provides a mechanism to express explicitly such identity. During flattening of a
Scope module, if several Reference instances in the tree share the
same identification tag, the first occurence (depth-first order) is preserved, while the
following occurences are pruned. During unflattening, the pruned occurences are filled
in with the preserved occurence, which preserves their shared identity.
import inox
import inox.nn as nn
import jax
import jax.numpy as jnp
class WeightSharingMLP(nn.Scope):
def __init__(self, key):
keys = jax.random.split(key, 3)
self.l1 = nn.Linear(in_features=64, out_features=64, key=keys[0])
self.l3 = nn.Linear(in_features=64, out_features=64, key=keys[1])
self.l4 = nn.Linear(in_features=64, out_features=64, key=keys[2])
self.l1 = nn.Reference('l1', self.l1)
self.l2 = self.l1 # tied layer
self.l3.weight = nn.Reference('l3.weight', self.l3.weight)
self.l4.weight = self.l3.weight # tied parameter
self.relu = nn.ReLU()
def __call__(self, x): # standard __call__
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
x = self.l4(self.relu(x))
return x
key = jax.random.key(0)
model = WeightSharingMLP(key)
static, params, others = model.partition(nn.Parameter)
print(inox.tree.prepr(params)) # does not contain 'l2' and 'l4.weight'
{
'.l1.value.bias.value': float32[64],
'.l1.value.weight.value': float32[64, 64],
'.l3.bias.value': float32[64],
'.l3.weight.value.value': float32[64, 64],
'.l4.bias.value': float32[64]
}
Classes¶
Subclass of |
|
Creates a reference to an object. |
Descriptions¶
- class inox.nn.share.Scope(**kwargs)¶
Subclass of
inox.nn.module.Modulewhich handles shared object references within its scope.- Parameters:
kwargs – A name-value mapping.
- class inox.nn.share.Reference(tag, obj)¶
Creates a reference to an object.
A
Referenceinstance forwards__call__,__iter__,__getattr__, and__getitem__operations to the object it references. For arithmetic operations (+,*, …), useref.objdirectly instead.Example
>>> dummy = ["zero", nn.Parameter(jax.numpy.ones((2, 3)))] >>> dummy = Reference("dummy-list", dummy) >>> dummy # repr preceded by asterisk *['zero', Parameter(float32[2, 3])] >>> len(dummy) 2 >>> "zero" in dummy True >>> dummy[1].shape (2, 3)