1. CNN on MNIST#

This tutorial demonstrates how to build a simple convolutional neural network (CNN) with Inox, and train it to classify digits. It is intended for those who are new to JAX and Inox, or simply curious.

Unlike PyTorch, which is a centralized framework, the JAX ecosystem is distributed as a collection of packages, each tackling a well-defined task. This tutorial uses Inox to build the network, Optax to optimize the parameters, and 🤗 Datasets to load the MNIST data.

# !pip install jax[cpu] inox optax datasets[jax] tqdm

import inox
import inox.nn as nn
import jax
import optax

from datasets import load_dataset
from tqdm import tqdm

jax.config.update('jax_platform_name', 'cpu')
jax.numpy.set_printoptions(suppress=True)

1.1. Data#

JAX does not provide built-in datasets and dataloaders, as there are already many alternatives. We load the MNIST dataset using 🤗 Datasets.

mnist = load_dataset('mnist')
mnist['train'][0]['image']
../_images/1aecaa0e99943d77fdcef47a591e7b7bec750f410d70173827b2ef2de23db56c.png

We transform the images into NumPy arrays which are compatible with JAX and define a pre-processing procedure to rescale the pixel values to \([0, 1]\).

mnist_np = mnist.with_format('numpy')


def process(x):
    return x / 256

1.2. Model#

Our model is a simple convolutional neural network. We define its architecture by a sequence of parametric functions, often called layers.

A few remarks:

  1. Following JAX’s random number generation (RNG) design choices, Inox layers like Linear and Conv require an RNG key for parameter initialization.

  2. Like TensorFlow, Inox adopts a channel-last convention for axes, meaning that a batch of images is expected to have a shape \((N, H, W, C)\), where \(C\) is the number of channels.

  3. Rearrange and Repeat are thin wrappers around einops that enable intuitive and efficient axis manipulations.

class CNN(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 4)

        self.layers = [
            nn.Repeat('H W -> H W C', C=1),
            nn.Conv(in_channels=1, out_channels=4, kernel_size=[3, 3], key=keys[0]),
            nn.ReLU(),
            nn.Conv(in_channels=4, out_channels=4, kernel_size=[3, 3], key=keys[1]),
            nn.MaxPool(window_size=[2, 2]),
            nn.Rearrange('H W C -> (H W C)'),
            nn.Linear(in_features=576, out_features=256, key=keys[2]),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=10, key=keys[3]),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def predict(self, x):
        return jax.nn.softmax(self(x))

Our model is a PyTree, that is just a nested collection of Python objects. Some of these objects are JAX arrays, like the convolution kernels, while others are arbitrary objects, like the pattern strings. Inox provides a nice representation for its modules.

model = CNN(jax.random.key(0))
model
CNN(
  layers = [
    Repeat(
      lengths = {'C': 1},
      pattern = 'H W -> H W C'
    ),
    Conv(
      bias = Parameter(float32[4]),
      dilation = [1, 1],
      groups = 1,
      kernel = Parameter(float32[3, 3, 1, 4]),
      kernel_size = [3, 3],
      padding = [(0, 0), (0, 0)],
      stride = [1, 1]
    ),
    ReLU(),
    Conv(
      bias = Parameter(float32[4]),
      dilation = [1, 1],
      groups = 1,
      kernel = Parameter(float32[3, 3, 4, 4]),
      kernel_size = [3, 3],
      padding = [(0, 0), (0, 0)],
      stride = [1, 1]
    ),
    MaxPool(
      padding = [(0, 0), (0, 0)],
      stride = [2, 2],
      window_size = [2, 2]
    ),
    Rearrange(
      lengths = {},
      pattern = 'H W C -> (H W C)'
    ),
    Linear(
      bias = Parameter(float32[256]),
      weight = Parameter(float32[576, 256])
    ),
    ReLU(),
    Linear(
      bias = Parameter(float32[10]),
      weight = Parameter(float32[256, 10])
    )
  ]
)

Now that our model is built, we can use it to make predictions. However, since it has not been trained yet, it is currently unable to classify the digits. In the next cell, you see that the probability it associates with each digit (0 to 9) is more or less uniform.

x = mnist_np['train'][0]['image']
y = mnist_np['train'][0]['label']

model.predict(process(x))
Array([0.09936029, 0.10123681, 0.10074387, 0.09930488, 0.0987383 ,
       0.10038019, 0.10089508, 0.10059316, 0.09874175, 0.10000562],      dtype=float32)

We can quantify the quality of our model’s predictions with their cross entropy. For perfect predictions, the cross entropy is null, making it a good training objective.

optax.softmax_cross_entropy_with_integer_labels(
    logits=model(process(x)),
    labels=y,
)
Array(2.2987905, dtype=float32)

1.3. Training#

Now that we have an objective to minimize, we can start to train the parameters \(\phi\) of our model. We use the Module.partition method to split the static definition (structure, hyper-parameters, …) of the module from its parameters and other arrays (constants, …).

static, params, others = model.partition(nn.Parameter)

print(inox.tree_repr(params))
{
  '.layers[1].bias.value': float32[4],
  '.layers[1].kernel.value': float32[3, 3, 1, 4],
  '.layers[3].bias.value': float32[4],
  '.layers[3].kernel.value': float32[3, 3, 4, 4],
  '.layers[6].bias.value': float32[256],
  '.layers[6].weight.value': float32[576, 256],
  '.layers[8].bias.value': float32[10],
  '.layers[8].weight.value': float32[256, 10]
}

We initialize an Optax optimizer (Adam) for the parameters of our model.

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

A training step consists in computing the gradients of the loss \(\ell(\phi)\), here the cross entropy, with respect to the parameters \(\phi\) using jax.grad and then updating the parameters according to the gradients. The whole procedure is compiled just-in-time (JIT) with jax.jit to make it as fast as possible.

@jax.jit
def step(params, opt_state, x, y):
    def ell(params):
        model = static(params, others)
        logits = jax.vmap(model)(process(x))
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)

        return jax.numpy.mean(loss)

    grads = jax.grad(ell)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, opt_state

Then, to train the model, we iteratively apply our training step with random batches loaded from our training set.

loader = mnist_np['train'].shuffle(seed=0).iter(batch_size=64, drop_last_batch=True)

for batch in tqdm(loader):
    params, opt_state = step(params, opt_state, batch['image'], batch['label'])

model = static(params)
937it [00:10, 87.52it/s]

1.4. Evaluation#

Now that the parameters of our model are trained, we use them to make predictions.

x = mnist['test'][0]['image']
x
../_images/e9dc73ba8b5037340066f1c4c087ed0d2dfa9a2b59bf71f7f623bfebad5f78e5.png
x = mnist_np['test'][0]['image']

model.predict(process(x))
Array([0.0000001 , 0.        , 0.00001961, 0.00014004, 0.        ,
       0.00000015, 0.        , 0.99978584, 0.00000003, 0.00005421],      dtype=float32)