Source code for curvlinops.gradient_moments

"""Contains linear operator implementation of gradient moment matrices."""

from collections.abc import MutableMapping
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum, rearrange
from torch import Tensor, zeros_like
from torch.autograd import grad
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter

from curvlinops._torch_base import CurvatureLinearOperator


[docs] class EFLinearOperator(CurvatureLinearOperator): r"""Uncentered gradient covariance as PyTorch linear operator. The uncentered gradient covariance is often called 'empirical Fisher' (EF). Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. The uncentered gradient covariance matrix is .. math:: c \sum_{n=1}^{N} \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right)^\top\,. Attributes: SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for empirical Fisher. """ supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss) SELF_ADJOINT: bool = True
[docs] def __init__( self, model_func: Callable[[Union[MutableMapping, Tensor]], Tensor], loss_func: Union[Callable[[Tensor, Tensor], Tensor], None], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, check_deterministic: bool = True, num_data: Optional[int] = None, batch_size_fn: Optional[Callable[[Union[Tensor, MutableMapping]], int]] = None, ): """Linear operator for the uncentered gradient covariance/empirical Fisher (EF). Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y. Args: model_func: A function that maps the mini-batch input X to predictions. Could be a PyTorch module representing a neural network. loss_func: Loss function criterion. Maps predictions and mini-batch labels to a scalar value. params: List of differentiable parameters used by the prediction function. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X`` could be a ``dict`` or ``UserDict``; this is useful for custom models. In this case, you must (i) specify the ``batch_size_fn`` argument, and (ii) take care of preprocessing like ``X.to(device)`` inside of your ``model.forward()`` function. Due to the sequential internal Monte-Carlo sampling, batches must be presented in the same deterministic order (no shuffling!). progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. that the data does not use `drop_last` or data augmentation. Also, the model's forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default: ``True``. This is a safeguard, only turn it off if you know what you are doing. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this needs to be specified. The intended behavior is to consume the first entry of the iterates from ``data`` and return their batch size. Raises: NotImplementedError: If the loss function differs from ``MSELoss``, ``BCEWithLogitsLoss``, or ``CrossEntropyLoss``. """ if not isinstance(loss_func, self.supported_losses): raise NotImplementedError( f"Loss must be one of {self.supported_losses}. Got: {loss_func}." ) super().__init__( model_func, loss_func, params, data, progressbar=progressbar, check_deterministic=check_deterministic, num_data=num_data, batch_size_fn=batch_size_fn, )
def _matmat_batch( self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] ) -> List[Tensor]: """Apply the mini-batch empirical Fisher to a matrix in tensor list format. Args: X: Input to the DNN. y: Ground truth. M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an additional trailing axis for the matrix columns. Returns: Result of EF multiplication in tensor list format. Has the same shape as ``M``, i.e. each tensor in the list has the shape of a parameter and a trailing dimension of matrix columns. """ output = self._model_func(X) # If >2d output we convert to an equivalent 2d output if isinstance(self._loss_func, CrossEntropyLoss): output = rearrange(output, "batch c ... -> (batch ...) c") y = rearrange(y, "batch ... -> (batch ...)") else: output = rearrange(output, "batch ... c -> (batch ...) c") y = rearrange(y, "batch ... c -> (batch ...) c") # Adjust the scale depending on the loss reduction used num_loss_terms, C = output.shape reduction_factor = { "mean": ( num_loss_terms if isinstance(self._loss_func, CrossEntropyLoss) else num_loss_terms * C ), "sum": 1.0, }[self._loss_func.reduction] # compute ∂ℓₙ/∂fₙ without reduction factor of L (grad_output,) = grad(self._loss_func(output, y), output) grad_output = grad_output.detach() * reduction_factor # Compute the pseudo-loss L' := 0.5 / c ∑ₙ fₙᵀ (gₙ gₙᵀ) fₙ where gₙ = ∂ℓₙ/∂fₙ # (detached). The GGN of L' linearized at fₙ is the empirical Fisher. # We can thus multiply with the EF by computing the GGN-vector products of L'. loss = ( 0.5 / reduction_factor * (einsum(output, grad_output, "n ..., n ... -> n") ** 2).sum() ) # Multiply the EF onto each vector in the input matrix EM = [zeros_like(m) for m in M] (num_vectors,) = {m.shape[-1] for m in M} for v in range(num_vectors): for idx, ggnvp in enumerate( ggn_vector_product_from_plist( loss, output, self._params, [m[..., v] for m in M] ) ): EM[idx][..., v].add_(ggnvp.detach()) return EM