inox.nn.share#
Sharing modules
In a vanilla inox.nn.module.Module
, shared references to the same layer or
parameter would be treaded as separate copies and their weights would not be tied. The
Scope
class correctly handles such cases when shared references are explicited
with Reference
.
import inox
import inox.nn
import jax
import jax.numpy as jnp
class WeightSharingMLP(nn.Scope):
def __init__(self, key):
keys = jax.random.split(key, 3)
self.l1 = nn.Linear(in_features=64, out_features=64, key=keys[0])
self.l3 = nn.Linear(in_features=64, out_features=64, key=keys[1])
self.l4 = nn.Linear(in_features=64, out_features=64, key=keys[2])
self.l1 = nn.Reference('l1', self.l1)
self.l2 = self.l1 # tied layer
self.l3.weight = nn.Reference('l3.weight', self.l3.weight)
self.l4.weight = self.l3.weight # tied parameter
self.relu = nn.ReLU()
def __call__(self, x): # standard __call__
x = self.l1(x)
x = self.l2(self.relu(x))
x = self.l3(self.relu(x))
x = self.l4(self.relu(x))
return x
key = jax.random.key(0)
model = WeightSharingMLP(key)
static, params, others = model.partition(nn.Parameter)
print(inox.tree_repr(params)) # does not contain 'l2' and 'l4.weight'
{
'.l1.value.bias.value': float32[64],
'.l1.value.weight.value': float32[64, 64],
'.l3.bias.value': float32[64],
'.l3.weight.value.value': float32[64, 64],
'.l4.bias.value': float32[64]
}
Classes#
Subclass of |
|
Creates a reference to a value. |
Descriptions#
- class inox.nn.share.Scope(**kwargs)#
Subclass of
inox.nn.module.Module
which handles shared object references within its scope.All references with the same identification tag in a scope are considered to be the same and all but one copies are pruned during the flattening of the scope tree. Cyclic references are allowed, with the exception of a scope referencing itself.
Warning
Shared references and in-place mutations are very hard to combine properly. Conversely,
Reference
works seamlessly withinox.nn.state
utils.- Parameters:
kwargs – A name-value mapping.
- class inox.nn.share.Reference(tag, value)#
Creates a reference to a value.
A
Reference
instance forwards__call__
and__getattr__
operations to the value it references. For indexing (__getitem__
) or arithmetic operations (+
,*
, …), useref.value
directly instead.See also
Example
>>> weight = Reference('my-ref', Parameter(jax.random.ones((3, 5)))) >>> weight # repr preceded by an asterisk *Parameter(float[3, 5]) >>> weight.shape (3, 5) >>> weight() Array([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], dtype=float32)