Model Checkpointing using Orbax

9 minute read

Published:

So say you’ve trained a model using flax, it trained fine, has a nice learning curve (train vs validation) and now you want to save it. Or, you want to save checkpoints of the model during specific stages of the training process and later, use the best checkpoints for inference. Technically, all flax modules are dataclasses and params (part of model state in flax) are what store the model, so what we need to do for checkpointing is to persist the params.

However, that poses on small problem. You see, params from flax modules are not regular python data types. They’re tree maps. In other libraries, e.g. pytorch, you can save a state dict as a regular python dictionary. Such isn’t compatible with flax. Instead, we use a package called orbax for checkpointing.

Let’s train a small cnn on cifar 10 and then add checkpointing to it. You can skip the notebook cells until the Checkpointing section if you want as they have nothing to do with checkpointing.

Dependencies

Before you begin, use the toml file to create an env with uv.

import torch
import jax_dataloader as jdl
import torchvision
import torchvision.transforms as transforms


class ToNumpy:
    def __call__(self, x: torch.Tensor):
        return x.numpy()
    
    
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToNumpy()])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


batch_size = 128

trainloader = jdl.DataLoader(trainset, backend="pytorch", batch_size=batch_size,
                             shuffle=True)
testloader = jdl.DataLoader(testset, backend="pytorch", batch_size=batch_size,
                            shuffle=False)

# classes in cifar10
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn


from einops import rearrange


class ConvNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        # convs
        out = nn.Conv(features=6, kernel_size=(5, 5))(x)
        out = nn.max_pool(out, window_shape=(2, 2))
        out = nn.Conv(features=16, kernel_size=(5, 5))(out)
        out = nn.max_pool(out, window_shape=(2, 2))

        # flatten into a vector
        # skip the batch dim
        if len(x.shape) > 3:
            out = rearrange(x, "batch c h w -> batch (c h w)")
        else:
            out = out.flatten()

        # dense
        out = nn.Dense(features=120)(out)
        out = nn.Dense(features=84)(out)
        out = nn.Dense(features=10)(out)

        return out
    

# ======================
model = ConvNet()
rng = jax.random.key(0)
params = model.init(rng, jnp.empty((3, 32, 32)))

# run a sample forward pass
logits = model.apply(params, jnp.empty((3, 32, 32)))
print(logits.shape)
(10,)
import optax
from tqdm.notebook import trange, tqdm
from flax.training import train_state


@jax.jit
def calculate_loss(params, x, y):
    logits = model.apply(params, x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return loss


@jax.jit
def batched_loss(params, xs, ys):
    batch_loss = jax.vmap(calculate_loss, in_axes=(None, 0, 0))(params, xs, ys)
    return batch_loss.mean(axis=-1)


optimiser = optax.adam(learning_rate=0.001)
criterion = jax.value_and_grad(batched_loss)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimiser
)


@jax.jit
def train_step(state, batch):
    loss_value, grads = criterion(state.params, *batch)
    updated_state = state.apply_gradients(grads=grads)
    return loss_value, updated_state
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score


@jax.jit
def test_step(state, xs):
    def infer(params, x):
        logits = model.apply(params, x)
        return jax.nn.softmax(logits, axis=-1)

    preds = jax.vmap(jax.jit(infer), in_axes=(None, 0))(state.params, xs)
    return preds


def evaluate(state, test_loader):
    scores = list()
    for batch in tqdm(test_loader):
        xs, ys = batch
        preds = test_step(state, xs)
        preds = jnp.argmax(preds, axis=-1)
        f1 = f1_score(preds, ys, average="micro")
        scores.append(f1)

    return np.array(scores).mean(axis=-1)


def custom_classification_report(state, test_loader):
    preds = list()
    actual = list()
    for batch in tqdm(test_loader):
        xs, ys = batch
        pred = test_step(state, xs)
        pred = jnp.argmax(pred, axis=-1).tolist()
        preds.extend(pred)
        actual.extend(ys.tolist())

    clf = classification_report(preds, actual, target_names=classes)
    print(clf)

Checkpointing

Let’s first prepare orbax and then go through how checkpointing via it works.

!rm -rf /tmp/cnn_cifar10_checkpointing_example
# delete the existing checkpoints
import orbax
from flax.training import orbax_utils

# since everything in jax is a pytree
# the checkpoints are basically the pytree versions of the params
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# checkpoint manager for managing how many checkpoints to keep
# keep a max of 5 checkpoints
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=5, create=True)
# structure for the checkpoint
# will be used by orbax later to restore the saved model
# you can also add other information regarding the model
# but make sure to keep the structure consistent
ckpt = {
    "state": state,
}
# tell orbax to use the structure
save_args = orbax_utils.save_args_from_target(ckpt)
# add the save path
save_path = "/tmp/cnn_cifar10_checkpointing_example"
# ckpt manager for versioning
checkpoint_manager = orbax.checkpoint.CheckpointManager(save_path, orbax_checkpointer, options)

Breaking down the mess

Since we’re training using train_state from flax, we update the params with the train_state. So it would be wise to checkpoint the entire state instead of the params only. Now, there are two stages of checkpointing using Orbax. First, we have PyTreeCheckpointer, which can save a single checkpoint. It doesn’t keep track of updates over training iterations so no matter how many iterations your training runs for, it’ll only save a checkpoint on the first call. To track different checkpoints throughout training, we need CheckpointManager.

Now, orbax will save the state as a pytree object. But if we want to restore a full state from it, the checkpoint manager provides no such direct API. So we have to improvise a bit (just follow through, comes right after the training function). Furthermore, orbax has no clue about the data structure of your checkpoints. All it cares is that it gets a pytree, as long as you supply the structure for it. The ckpt dict here provides the structure to orbax. It basically acts as a schema for the checkpoints.

Sounds complicated? Kinda is. You see in Pytorch, you can just use the model class and map a saved state_dict to it. Could flax have made it simpler? May be! You can read more here.

N.B: Checkpoint manager creates the checkpoint directory during init and maintains the directory and checkpoint metadata (paths etc.) as state. So if you want to rerun the training, delete the save_path directory first and re init the manager.

Now we can add the checkpointing code inside the train function.

def train(state, epochs, train_loader, test_loader, ckpt_manager=checkpoint_manager, save_args=save_args):
    steps = 0
    losses = []
    f1_scores = []
    
    lowest_loss = np.inf

    # =============
    for e in trange(epochs):
        for batch in tqdm(train_loader):
            loss, state = train_step(state, batch)
            steps += 1

            # log every 200 steps
            if steps % 200 == 0:
                losses.append(loss)

                # run evaluation
                print("Evaluating ... ")
                score = evaluate(state, test_loader)

                f1_scores.append(score)

                print(f"Epoch : {e + 1} :: Step : {steps} :: Loss : {loss} :: F1 : {score}")
                
                # checkpoint only if train loss decreases
                # ideally one would checkpoint on validation loss
                # but we don't have a validation split on the dataset
                # take this as an example
            
                if loss < lowest_loss:
                    print("Checkpointing")
                    lowest_loss = loss
                    # save model ckpt
                    ckpt = {
                        "state": state
                    }
                    ckpt_manager.save(steps, ckpt, save_kwargs={"save_args": save_args})

    # ============
    return state, losses, f1_scores
_, _, _ = train(state, 5, trainloader, testloader)
# don't really need these as we'll be restoring checkpoints
  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 1 :: Step : 200 :: Loss : 1.6986947059631348 :: F1 : 0.4495648734177215
Checkpointing

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 2 :: Step : 400 :: Loss : 1.507976770401001 :: F1 : 0.4664754746835443
Checkpointing
Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 2 :: Step : 600 :: Loss : 1.4895985126495361 :: F1 : 0.4799248417721519
Checkpointing

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 3 :: Step : 800 :: Loss : 1.3953628540039062 :: F1 : 0.48783623417721517
Checkpointing
Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 3 :: Step : 1000 :: Loss : 1.5259835720062256 :: F1 : 0.49831882911392406

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 4 :: Step : 1200 :: Loss : 1.2466219663619995 :: F1 : 0.5143393987341772
Checkpointing
Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 4 :: Step : 1400 :: Loss : 1.433868408203125 :: F1 : 0.4990110759493671

  0%|          | 0/391 [00:00<?, ?it/s]

Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 5 :: Step : 1600 :: Loss : 1.349440574645996 :: F1 : 0.5150316455696202
Evaluating ... 

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 5 :: Step : 1800 :: Loss : 1.4629029035568237 :: F1 : 0.5240308544303798

Now let’s check the saved checkpoints.

import os

print(os.listdir(save_path))
['200', '400', '600', '800', '1200']

Question is though, how to figure out which one is the best? If you check back on the checkpointing condition, we only saved the params when there was a new min train loss. So, based on the steps count, the last one is our best choice.

latest_step = checkpoint_manager.latest_step()
print(latest_step)
1200

Time to restore the model and run some inference on it. I’m just going to iterate through the test loader again here.

def restore_state_from_step(step, ckpt_manager=checkpoint_manager):
    restored_ckpt = ckpt_manager.restore(step)
    restored_params = restored_ckpt["state"]["params"]
    
    # create a new train state object from the params
    restored_state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=restored_params,
        tx=optimiser
    )

    return restored_state


restored_state = restore_state_from_step(latest_step)

So what we have to do to restore the model is to load the latest checkpoint, take the params from it, create a new state and return it.

evaluate(restored_state, testloader)
0%|          | 0/79 [00:00<?, ?it/s]

0.5143393987341772
custom_classification_report(restored_state, testloader)
  0%|          | 0/79 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       plane       0.59      0.53      0.56      1106
         car       0.57      0.67      0.62       859
        bird       0.38      0.34      0.36      1101
         cat       0.38      0.36      0.37      1052
        deer       0.46      0.45      0.45      1013
         dog       0.36      0.50      0.42       721
        frog       0.64      0.51      0.57      1251
       horse       0.55      0.60      0.57       918
        ship       0.67      0.63      0.65      1070
       truck       0.55      0.60      0.57       909

    accuracy                           0.51     10000
   macro avg       0.51      0.52      0.51     10000
weighted avg       0.52      0.51      0.51     10000