Source code for curvlinops.ggn

"""Contains LinearOperator implementation of the GGN."""

from collections.abc import MutableMapping
from typing import List, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from torch import Tensor, zeros_like

from curvlinops._torch_base import CurvatureLinearOperator


[docs] class GGNLinearOperator(CurvatureLinearOperator): r"""Linear operator for the generalized Gauss-Newton matrix of an empirical risk. Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. The GGN matrix is .. math:: c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)^\top \left( \nabla_{f_\mathbf{\theta}(\mathbf{x}_n)}^2 \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)\,. Attributes: SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for GGNs. """ SELF_ADJOINT: bool = True def _matmat_batch( self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] ) -> List[Tensor]: """Apply the mini-batch GGN to a matrix. Args: X: Input to the DNN. y: Ground truth. M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an additional trailing axis for the matrix columns. Returns: Result of GGN multiplication in list format. Has the same shape as ``M``, i.e. each tensor in the list has the shape of a parameter and a trailing dimension of matrix columns. """ output = self._model_func(X) loss = self._loss_func(output, y) # collect matrix-matrix products per parameter (num_vecs,) = {m.shape[-1] for m in M} GM = [zeros_like(m) for m in M] for n in range(num_vecs): col_n = ggn_vector_product_from_plist( loss, output, self._params, [m[..., n] for m in M] ) for GM_p, col_n_p in zip(GM, col_n): GM_p[..., n].add_(col_n_p) return GM