Source code for curvlinops.ggn

"""Contains LinearOperator implementation of the GGN."""

from collections.abc import Callable, Iterable, MutableMapping

from einops import einsum
from torch import Tensor, manual_seed, no_grad
from torch.func import jacrev, jvp, vjp, vmap
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
from torch.random import fork_rng

from curvlinops._torch_base import CurvatureLinearOperator
from curvlinops.ggn_utils import make_grad_output_fn
from curvlinops.kfac_utils import FisherType
from curvlinops.utils import make_functional_loss


def make_ggn_vector_product(
    f: Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
    c: Callable[[Tensor, tuple], Tensor],
) -> Callable[
    [dict[str, Tensor], Tensor | MutableMapping, tuple, tuple[Tensor, ...]],
    tuple[Tensor, ...],
]:
    """Create a function that computes GGN-vector products for given f and c functions.

    Args:
        f: Function that takes parameters (as a dict) and input, returns prediction.
            Signature: ``(params_dict, X) -> prediction``
        c: Function that takes prediction and a tuple of loss arguments.
            Signature: ``(prediction, loss_args) -> loss``

    Returns:
        A function that computes GGN-vector products.
        Signature: ``(params_dict, X, loss_args, v) -> GGN @ v``
        where ``params_dict`` is a dict mapping parameter names to tensors,
        ``X`` is the model input, ``loss_args`` is a tuple of arguments
        passed to the loss function ``c`` (typically ``(y,)`` or
        ``(y, generator)``), and ``v`` is a tuple of tensors in list format.
    """

    @no_grad()
    def ggn_vector_product(
        params: dict[str, Tensor],
        X: Tensor | MutableMapping,
        loss_args: tuple,
        v: dict[str, Tensor],
    ) -> dict[str, Tensor]:
        """Multiply the GGN on a vector in dict format.

        Args:
            params: Parameters of the model as a dict.
            X: Input to the model.
            loss_args: Arguments forwarded to the loss function ``c``, e.g.
                ``(y,)`` or ``(y, generator)``.
            v: Vector as a dict matching the structure of ``params``.

        Returns:
            Result of GGN multiplication as a dict with the same keys as ``params``.
        """
        # Apply the Jacobian of f onto v: v → Jv
        f_val, f_jvp = jvp(lambda p: f(p, X), (params,), (v,))

        # Apply the criterion's Hessian onto Jv: Jv → HJv
        c_grad_func = jacrev(lambda pred: c(pred, loss_args))
        _, c_hvp = jvp(c_grad_func, (f_val,), (f_jvp,))

        # Apply the transposed Jacobian of f onto HJv: HJv → JᵀHJv
        # NOTE This re-evaluates the net's forward pass. [Unverified] It should be op-
        # timized away by common sub-expression elimination if you compile the function.
        _, f_vjp_func = vjp(lambda p: f(p, X), params)
        (result_dict,) = f_vjp_func(c_hvp)
        return result_dict

    return ggn_vector_product


def make_batch_ggn_vector_product(
    f: Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
    loss_func: Module,
) -> Callable[
    [dict[str, Tensor], Tensor | MutableMapping, tuple, dict[str, Tensor]],
    dict[str, Tensor],
]:
    r"""Set up function that multiplies the mini-batch GGN onto a vector in dict format.

    Args:
        f: Functional model with signature ``(params_dict, X) -> prediction``.
        loss_func: The loss function :math:`\ell`.

    Returns:
        A function ``(params_dict, X, loss_args, v_dict) -> Gv`` that takes
        parameters as a dict, model input ``X``, loss arguments
        ``loss_args = (y,)``, and a vector ``v`` as a dict, and returns the
        mini-batch GGN applied to ``v`` as a dict.
    """
    c = make_functional_loss(loss_func)
    return make_ggn_vector_product(f, c)


def make_batch_ggn_mc_vector_product(
    f: Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
    loss_func: Module,
    mc_samples: int,
) -> Callable[
    [dict[str, Tensor], Tensor | MutableMapping, tuple, dict[str, Tensor]],
    dict[str, Tensor],
]:
    r"""Set up function that multiplies the mini-batch MC-approximated GGN onto a vector.

    The MC approximation replaces the exact loss Hessian with a Monte-Carlo estimate
    by sampling from the model's predictive distribution. For exponential family losses
    (MSE, CrossEntropy, BCEWithLogitsLoss), the MC estimate converges to the exact GGN.

    Internally constructs a pseudo-loss whose GGN equals the MC-approximated GGN,
    using sampled gradient output vectors from
    :func:`curvlinops.ggn_utils.make_grad_output_fn`.

    Note:
        Samples from the global RNG (no explicit ``torch.Generator``) so that
        the returned function is ``torch.compile``-compatible. The caller is
        responsible for seeding, e.g. via ``fork_rng`` + ``manual_seed``
        (see :meth:`GGNLinearOperator._matmat`).

    Args:
        f: Functional model with signature ``(params_dict, X) -> prediction``.
        loss_func: The loss function :math:`\ell`.
        mc_samples: Number of Monte-Carlo samples.

    Returns:
        A function ``(params_dict, X, loss_args, v_dict) -> Gv`` that takes
        parameters as a dict, model input ``X``, loss arguments
        ``loss_args = (y,)``, and a vector ``v`` as a dict, and returns
        the mini-batch MC-GGN applied to ``v`` as a dict.
    """
    _grad_output_fn = make_grad_output_fn(loss_func, FisherType.MC, mc_samples)
    # vmap over batch: per-datum grad outputs → batched
    batched_grad_output_fn = vmap(_grad_output_fn, (0, 0), randomness="different")

    def c_pseudo(prediction: Tensor, loss_args: tuple) -> Tensor:
        r"""Pseudo-loss whose GGN equals the MC-approximated GGN.

        Constructs :math:`L' = \frac{1}{2c} \sum_n \sum_k
        \langle \mathbf{g}'_{nk}, \mathbf{f}_n \rangle^2`
        where :math:`\mathbf{g}'_{nk}` are sampled gradient output vectors (scaled
        by :math:`1/\sqrt{M}`) and :math:`c` is the reduction factor.

        Args:
            prediction: Batched model predictions.
            loss_args: Tuple of ``(y,)`` with labels.

        Returns:
            Scalar pseudo-loss.
        """
        (y,) = loss_args

        # [batch, mc_samples, *output_shape], scaled by 1/sqrt(mc_samples)
        grad_outputs = batched_grad_output_fn(prediction.detach(), y)

        # Inner products: [batch, mc_samples]
        ip = einsum(grad_outputs, prediction, "n k ..., n ... -> n k")

        batch_size = prediction.shape[0]
        reduction_factor = {"mean": batch_size, "sum": 1.0}[loss_func.reduction]

        return 0.5 / reduction_factor * (ip**2).sum()

    # Create GGN-vp of pseudo-loss: (params_dict, X, loss_args, v) -> Gv
    return make_ggn_vector_product(f, c_pseudo)


[docs] class GGNLinearOperator(CurvatureLinearOperator): r"""Linear operator for the generalized Gauss-Newton matrix of an empirical risk. Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. The GGN matrix is .. math:: c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)^\top \left( \nabla_{f_\mathbf{\theta}(\mathbf{x}_n)}^2 \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)\,. Denoting :math:`\mathbf{f}_n = f_{\mathbf{\theta}}(\mathbf{x}_n)` and using a matrix square root :math:`\mathbf{S}_n \mathbf{S}_n^\top = \nabla_{\mathbf{f}_n}^2 \ell(\mathbf{f}_n, \mathbf{y}_n)`, this can be rewritten as .. math:: c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)^\top \mathbf{S}_n \mathbf{S}_n^\top \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)\,. When ``mc_samples > 0``, the loss Hessian's square root is approximated via Monte-Carlo sampling. For exponential family losses (``MSELoss``, ``CrossEntropyLoss``, ``BCEWithLogitsLoss``), the loss Hessian equals :math:`\mathbb{E}_{\tilde{\mathbf{y}}_n \sim q(\cdot \mid \mathbf{f}_n)} [\nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n) \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n)^\top]`, where :math:`q` is the model's predictive distribution. This expectation is approximated by drawing :math:`M` samples :math:`\tilde{\mathbf{y}}_n^{(m)}` and using the sampled gradients :math:`\mathbf{g}_{nm} = \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n^{(m)})` as columns of :math:`\mathbf{S}_n`: .. math:: \nabla_{\mathbf{f}_n}^2 \ell \approx \frac{1}{M} \sum_{m=1}^{M} \mathbf{g}_{nm} \mathbf{g}_{nm}^\top\,. The MC estimate converges to the exact GGN as :math:`M \to \infty`. Attributes: SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for GGNs. MC_SUPPORTED_LOSSES: Loss functions supported by the MC approximation. """ SELF_ADJOINT: bool = True MC_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)
[docs] def __init__( self, model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor], loss_func: Callable[[Tensor, Tensor], Tensor], params: dict[str, Tensor], data: Iterable[tuple[Tensor | MutableMapping, Tensor]], progressbar: bool = False, check_deterministic: bool = True, num_data: int | None = None, batch_size_fn: Callable[[MutableMapping], int] | None = None, mc_samples: int = 0, seed: int = 2147483647, ): r"""Linear operator for the GGN of an empirical risk. Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y. Args: model_func: The neural network's forward pass, defining the functional relationship ``(params, X) -> prediction``. Either an ``nn.Module`` (architecture) or a callable ``(params_dict, X) -> prediction``. loss_func: Loss function criterion. Maps predictions and mini-batch labels to a scalar value. params: The parameter values at which the GGN is evaluated. A dictionary mapping parameter names to tensors. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X`` could be a ``dict`` or ``UserDict``; this is useful for custom models. In this case, you must (i) specify the ``batch_size_fn`` argument, and (ii) take care of preprocessing like ``X.to(device)`` inside of your ``model.forward()`` function. When using MC sampling, batches must be presented in the same deterministic order (no shuffling!). progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. that the data does not use `drop_last` or data augmentation. Also, the model's forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default: ``True``. This is a safeguard, only turn it off if you know what you are doing. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this needs to be specified. The intended behavior is to consume the first entry of the iterates from ``data`` and return their batch size. mc_samples: Number of Monte-Carlo samples to approximate the loss Hessian. ``0`` (default) uses the exact GGN. Positive values activate the MC approximation, which is only supported for ``MSELoss``, ``CrossEntropyLoss``, and ``BCEWithLogitsLoss``. seed: Seed for the internal random number generator used for MC sampling. Only used when ``mc_samples > 0``. Default: ``2147483647``. Raises: NotImplementedError: If ``mc_samples > 0`` and the loss function is not in ``MC_SUPPORTED_LOSSES``. """ self._mc_samples = mc_samples if mc_samples > 0: if not isinstance(loss_func, self.MC_SUPPORTED_LOSSES): raise NotImplementedError( f"MC-GGN requires loss in {self.MC_SUPPORTED_LOSSES}. " f"Got: {loss_func}." ) self.FIXED_DATA_ORDER = True self._seed = seed super().__init__( model_func, loss_func, params, data, progressbar=progressbar, check_deterministic=check_deterministic, num_data=num_data, batch_size_fn=batch_size_fn, )
def _matmat(self, M): """Multiply the GGN onto a matrix. Uses ``fork_rng`` to isolate the global RNG state for MC sampling, avoiding side effects on the caller's RNG. Seeding is placed here (not in ``__matmul__``) so that internal callers like ``_ChainPyTorchLinearOperator`` also get deterministic MC samples. Args: M: Matrix for multiplication in tensor list format. Returns: Matrix-multiplication result ``mat @ M`` in tensor list format. """ if self._mc_samples > 0: with fork_rng(): manual_seed(self._seed) return super()._matmat(M) return super()._matmat(M) def _init_mp(self): """Set up the batch GGN-vector product function, then build vmap.""" if self._mc_samples > 0: self._vp = make_batch_ggn_mc_vector_product( self._model_func, self._loss_func, self._mc_samples ) else: self._vp = make_batch_ggn_vector_product(self._model_func, self._loss_func) super()._init_mp() def _matvec_batch( self, X: Tensor | MutableMapping, y: Tensor, v: dict[str, Tensor] ) -> dict[str, Tensor]: """Apply the mini-batch GGN to a vector. Args: X: Input to the DNN. y: Ground truth. v: Vector as a dict keyed by parameter names. Returns: Result of GGN-vector multiplication as a dict. """ return self._vp(self._params, X, (y,), v)