Source code for curvlinops.gradient_moments

"""Contains linear operator implementation of gradient moment matrices."""

from collections.abc import MutableMapping
from functools import cached_property, partial
from typing import Callable, List, Tuple, Union

from einops import einsum
from torch import Tensor, vmap
from torch.func import grad
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss, Parameter

from curvlinops._torch_base import CurvatureLinearOperator
from curvlinops.ggn import make_ggn_vector_product
from curvlinops.utils import make_functional_flattened_model_and_loss


def make_batch_ef_matrix_product(
    model_func: Module, loss_func: Module, params: Tuple[Parameter, ...]
) -> Callable[
    [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...]
]:
    r"""Set up function that multiplies the mini-batch empirical Fisher onto a matrix.

    The empirical Fisher is computed as the GGN of a pseudo-loss that is quadratic
    in the gradients of the original loss. Specifically, for loss gradients
    :math:`g_n = \nabla_f \ell(f_n, y_n)`, the pseudo-loss is:

    .. math::
        L'(\mathbf{\theta}) = \frac{1}{2c} \sum_{n=1}^{N} \langle f_n, g_n \rangle^2

    where :math:`c` is the reduction factor and :math:`f_n = f_{\mathbf{\theta}}(x_n)`.
    The GGN of this pseudo-loss equals the empirical Fisher of the original loss.

    Args:
        model_func: The neural network :math:`f_{\mathbf{\theta}}`.
        loss_func: The loss function :math:`\ell`.
        params: A tuple of parameters w.r.t. which the empirical Fisher is computed.
            All parameters must be part of ``model_func.parameters()``.

    Returns:
        A function that takes inputs ``X``, ``y``, and a matrix ``M`` in list
        format, and returns the mini-batch empirical Fisher applied to ``M`` in
        list format.
    """
    f_flat, c_flat = make_functional_flattened_model_and_loss(
        model_func, loss_func, params
    )
    # function that computes gradients of the loss w.r.t. the flattened outputs
    c_flat_grad = grad(c_flat, argnums=0)

    def c_pseudo_flat(output_flat: Tensor, y: Tensor) -> Tensor:
        """Compute pseudo-loss: L' = 0.5 / c * sum_n <f_n, g_n>^2.

        This pseudo-loss L' := 0.5 / c ∑ₙ fₙᵀ (gₙ gₙᵀ) fₙ where gₙ = ∂ℓₙ/∂fₙ
        (detached). The GGN of L' linearized at fₙ is the empirical Fisher.
        We can thus multiply with the EF by computing the GGN-vector products of L'.

        The reduction factor adjusts the scale depending on the loss reduction used.

        Args:
            output_flat: Flattened model outputs for the mini-batch.
            y: Un-flattened labels for the mini-batch.

        Returns:
            The pseudo-loss whose GGN is the empirical Fisher on the batch.
        """
        # Compute ∂ℓₙ/∂fₙ without reduction factor of L (detached)
        grad_output_flat = c_flat_grad(output_flat.detach(), y)

        # Adjust the scale depending on the loss reduction used
        num_loss_terms, C = output_flat.shape
        reduction_factor = {
            "mean": (
                num_loss_terms
                if isinstance(loss_func, CrossEntropyLoss)
                else num_loss_terms * C
            ),
            "sum": 1.0,
        }[loss_func.reduction]

        # compute the pseudo-loss
        grad_output_flat = grad_output_flat * reduction_factor
        inner_products = einsum(output_flat, grad_output_flat, "n ..., n ... -> n")
        return 0.5 / reduction_factor * (inner_products**2).sum()

    # Create the functional EF-vector product using GGN of pseudo-loss
    ef_vp = make_ggn_vector_product(f_flat, c_pseudo_flat)

    # Freeze parameter values
    efvp = partial(ef_vp, params)  # X, y, *v -> *EFv

    # Parallelize over vectors to multiply onto a matrix in list format
    list_format_vmap_dims = tuple(p.ndim for p in params)  # last axis
    return vmap(
        efvp,
        # No vmap in X, y, assume last axis is vmapped in the matrix list
        in_dims=(None, None, *list_format_vmap_dims),
        # Vmapped output axis is last
        out_dims=list_format_vmap_dims,
        # We want each vector to be multiplied with the same mini-batch EF
        randomness="same",
    )


[docs] class EFLinearOperator(CurvatureLinearOperator): r"""Uncentered gradient covariance as PyTorch linear operator. The uncentered gradient covariance is often called 'empirical Fisher' (EF). Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. The uncentered gradient covariance matrix is .. math:: c \sum_{n=1}^{N} \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right)^\top\,. Attributes: SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for empirical Fisher. """ SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss) SELF_ADJOINT: bool = True @cached_property def _mp( self, ) -> Callable[ [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] ]: """Lazy initialization of the batch empirical Fisher matrix product function. Returns: Function that computes mini-batch EF-vector products, given inputs ``X``, labels ``y``, and the entries ``v1, v2, ...`` of the vector in list format. Produces a list of tensors with the same shape as the input vector that re- presents the result of the batch-EF multiplication. Raises: NotImplementedError: If the loss function is not supported. """ if not isinstance(self._loss_func, self.SUPPORTED_LOSSES): raise NotImplementedError( f"Loss must be one of {self.SUPPORTED_LOSSES}. Got: {self._loss_func}." ) return make_batch_ef_matrix_product( self._model_func, self._loss_func, tuple(self._params) ) def _matmat_batch( self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] ) -> List[Tensor]: """Apply the mini-batch empirical Fisher to a matrix in tensor list format. Args: X: Input to the DNN. y: Ground truth. M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an additional trailing axis for the matrix columns. Returns: Result of EF multiplication in tensor list format. Has the same shape as ``M``, i.e. each tensor in the list has the shape of a parameter and a trailing dimension of matrix columns. """ return list(self._mp(X, y, *M))