Source code for curvlinops.hessian

"""Contains a linear operator implementation of the Hessian."""

from collections.abc import MutableMapping
from functools import cached_property, partial
from typing import Callable, List, Optional, Tuple, Union

from torch import Tensor, no_grad, vmap
from torch.func import jacrev, jvp
from torch.nn import Module, Parameter

from curvlinops._torch_base import CurvatureLinearOperator
from curvlinops.utils import make_functional_model_and_loss, split_list


def make_batch_hessian_matrix_product(
    model_func: Module,
    loss_func: Module,
    params: Tuple[Parameter, ...],
    block_sizes: Optional[List[int]] = None,
) -> Callable[[Tensor, Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...]]:
    r"""Set up function that multiplies the mini-batch Hessian onto a matrix in list format.

    Args:
        model_func: The neural network :math:`f_{\mathbf{\theta}}`.
        loss_func: The loss function :math:`\ell`.
        params: A tuple of parameters w.r.t. which the Hessian is computed.
            All parameters must be part of ``model_func.parameters()``.
        block_sizes: Sizes of parameter blocks for block-diagonal approximation.
            If ``None``, the full Hessian is used.

    Returns:
        A function that takes inputs ``X``, ``y``, and a matrix ``M`` in list
        format, and returns the mini-batch Hessian applied to ``M`` in list format.
    """
    # Determine block structure
    block_sizes = [len(params)] if block_sizes is None else block_sizes

    # Create block-specific functional calls: *block_params, X -> prediction
    block_params = split_list(list(params), block_sizes)
    block_functionals = []

    for block in block_params:
        # criterion functional c is the same for all blocks
        f_block, c = make_functional_model_and_loss(model_func, loss_func, tuple(block))
        block_functionals.append(f_block)

    @no_grad()
    def hessian_vector_product(
        X: Tensor, y: Tensor, *v: Tuple[Tensor, ...]
    ) -> Tuple[Tensor, ...]:
        """Multiply the mini-batch Hessian on a vector in list format.

        Args:
            X: Input to the DNN.
            y: Ground truth.
            *v: Vector to be multiplied with in tensor list format.

        Returns:
            Result of Hessian multiplication in list format. Has the same shape as
            ``v``, i.e. each tensor in the list has the shape of a parameter.
        """
        # Split input vectors by blocks
        v_blocks = split_list(list(v), block_sizes)

        # Set up loss functions for each block
        block_grad_fns = []

        def loss_fn(
            f: Callable[[Tuple[Tensor, ...], Union[Tensor, MutableMapping]], Tensor],
            *params: Tuple[Tensor, ...],
        ) -> Tensor:
            """Compute the mini-batch loss given the neural net and its parameters.

            Args:
                f: Functional model with signature (*params, X) -> prediction
                *params: Parameters for the functional model.

            Returns:
                Mini-batch loss.
            """
            return c(f(*params, X), y)

        for f_block, ps in zip(block_functionals, block_params):
            # Define the loss function composition for this block
            block_loss_fn = partial(loss_fn, f_block)
            block_grad_fn = jacrev(block_loss_fn, argnums=tuple(range(len(ps))))
            block_grad_fns.append(block_grad_fn)

        # Compute the HVPs per block and concatenate the results
        hvps = []
        for grad_fn, ps, vs in zip(block_grad_fns, block_params, v_blocks):
            _, hvp_block = jvp(grad_fn, tuple(ps), tuple(vs))
            hvps.extend(hvp_block)

        return tuple(hvps)

    # Parallelize over vectors to multiply onto a matrix in list format
    list_format_vmap_dims = tuple(p.ndim for p in params)  # last axis
    return vmap(
        hessian_vector_product,
        # No vmap in X, y, last-axis vmap over vector in list format
        in_dims=(None, None, *list_format_vmap_dims),
        # Vmapped output axis is last
        out_dims=list_format_vmap_dims,
        # We want each vector to be multiplied with the same mini-batch Hessian
        randomness="same",
    )


[docs] class HessianLinearOperator(CurvatureLinearOperator): r"""Linear operator for the Hessian of an empirical risk. Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. The Hessian matrix is .. math:: \nabla^2_{\mathbf{\theta}} \mathcal{L} = c \sum_{n=1}^{N} \nabla_{\mathbf{\theta}}^2 \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\,. Example: >>> from torch import rand, eye, allclose, kron, manual_seed >>> from torch.nn import Linear, MSELoss >>> from curvlinops import HessianLinearOperator >>> >>> # Create a simple linear model without bias >>> _ = manual_seed(0) # make deterministic >>> D_in, D_out = 4, 2 >>> num_data, num_batches = 10, 3 >>> model = Linear(D_in, D_out, bias=False) >>> params = list(model.parameters()) >>> loss_func = MSELoss(reduction='sum') >>> >>> # Generate synthetic dataset and chunk into batches >>> X, y = rand(num_data, D_in), rand(num_data, D_out) >>> data = list(zip(X.split(num_batches), y.split(num_batches))) >>> >>> # Create Hessian linear operator >>> H_op = HessianLinearOperator(model, loss_func, params, data) >>> >>> # Compare with the known Hessian matrix 2 I ⊗ Xᵀ X >>> H_mat = 2 * kron(eye(D_out), X.T @ X) >>> P = sum(p.numel() for p in params) >>> v = rand(P) # generate a random vector >>> (H_mat @ v).allclose(H_op @ v) True Attributes: SUPPORTS_BLOCKS: Whether the linear operator supports block operations. Default is ``True``. SELF_ADJOINT: Whether the linear operator is self-adjoint (``True`` for Hessians). """ SELF_ADJOINT: bool = True SUPPORTS_BLOCKS: bool = True @cached_property def _mp( self, ) -> Callable[ [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] ]: """Lazy initialization of batch-Hessian matrix product function. Returns: Function that computes mini-batch Hessian-vector products, given inputs ``X``, labels ``y``, and the entries ``v1, v2, ...`` of the vector in list format. Produces a list of tensors with the same shape as the input vector that represents the result of the batch-Hessian multiplication. """ return make_batch_hessian_matrix_product( self._model_func, self._loss_func, tuple(self._params), self._block_sizes ) def _matmat_batch( self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] ) -> List[Tensor]: """Apply the mini-batch Hessian to a matrix. Args: X: Input to the DNN. y: Ground truth. M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an additional trailing axis for the matrix columns. Returns: Result of Hessian multiplication in list format. Has the same shape as ``M``, i.e. each tensor in the list has the shape of a parameter and a trailing dimension of matrix columns. """ return list(self._mp(X, y, *M))