Source code for curvlinops.ggn_diagonal

"""Contains a linear operator class for the diagonal of the GGN matrix."""

from collections.abc import Callable, Iterable, MutableMapping

from torch import Tensor
from torch.nn import Module

from curvlinops.computers.ggn_diagonal import GGNDiagonalComputer
from curvlinops.diag import DiagonalLinearOperator


[docs] class GGNDiagonalLinearOperator(DiagonalLinearOperator): r"""Diagonal linear operator representing the GGN diagonal. Computes :math:`\mathrm{diag}(\mathbf{G})` where :math:`\mathbf{G}` is the generalized Gauss-Newton matrix (see :class:`GGNLinearOperator` for the full definition). When ``mc_samples > 0``, the loss Hessian is approximated via Monte-Carlo sampling from the model's predictive distribution (see :class:`GGNLinearOperator` for details). Internally uses a :class:`GGNDiagonalComputer` to compute the diagonal, then initializes the parent :class:`DiagonalLinearOperator` with the result. """
[docs] def __init__( self, model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor], loss_func: Callable[[Tensor, Tensor], Tensor], params: dict[str, Tensor], data: Iterable[tuple[Tensor | MutableMapping, Tensor]], progressbar: bool = False, check_deterministic: bool = True, num_data: int | None = None, batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None, mc_samples: int = 0, seed: int = 2_147_483_647, ): """Initialize the GGN diagonal linear operator. Constructs a :class:`GGNDiagonalComputer` with the given arguments, computes the diagonal, and passes it to the parent class. Args: model_func: The neural network's forward pass, defining the functional relationship ``(params, X) -> prediction``. Either an ``nn.Module`` (architecture) or a callable ``(params_dict, X) -> prediction``. loss_func: Loss function criterion. Maps predictions and mini-batch labels to a scalar value. params: The parameter values at which the GGN diagonal is evaluated. A dictionary mapping parameter names to tensors. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X`` could be a ``dict`` or ``UserDict``; this is useful for custom models. In this case, you must (i) specify the ``batch_size_fn`` argument, and (ii) take care of preprocessing like ``X.to(device)`` inside of your ``model.forward()`` function. When using MC sampling, batches must be presented in the same deterministic order (no shuffling!). progressbar: Show a progressbar during computation. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. that the data does not use ``drop_last`` or data augmentation. Also, the model's forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default: ``True``. This is a safeguard, only turn it off if you know what you are doing. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: Function that computes the batch size from input data. For ``torch.Tensor`` inputs, this should typically return ``X.shape[0]``. For ``dict``/``UserDict`` inputs, this should return the batch size of the contained tensors. mc_samples: Number of Monte-Carlo samples to approximate the loss Hessian. ``0`` (default) uses the exact GGN diagonal. Positive values activate the MC approximation. seed: Seed for the internal random number generator used for MC sampling. Only used when ``mc_samples > 0``. Default: ``2147483647``. """ computer = GGNDiagonalComputer( model_func, loss_func, params, data, progressbar=progressbar, check_deterministic=check_deterministic, num_data=num_data, batch_size_fn=batch_size_fn, mc_samples=mc_samples, seed=seed, ) diagonal = computer.compute() super().__init__([diagonal[k] for k in computer._params])