Fine tuning a pretrained model from Hugging Face Transformers with flax
Published:
Pre-trained models are great. They’re trained on a lot of data us normies probably won’t be able to compile by ourselves and they also require a lot of compute to train from scratch. Ever since BERT was released, the NLP community has been using pre-trained models to fine-tune on their own datasets. This is a great way to leverage the power of these models without having to train them from scratch.
(The last two sentences were suggested by Copilot. I don’t disagree but don’t blame me for plagiarism.)
So to pay homage to the model that brought the ImageNet moment to NLP, I will show you how you can take a pre-trained BERT model from Huggingface and train it on a dataset for movie reviews.
The dataset can be found here: Pang & Lee, 2004
Also, if you’re interested in the paper behind the dataset:
@inproceedings{pang-lee-2004-sentimental,
title = "A Sentimental Education: Sentiment Analysis Using Subjectivity Summarization Based on Minimum Cuts",
author = "Pang, Bo and
Lee, Lillian",
booktitle = "Proceedings of the 42nd Annual Meeting of the Association for Computational Linguistics ({ACL}-04)",
month = jul,
year = "2004",
address = "Barcelona, Spain",
url = "https://aclanthology.org/P04-1035",
doi = "10.3115/1218955.1218990",
pages = "271--278",
}
To keep the main focus on the fine-tuning process, I will abstract the data preprocessing in a separate python script which can be found here.
Dependencies
Before you begin, use the toml file to create an env with uv
.
Dataset and Dataloaders
from utils.pre_polarity import prepare_dataset
from IPython.display import clear_output
main_dataset = prepare_dataset()
clear_output()
from loguru import logger
logger.info(f"Total dataset size: {len(main_dataset)}")
logger.info("Creating Train and Test Splits.")
train_test_dict = main_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_dict["train"]
test_dataset = train_test_dict["test"]
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Test dataset size: {len(test_dataset)}")
logger.info("Creating Train Dev Split from Train Dataset.")
train_dev_dict = train_dataset.train_test_split(test_size=0.2)
train_dataset = train_dev_dict["train"]
dev_dataset = train_dev_dict["test"]
logger.info(f"Train dataset size: {len(train_dataset)}")
logger.info(f"Dev dataset size: {len(dev_dataset)}")
2024-07-07 02:58:23.659 | INFO | __main__:<module>:3 - Total dataset size: 2000
2024-07-07 02:58:23.660 | INFO | __main__:<module>:4 - Creating Train and Test Splits.
2024-07-07 02:58:23.664 | INFO | __main__:<module>:10 - Train dataset size: 1600
2024-07-07 02:58:23.664 | INFO | __main__:<module>:11 - Test dataset size: 400
2024-07-07 02:58:23.665 | INFO | __main__:<module>:13 - Creating Train Dev Split from Train Dataset.
2024-07-07 02:58:23.667 | INFO | __main__:<module>:20 - Train dataset size: 1280
2024-07-07 02:58:23.668 | INFO | __main__:<module>:21 - Dev dataset size: 320
import numpy as np
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset
from transformers import AutoTokenizer
class PolarityReviewDataset(Dataset):
def __init__(self, dataset_split: HFDataset,
tokenizer_model_name: str = "google-bert/bert-base-uncased",
max_len: int = 512):
self.ds = dataset_split
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
self.MAX_LEN = max_len
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
review = self.ds[idx]["text"]
label = self.ds[idx]["label"]
# encode review text
encoding = self.tokenizer.encode_plus(
review,
add_special_tokens=True,
max_length=self.MAX_LEN,
truncation=True,
return_token_type_ids=False,
padding="max_length",
return_attention_mask=True,
return_tensors="np" # return numpy arrays
)
return encoding["input_ids"], encoding["attention_mask"], np.array([label])
trainset = PolarityReviewDataset(train_dataset)
devset = PolarityReviewDataset(dev_dataset)
testset = PolarityReviewDataset(test_dataset)
import jax_dataloader as jdl
BATCH_SIZE = 24 # Max I could load on an RTX 3090
train_loader = jdl.DataLoader(
trainset, "pytorch", batch_size=BATCH_SIZE, shuffle=True)
val_loader = jdl.DataLoader(
devset, "pytorch", batch_size=BATCH_SIZE, shuffle=False)
test_loader = jdl.DataLoader(
testset, "pytorch", batch_size=BATCH_SIZE, shuffle=False)
Model Definition
I would urge you to pay special attention to this part if you’re coming from pytorch. Jax works differently. So does Flax. Although BERT is available as a Flax module on the HF hub, the loading process is different than that of the pytorch version.
First of all, Flax models are immutable pytrees. Pytorch models are a container of tensors which can be mutated. So you can update or assign new params to a Pytorch model on the fly. The same is not possible with Flax models.
Second, you can’t take a Flax model with pretrained params and just assign it to a flax model with the same architecture. You have to unfreeze the new model params, then overwrite them with the pretrained params and then freeze them again. It’s like opening a pack of chips and sealing it back again so that nobody knows that you ate some.
Let’s check some code first then I will explain.
from transformers import FlaxAutoModel
def load_model(model_name: str = "google-bert/bert-base-uncased") -> tuple:
model = FlaxAutoModel.from_pretrained(model_name)
clear_output()
# extract the module and the params
module = model.module
pretrained_params = {"params": model.params}
return module, pretrained_params
As you can see, I extracted the flax module and the params from the model. Now I will define a new model and assign these params to that one.
import flax.linen as nn
from flax.core.frozen_dict import unfreeze, freeze
import jax.numpy as jnp
class SentimentCLF(nn.Module):
backbone: nn.Module # the pretrained model
@nn.compact
def __call__(self, input_ids: jnp.ndarray, attention_mask: jnp.ndarray) -> jnp.ndarray:
# forward pass
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
# pooler_output
out = out.pooler_output
# pass through a dense layer that projects to 2 labels types
out = nn.Dense(2)(out)
return out
bert_module, pretrained_params = load_model()
import jax
rng = jax.random.key(42)
model = SentimentCLF(bert_module)
sample_data = trainset[0]
input_ids, attention_mask, label = sample_data
params = model.init(rng, input_ids, attention_mask)
Unfreeze and Freeze, basically.
# unfreeze the new model
params_unfrozen = unfreeze(params)
params_unfrozen["params"]["backbone"] = pretrained_params["params"]
# freeze back
params = freeze(params_unfrozen)
That’s it. Now you can train this model as any other flax model.
Training
import optax
@jax.jit
def calculate_loss(params, input_ids, attention_mask, label):
logits = model.apply(params, input_ids, attention_mask)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
# typical numpy array thing
# should be a scalar
return loss[0]
@jax.jit
def batched_loss(params, input_ids, attention_masks, labels):
batch_loss = jax.vmap(calculate_loss, in_axes=(None, 0, 0, 0))(
params, input_ids, attention_masks, labels)
return batch_loss.mean(axis=-1)
from flax.training import train_state
clipper = optax.clip_by_global_norm(1.0)
tx = optax.chain(optax.adam(learning_rate=2e-5),
optax.clip_by_global_norm(1.0))
initial_state = train_state.TrainState.create(
apply_fn=model.apply,
tx=tx,
params=params,
)
criterion = jax.value_and_grad(batched_loss)
from sklearn.metrics import f1_score
from tqdm.auto import tqdm, trange
@jax.jit
def test_step(state, batch):
input_ids, attention_masks, _ = batch
def infer(params, input_ids, attention_mask):
logits = model.apply(params, input_ids, attention_mask)
return jax.nn.softmax(logits, axis=-1)
probas = jax.vmap(jax.jit(infer), in_axes=(None, 0, 0))(
state.params, input_ids, attention_masks)
return probas
def evaluate(state, test_loader):
scores = list()
for batch in tqdm(test_loader):
labels = batch[2]
probas = test_step(state, batch)
preds = np.argmax(probas, axis=-1)
# f1 score, never trust simple accuracy!
f1s = f1_score(labels, preds)
scores.append(f1s)
return np.array(scores).mean(axis=-1)
@jax.jit
def train_step(state, batch):
input_ids, attention_masks, labels = batch
loss_value, grads = criterion(
state.params, input_ids, attention_masks, labels)
updated_state = state.apply_gradients(grads=grads)
return loss_value, updated_state
@jax.jit
def validation_step(state, batch):
input_ids, attention_masks, labels = batch
loss_value, _ = criterion(state.params, input_ids, attention_masks, labels)
return loss_value
def train(state, epochs, train_loader, val_loader):
steps = 0
train_losses = []
mean_val_losses = []
# =============
for e in trange(epochs):
for batch in tqdm(train_loader, desc="train_step"):
train_loss, state = train_step(state, batch)
steps += 1
# log every 200 steps
if steps % 40 == 0:
train_losses.append(train_loss)
# run validation
validation_losses = []
for batch in tqdm(val_loader, desc="validation_step"):
val_loss = validation_step(state, batch)
validation_losses.append(val_loss)
mean_val_loss = np.array(validation_losses).mean(axis=-1)
mean_val_losses.append(mean_val_loss)
logger.info(
f"Epoch : {e + 1} :: Step : {steps} :: Loss/Train : {train_loss} :: Loss/Validation : {mean_val_loss}")
# ============
return state, train_losses, mean_val_losses
# ============
trained_state, train_losses, mean_val_losses = train(
initial_state, 2, train_loader, val_loader)
0%| | 0/2 [00:00<?, ?it/s]
train_step: 0%| | 0/54 [00:00<?, ?it/s]
validation_step: 0%| | 0/14 [00:00<?, ?it/s]
2024-07-07 02:59:13.291 | INFO | __main__:train:25 - Epoch : 1 :: Step : 40 :: Loss/Train : 0.5054243206977844 :: Loss/Validation : 0.3673432469367981
train_step: 0%| | 0/54 [00:00<?, ?it/s]
validation_step: 0%| | 0/14 [00:00<?, ?it/s]
2024-07-07 02:59:40.769 | INFO | __main__:train:25 - Epoch : 2 :: Step : 80 :: Loss/Train : 0.23009130358695984 :: Loss/Validation : 0.3605358302593231
Kinda sus, looks like slight ovefitting but let’s do an eval first and then I will explain.
Evaluation
Always evaluate your models. You don’t leave home without brushing teeth in the morning, do you?
score = evaluate(trained_state, test_loader)
logger.info(f"Test F1 Score: {score}")
0%| | 0/17 [00:00<?, ?it/s]
2024-07-07 02:59:58.041 | INFO | __main__:<module>:2 - Test F1 Score: 0.8734399378232043
The main problem with this dataset is that the inputs are longer than what BERT can handle (512). In the tokeniser I added truncation. This leads to information loss, so the model is basically reading halfway through the texts and is forced to make a hasty decision about the label.
Either way, our goal was to fine tune a model, we did that. In a real life scenario, always check your data and what you want from your model before you burn electricity.