inox.nn.share#
Sharing modules
In a vanilla inox.nn.module.Module, shared references to the same layer or
parameter would be treaded as separate copies and their weights would not be tied. The
Scope class correctly handles such cases when shared references are explicited
with Reference.
class WeightSharingMLP(nn.Scope):
def __init__(self, key):
keys = jax.random.split(key, 3)
self.l1 = nn.Linear(in_features=64, out_features=64, key=keys[0])
self.l3 = nn.Linear(in_features=64, out_features=64, key=keys[1])
self.l4 = nn.Linear(in_features=64, out_features=64, key=keys[2])
self.l1 = nn.Reference('l1', self.l1)
self.l2 = self.l1 # tied layer
self.l3.weight = nn.Reference('l3.weight', self.l3.weight)
self.l4.weight = self.l3.weight # tied parameter
self.relu = nn.ReLU()
def __call__(self, x): # standard __call__
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
x = self.l4(self.relu(x))
return x
key = jax.random.key(0)
model = WeightSharingMLP(key)
static, params, others = model.partition(nn.Parameter)
print(inox.tree_repr(params)) # does not contain 'l2' and 'l4.weight'
{
'.l1.value.bias.value': float32[64],
'.l1.value.weight.value': float32[64, 64],
'.l3.bias.value': float32[64],
'.l3.weight.value.value': float32[64, 64],
'.l4.bias.value': float32[64]
}
Classes#
Subclass of |
|
Creates a reference to a value. |
Descriptions#
- class inox.nn.share.Scope(**kwargs)#
Subclass of
inox.nn.module.Modulewhich handles shared object references within its scope.All references with the same identification tag in a scope are considered to be the same and all but one copies are pruned during the flattening of the scope tree. Cyclic references are allowed, with the exception of a scope referencing itself.
Warning
Shared references and in-place mutations are very hard to combine properly. Conversely,
Referenceworks seamlessly withinox.nn.stateutils.- Parameters:
kwargs – A name-value mapping.
- class inox.nn.share.Reference(tag, value)#
Creates a reference to a value.
A
Referenceinstance forwards__call__and__getattr__operations to the value it references. For indexing (__getitem__) or arithmetic operations (+,*, …), useref.valuedirectly instead.See also
Example
>>> weight = Reference('my-ref', Parameter(jnp.ones((3, 5)))) >>> weight # repr preceded by an asterisk *Parameter(float[3, 5]) >>> weight.shape (3, 5) >>> weight() Array([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], dtype=float32)