"""Contains a linear operator implementation of the Hessian."""
from collections.abc import Callable, MutableMapping
from torch import Tensor, no_grad
from torch.func import jacrev, jvp
from torch.nn import Module
from curvlinops._torch_base import CurvatureLinearOperator
from curvlinops.utils import make_functional_loss
def make_batch_hessian_vector_product(
f: Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
loss_func: Module,
) -> Callable[
[dict[str, Tensor], Tensor | MutableMapping, tuple, dict[str, Tensor]],
dict[str, Tensor],
]:
r"""Set up function that multiplies the mini-batch Hessian onto a vector in dict format.
Args:
f: Functional model with signature ``(params_dict, X) -> prediction``.
loss_func: The loss function :math:`\ell`.
Returns:
A function ``(params_dict, X, loss_args, v_dict) -> Hv`` that takes
parameters as a dict, model input ``X``, loss arguments
``loss_args = (y,)``, and a vector ``v`` as a dict, and returns the
mini-batch Hessian applied to ``v`` as a dict.
"""
c = make_functional_loss(loss_func)
@no_grad()
def hessian_vector_product(
params: dict[str, Tensor],
X: Tensor | MutableMapping,
loss_args: tuple,
v: dict[str, Tensor],
) -> dict[str, Tensor]:
"""Multiply the mini-batch Hessian on a vector in dict format.
Args:
params: Parameters of the model as a dict.
X: Input to the model.
loss_args: Arguments forwarded to the loss function, e.g. ``(y,)``.
v: Vector as a dict matching the structure of ``params``.
Returns:
Result of Hessian multiplication as a dict with the same keys as
``params``.
"""
(y,) = loss_args
def loss_fn(p: dict[str, Tensor]) -> Tensor:
"""Compute the mini-batch loss.
Args:
p: Model parameters as a dict.
Returns:
Mini-batch loss.
"""
return c(f(p, X), (y,))
_, hvp = jvp(jacrev(loss_fn), (params,), (v,))
return hvp
return hessian_vector_product
[docs]
class HessianLinearOperator(CurvatureLinearOperator):
r"""Linear operator for the Hessian of an empirical risk.
Consider the empirical risk
.. math::
\mathcal{L}(\mathbf{\theta})
=
c \sum_{n=1}^{N}
\ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)
with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for
``reduction='sum'``. The Hessian matrix is
.. math::
\nabla^2_{\mathbf{\theta}} \mathcal{L}
=
c \sum_{n=1}^{N}
\nabla_{\mathbf{\theta}}^2
\ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\,.
Example:
>>> from torch import rand, eye, allclose, kron, manual_seed
>>> from torch.nn import Linear, MSELoss
>>> from curvlinops import HessianLinearOperator
>>>
>>> # Create a simple linear model without bias
>>> _ = manual_seed(0) # make deterministic
>>> D_in, D_out = 4, 2
>>> num_data, num_batches = 10, 3
>>> model = Linear(D_in, D_out, bias=False)
>>> params = dict(model.named_parameters())
>>> loss_func = MSELoss(reduction='sum')
>>>
>>> # Generate synthetic dataset and chunk into batches
>>> X, y = rand(num_data, D_in), rand(num_data, D_out)
>>> data = list(zip(X.split(num_batches), y.split(num_batches)))
>>>
>>> # Create Hessian linear operator
>>> H_op = HessianLinearOperator(model, loss_func, params, data)
>>>
>>> # Compare with the known Hessian matrix 2 I ⊗ Xᵀ X
>>> H_mat = 2 * kron(eye(D_out), X.T @ X)
>>> P = sum(p.numel() for p in params.values())
>>> v = rand(P) # generate a random vector
>>> (H_mat @ v).allclose(H_op @ v)
True
Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint (``True`` for
Hessians).
"""
SELF_ADJOINT: bool = True
def _init_mp(self):
"""Set up the batch Hessian-vector product function, then build vmap."""
self._vp = make_batch_hessian_vector_product(self._model_func, self._loss_func)
super()._init_mp()
def _matvec_batch(
self, X: Tensor | MutableMapping, y: Tensor, v: dict[str, Tensor]
) -> dict[str, Tensor]:
"""Apply the mini-batch Hessian to a vector.
Args:
X: Input to the DNN.
y: Ground truth.
v: Vector as a dict keyed by parameter names.
Returns:
Result of Hessian-vector multiplication as a dict.
"""
return self._vp(self._params, X, (y,), v)