_images/banner.svg_images/banner_dark.svg

Inox#

Inox is a minimal JAX library for neural networks with an intuitive PyTorch-like syntax. As with Equinox, modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like NNX and Serket to provide a versatile interface.

Installation#

The inox package is available on PyPI, which means it is installable via pip.

pip install inox

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/inox

Getting started#

Modules are defined with an intuitive PyTorch-like syntax,

import jax
import inox.nn as nn

init_key, data_key = jax.random.split(jax.random.key(0))

class MLP(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(3, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, 3, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        return x

model = MLP(init_key)

and are compatible with JAX transformations.

X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)

@jax.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(model, X, Y)

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

model.name = 'stainless'

@inox.jit
def loss_fn(model, x, y):
    pred = inox.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = inox.grad(loss_fn)(model, X, Y)