.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_huggingface.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_huggingface.py: Usage with Huggingface LLMs =============================== This example demonstrates how to work with Huggingface (HF) language models. As always, let's first import the required functionality. Remember to run :code:`pip install -U transformers datasets` .. GENERATED FROM PYTHON SOURCE LINES 9-30 .. code-block:: Python from collections import UserDict from collections.abc import MutableMapping import torch.utils.data as data_utils from datasets import Dataset from torch import Tensor, bfloat16, eye, manual_seed, no_grad from torch.nn import CrossEntropyLoss, Module from transformers import ( DataCollatorWithPadding, GPT2Config, GPT2ForSequenceClassification, GPT2Tokenizer, PreTrainedTokenizer, ) from curvlinops import GGNLinearOperator # make deterministic manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 31-36 Data ---- We will use synthetic data for simplicity. But obviously this can be replaced with any HF dataloader. .. GENERATED FROM PYTHON SOURCE LINES 37-60 .. code-block:: Python tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id data = [ {"text": "Today is hot, but I will manage!!!!", "label": 1}, {"text": "Tomorrow is cold", "label": 0}, {"text": "Carpe diem", "label": 1}, {"text": "Tempus fugit", "label": 1}, ] dataset = Dataset.from_list(data) def tokenize(row): return tokenizer(row["text"]) dataset = dataset.map(tokenize, remove_columns=["text"]) dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) dataloader = data_utils.DataLoader( dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer) ) .. rst-class:: sphx-glr-script-out .. code-block:: none Map: 0%| | 0/4 [00:00 None: super().__init__() config = GPT2Config.from_pretrained("gpt2") config.pad_token_id = tokenizer.pad_token_id config.num_labels = 2 self.hf_model = GPT2ForSequenceClassification.from_pretrained( "gpt2", config=config ) # For simplicity, only enable grad for the last layer for p in self.hf_model.parameters(): p.requires_grad = False for p in self.hf_model.score.parameters(): p.requires_grad = True def forward(self, data: MutableMapping) -> Tensor: """ Custom forward function. Handles things like moving the input tensor to the correct device inside. Args: data: A dict-like data structure with `input_ids` inside. This is the default data structure assumed by Huggingface dataloaders. Returns: logits: An `(batch_size, n_classes)`-sized tensor of logits. """ device = next(self.parameters()).device input_ids = data["input_ids"].to(device) output_dict = self.hf_model(input_ids) return output_dict.logits model = MyGPT2(tokenizer).to(bfloat16) with no_grad(): logits = model(data) print(f"Logits shape: {logits.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked. You may ignore this warning if your `pad_token_id` (50256) is identical to the `bos_token_id` (50256), `eos_token_id` (50256), or the `sep_token_id` (None), and your input is not padded. Logits shape: torch.Size([4, 2]) .. GENERATED FROM PYTHON SOURCE LINES 136-143 Curvlinops ---------- We are now ready to compute the curvature of this HF model using Curvlinops. For this, we need to define a function to tell Curvlinops how to get the batch size of the :code:`UserDict` input batch. Everything else is unchanged from the standard usage of Curvlinops! .. GENERATED FROM PYTHON SOURCE LINES 144-166 .. code-block:: Python def batch_size_fn(x: MutableMapping): return x["input_ids"].shape[0] params = [p for p in model.parameters() if p.requires_grad] ggn = GGNLinearOperator( model, CrossEntropyLoss(), params, [(data, data["labels"])], # We still need to input a list of "(X, y)" pairs! check_deterministic=False, batch_size_fn=batch_size_fn, # Remember to specify this! ) G = ggn @ eye(ggn.shape[0], device=params[0].device) print(f"GGN shape: {G.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none GGN shape: torch.Size([1536, 1536]) .. GENERATED FROM PYTHON SOURCE LINES 167-173 Conclusion ---------- This :code:`UserDict` (or any other dict-like data structure) specification is very flexible. This doesn't stop at HF models. You can leverage this for any custom models! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.395 seconds) .. _sphx_glr_download_basic_usage_example_huggingface.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_huggingface.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_huggingface.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_huggingface.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_