Usage with Huggingface LLMs

This example demonstrates how to work with Huggingface (HF) language models.

As always, let’s first import the required functionality. Remember to run pip install -U transformers datasets

from collections import UserDict
from collections.abc import MutableMapping

import torch.utils.data as data_utils
from datasets import Dataset
from torch import Tensor, bfloat16, eye, manual_seed, no_grad
from torch.nn import CrossEntropyLoss, Module
from transformers import (
    DataCollatorWithPadding,
    GPT2Config,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    PreTrainedTokenizer,
)

from curvlinops import GGNLinearOperator

# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>

Data

We will use synthetic data for simplicity. But obviously this can be replaced with any HF dataloader.

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id

data = [
    {"text": "Today is hot, but I will manage!!!!", "label": 1},
    {"text": "Tomorrow is cold", "label": 0},
    {"text": "Carpe diem", "label": 1},
    {"text": "Tempus fugit", "label": 1},
]
dataset = Dataset.from_list(data)


def tokenize(row):
    """Tokenize a dataset row for GPT-2.

    Returns:
        Tokenized dictionary for the given row.
    """
    return tokenizer(row["text"])


dataset = dataset.map(tokenize, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
dataloader = data_utils.DataLoader(
    dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer)
)
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.

Map:   0%|          | 0/4 [00:00<?, ? examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 166.24 examples/s]

Let’s check the batch emitted by HF. We will see that it is a UserDict, containing the input and label tensors. Note that UserDict is MutableMapping, so it is compatible with curvlinops.

data = next(iter(dataloader))
print(f"Is the data a UserDict? {isinstance(data, UserDict)}")
for k, v in data.items():
    print(k, v.shape)
Is the data a UserDict? True
input_ids torch.Size([4, 9])
attention_mask torch.Size([4, 9])
labels torch.Size([4])

Model

Curvlinops supports general UserDict inputs. However, everything must be handled inside the forward function of the model. This gives the users the most flexibility, without much overhead.

Let’s wrap the HF model to conform this requirement then.

class MyGPT2(Module):
    """Huggingface LLM wrapper.

    Args:
        tokenizer: The tokenizer used for preprocessing the text data. Needed
            since the model needs to know the padding token id.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
        """Initialize the wrapper with a tokenizer."""
        super().__init__()
        config = GPT2Config.from_pretrained("gpt2")
        config.pad_token_id = tokenizer.pad_token_id
        config.num_labels = 2
        self.hf_model = GPT2ForSequenceClassification.from_pretrained(
            "gpt2", config=config
        )

        # For simplicity, only enable grad for the last layer
        for p in self.hf_model.parameters():
            p.requires_grad = False

        for p in self.hf_model.score.parameters():
            p.requires_grad = True

    def forward(self, data: MutableMapping) -> Tensor:
        """Run the model forward pass and move inputs to the correct device.

        Args:
            data: A dict-like data structure with `input_ids` inside.
                This is the default data structure assumed by Huggingface
                dataloaders.

        Returns:
            logits: An `(batch_size, n_classes)`-sized tensor of logits.
        """
        device = next(self.parameters()).device
        input_ids = data["input_ids"].to(device)
        output_dict = self.hf_model(input_ids)
        return output_dict.logits


model = MyGPT2(tokenizer).to(bfloat16)

with no_grad():
    logits = model(data)
    print(f"Logits shape: {logits.shape}")
Loading weights:   0%|          | 0/148 [00:00<?, ?it/s]
Loading weights: 100%|██████████| 148/148 [00:00<00:00, 4464.82it/s]
[transformers] GPT2ForSequenceClassification LOAD REPORT from: gpt2
Key          | Status  |
-------------+---------+-
score.weight | MISSING |

Notes:
- MISSING:      those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
[transformers] We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
You may ignore this warning if your `pad_token_id` (50256) is identical to the `bos_token_id` (50256), `eos_token_id` (50256), or the `sep_token_id` (None), and your input is not padded.
Logits shape: torch.Size([4, 2])

Curvlinops

We are now ready to compute the curvature of this HF model using Curvlinops. For this, we need to define a function to tell Curvlinops how to get the batch size of the UserDict input batch. Everything else is unchanged from the standard usage of Curvlinops!

def batch_size_fn(x: MutableMapping):
    """Return the batch size for a dict-like batch.

    Returns:
        Batch size for the input batch.
    """
    return x["input_ids"].shape[0]


params = {n: p for n, p in model.named_parameters() if p.requires_grad}

ggn = GGNLinearOperator(
    model,
    CrossEntropyLoss(),
    params,
    [(data, data["labels"])],  # We still need to input a list of "(X, y)" pairs!
    check_deterministic=False,
    batch_size_fn=batch_size_fn,  # Remember to specify this!
)

G = ggn @ eye(ggn.shape[0], device=next(iter(params.values())).device)

print(f"GGN shape: {G.shape}")
GGN shape: torch.Size([1536, 1536])

Conclusion

This UserDict (or any other dict-like data structure) specification is very flexible. This doesn’t stop at HF models. You can leverage this for any custom models!

Total running time of the script: (0 minutes 29.323 seconds)

Gallery generated by Sphinx-Gallery