Note
Go to the end to download the full example code.
Usage with Huggingface LLMs
This example demonstrates how to work with Huggingface (HF) language models.
As always, let’s first import the required functionality.
Remember to run pip install -U transformers datasets
from collections import UserDict
from collections.abc import MutableMapping
import torch.utils.data as data_utils
from datasets import Dataset
from torch import Tensor, bfloat16, eye, manual_seed, no_grad
from torch.nn import CrossEntropyLoss, Module
from transformers import (
DataCollatorWithPadding,
GPT2Config,
GPT2ForSequenceClassification,
GPT2Tokenizer,
PreTrainedTokenizer,
)
from curvlinops import GGNLinearOperator
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>
Data
We will use synthetic data for simplicity. But obviously this can be replaced with any HF dataloader.
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
data = [
{"text": "Today is hot, but I will manage!!!!", "label": 1},
{"text": "Tomorrow is cold", "label": 0},
{"text": "Carpe diem", "label": 1},
{"text": "Tempus fugit", "label": 1},
]
dataset = Dataset.from_list(data)
def tokenize(row):
"""Tokenize a dataset row for GPT-2.
Returns:
Tokenized dictionary for the given row.
"""
return tokenizer(row["text"])
dataset = dataset.map(tokenize, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
dataloader = data_utils.DataLoader(
dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer)
)
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Map: 0%| | 0/4 [00:00<?, ? examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 166.24 examples/s]
Let’s check the batch emitted by HF. We will see that it is a UserDict,
containing the input and label tensors. Note that UserDict is
MutableMapping, so it is compatible with curvlinops.
data = next(iter(dataloader))
print(f"Is the data a UserDict? {isinstance(data, UserDict)}")
for k, v in data.items():
print(k, v.shape)
Is the data a UserDict? True
input_ids torch.Size([4, 9])
attention_mask torch.Size([4, 9])
labels torch.Size([4])
Model
Curvlinops supports general UserDict inputs. However, everything must
be handled inside the forward function of the model. This gives
the users the most flexibility, without much overhead.
Let’s wrap the HF model to conform this requirement then.
class MyGPT2(Module):
"""Huggingface LLM wrapper.
Args:
tokenizer: The tokenizer used for preprocessing the text data. Needed
since the model needs to know the padding token id.
"""
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
"""Initialize the wrapper with a tokenizer."""
super().__init__()
config = GPT2Config.from_pretrained("gpt2")
config.pad_token_id = tokenizer.pad_token_id
config.num_labels = 2
self.hf_model = GPT2ForSequenceClassification.from_pretrained(
"gpt2", config=config
)
# For simplicity, only enable grad for the last layer
for p in self.hf_model.parameters():
p.requires_grad = False
for p in self.hf_model.score.parameters():
p.requires_grad = True
def forward(self, data: MutableMapping) -> Tensor:
"""Run the model forward pass and move inputs to the correct device.
Args:
data: A dict-like data structure with `input_ids` inside.
This is the default data structure assumed by Huggingface
dataloaders.
Returns:
logits: An `(batch_size, n_classes)`-sized tensor of logits.
"""
device = next(self.parameters()).device
input_ids = data["input_ids"].to(device)
output_dict = self.hf_model(input_ids)
return output_dict.logits
model = MyGPT2(tokenizer).to(bfloat16)
with no_grad():
logits = model(data)
print(f"Logits shape: {logits.shape}")
Loading weights: 0%| | 0/148 [00:00<?, ?it/s]
Loading weights: 100%|██████████| 148/148 [00:00<00:00, 4464.82it/s]
[transformers] GPT2ForSequenceClassification LOAD REPORT from: gpt2
Key | Status |
-------------+---------+-
score.weight | MISSING |
Notes:
- MISSING: those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
[transformers] We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
You may ignore this warning if your `pad_token_id` (50256) is identical to the `bos_token_id` (50256), `eos_token_id` (50256), or the `sep_token_id` (None), and your input is not padded.
Logits shape: torch.Size([4, 2])
Curvlinops
We are now ready to compute the curvature of this HF model using Curvlinops.
For this, we need to define a function to tell Curvlinops how to get the
batch size of the UserDict input batch. Everything else is unchanged
from the standard usage of Curvlinops!
def batch_size_fn(x: MutableMapping):
"""Return the batch size for a dict-like batch.
Returns:
Batch size for the input batch.
"""
return x["input_ids"].shape[0]
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
ggn = GGNLinearOperator(
model,
CrossEntropyLoss(),
params,
[(data, data["labels"])], # We still need to input a list of "(X, y)" pairs!
check_deterministic=False,
batch_size_fn=batch_size_fn, # Remember to specify this!
)
G = ggn @ eye(ggn.shape[0], device=next(iter(params.values())).device)
print(f"GGN shape: {G.shape}")
GGN shape: torch.Size([1536, 1536])
Conclusion
This UserDict (or any other dict-like data structure) specification
is very flexible. This doesn’t stop at HF models. You can leverage this
for any custom models!
Total running time of the script: (0 minutes 29.323 seconds)