1. CNN on MNIST#

This tutorial demonstrates how to build a simple convolutional neural network (CNN) with Inox, and train it to classify digits. It is intended for those who are new to JAX and Inox, or simply curious.

Unlike PyTorch, which is a centralized framework, the JAX ecosystem is distributed as a collection of packages, each tackling a well-defined task. This tutorial uses Inox to build the network, Optax to optimize the network, and 🤗 Datasets to load the MNIST data.

# !pip install jax[cpu] inox optax datasets[jax] tqdm

import jax
import inox
import inox.nn as nn
import optax

from datasets import load_dataset
from tqdm import tqdm

jax.config.update('jax_platform_name', 'cpu')
jax.numpy.set_printoptions(suppress=True)

1.1. Data#

JAX does not provide built-in datasets and dataloaders, as there are already many alternatives. We load the MNIST dataset using 🤗 Datasets.

mnist = load_dataset('mnist')
mnist['train'][0]['image']
../_images/1aecaa0e99943d77fdcef47a591e7b7bec750f410d70173827b2ef2de23db56c.png

We transform the images into NumPy arrays which are compatible with JAX and define a pre-processing procedure to rescale the pixel values to \([0, 1]\).

mnist_np = mnist.with_format('numpy')

def process(x):
    return x / 256

1.2. Model#

Our model is a simple convolutional neural network. We define its architecture by a sequence of parametric functions, often called layers.

A few remarks:

  1. Like TensorFlow, Inox adopts a channel-last convention for axes, meaning that a batch of images is expected to have a shape \((N, H, W, C)\), where \(C\) is the number of channels.

  2. Rearrange and Repeat are thin wrappers around einops’s rearrange and repeat that enable intuitive and efficient axis manipulations.

  3. Some layers, like Linear and Conv, require a random number generator (RNG) key for initialization.

class CNN(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 4)

        self.layers = [
            nn.Repeat('H W -> H W C', C=1),
            nn.Conv(keys[0], in_channels=1, out_channels=4, kernel_size=[3, 3]),
            nn.ReLU(),
            nn.Conv(keys[1], in_channels=4, out_channels=4, kernel_size=[3, 3]),
            nn.MaxPool(window_size=[2, 2]),
            nn.Rearrange('H W C -> (H W C)'),
            nn.Linear(keys[2], in_features=576, out_features=256),
            nn.ReLU(),
            nn.Linear(keys[3], in_features=256, out_features=10),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def predict(self, x):
        return jax.nn.softmax(self(x))

Our model is a PyTree, that is just a nested collection of Python objects. Some of these objects are JAX arrays, like the convolution kernels, while others are arbitrary objects, like the pattern strings. Inox provides a nice representation for its modules.

model = CNN(jax.random.key(0))
model
CNN(
  layers = [
    Repeat(
      lengths = {'C': 1},
      pattern = 'H W -> H W C'
    ),
    Conv(
      bias = float32[4],
      dilation = [1, 1],
      groups = 1,
      kernel = float32[3, 3, 1, 4],
      kernel_size = [3, 3],
      padding = [(0, 0), (0, 0)],
      stride = [1, 1]
    ),
    ReLU(),
    Conv(
      bias = float32[4],
      dilation = [1, 1],
      groups = 1,
      kernel = float32[3, 3, 4, 4],
      kernel_size = [3, 3],
      padding = [(0, 0), (0, 0)],
      stride = [1, 1]
    ),
    MaxPool(
      padding = [(0, 0), (0, 0)],
      stride = [2, 2],
      window_size = [2, 2]
    ),
    Rearrange(
      lengths = {},
      pattern = 'H W C -> (H W C)'
    ),
    Linear(
      bias = float32[256],
      weight = float32[576, 256]
    ),
    ReLU(),
    Linear(
      bias = float32[10],
      weight = float32[256, 10]
    )
  ]
)

Now that our model is built, we can use it to make predictions. However, since it has not been trained yet, it is currently unable to classify the digits. In the next cell, you see that the probability it associates with each digit (0 to 9) is more or less uniform.

x = mnist_np['train'][0]['image']
y = mnist_np['train'][0]['label']

model.predict(process(x))
Array([0.10137317, 0.09839384, 0.1020477 , 0.09809756, 0.09756261,
       0.10181981, 0.09610692, 0.09600347, 0.11052851, 0.0980664 ],      dtype=float32)

We can quantify the quality of our model’s predictions with their cross entropy. For perfect predictions, the cross entropy is null, making it a good training objective.

optax.softmax_cross_entropy_with_integer_labels(
    logits=model(process(x)),
    labels=y,
)
Array(2.2845504, dtype=float32)

1.3. Training#

Now that we have an objective to minimize, we can start to train the parameters \(\phi\) of our model. Because our model does not have buffers (non-optimizable arrays), the parameters are the leaves of the model tree.

params = {
    jax.tree_util.keystr(path): leaf
    for path, leaf in jax.tree_util.tree_leaves_with_path(model)
}

print(inox.tree_util.tree_repr(params))
{
  '.layers[1].bias': float32[4],
  '.layers[1].kernel': float32[3, 3, 1, 4],
  '.layers[3].bias': float32[4],
  '.layers[3].kernel': float32[3, 3, 4, 4],
  '.layers[6].bias': float32[256],
  '.layers[6].weight': float32[576, 256],
  '.layers[8].bias': float32[10],
  '.layers[8].weight': float32[256, 10]
}

We initialize an Optax optimizer (Adam) for the parameters of our model.

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model)

A training step consists in computing the gradients of the loss \(\ell(\phi)\), here the cross entropy, with respect to the parameters \(\phi\) and updating the parameters according to the gradients. The whole procedure is written as a single JIT-compiled function to make it run as fast as possible.

@jax.jit
def step(model, opt_state, x, y):
    def ell(model):
        logits = jax.vmap(model)(process(x))
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)

        return loss.mean()

    grads = jax.grad(ell)(model)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = optax.apply_updates(model, updates)

    return model, opt_state

Then, to train the model, we iteratively apply our training step with random batches loaded from our training set.

loader = mnist_np['train'].shuffle(seed=0).iter(batch_size=64, drop_last_batch=True)

for batch in tqdm(loader):
    model, opt_state = step(model, opt_state, batch['image'], batch['label'])
937it [00:13, 71.16it/s]

1.4. Evaluation#

Now that the parameters of our model are trained, we use them to make predictions.

x = mnist['test'][0]['image']
x
../_images/e9dc73ba8b5037340066f1c4c087ed0d2dfa9a2b59bf71f7f623bfebad5f78e5.png
x = mnist_np['test'][0]['image']

model.predict(process(x))
Array([0.00001288, 0.        , 0.0004001 , 0.0006646 , 0.        ,
       0.00000394, 0.        , 0.9988293 , 0.00000029, 0.00008877],      dtype=float32)