Source code for curvlinops.kfac

"""Linear operator for the Fisher/GGN's Kronecker-factored approximation.

Kronecker-Factored Approximate Curvature (KFAC) was originally introduced for MLPs in

- Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored
  approximate curvature. International Conference on Machine Learning (ICML),

extended to CNNs in

- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for
  convolution layers. International Conference on Machine Learning (ICML),

and generalized to all linear layers with weight sharing in

- Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., Hennig, P. (2023).
  Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (NeurIPS).
"""

from collections.abc import Callable, Iterable, MutableMapping

from torch import Tensor
from torch.nn import (
    BCEWithLogitsLoss,
    CrossEntropyLoss,
    Module,
    MSELoss,
)

from curvlinops._torch_base import _ChainPyTorchLinearOperator
from curvlinops.blockdiagonal import BlockDiagonalLinearOperator
from curvlinops.computers._base import ParamGroup, _BaseKFACComputer
from curvlinops.computers.kfac_hooks import HooksKFACComputer
from curvlinops.computers.kfac_make_fx import MakeFxKFACComputer
from curvlinops.kfac_utils import (
    FisherType,
    FromCanonicalLinearOperator,
    KFACType,
    ToCanonicalLinearOperator,
)
from curvlinops.kronecker import KroneckerProductLinearOperator


[docs] class KFACLinearOperator(_ChainPyTorchLinearOperator): r"""Linear operator to multiply with the Fisher/GGN's KFAC approximation. KFAC approximates the per-layer Fisher/GGN with a Kronecker product: Consider a weight matrix :math:`\mathbf{W}` and a bias vector :math:`\mathbf{b}` in a single layer. The layer's Fisher :math:`\mathbf{F}(\mathbf{\theta})` for .. math:: \mathbf{\theta} = \begin{pmatrix} \mathrm{vec}(\mathbf{W}) \\ \mathbf{b} \end{pmatrix} where :math:`\mathrm{vec}` denotes column-stacking is approximated as .. math:: \mathbf{F}(\mathbf{\theta}) \approx \mathbf{A}_{(\text{KFAC})} \otimes \mathbf{B}_{(\text{KFAC})} (see :class:`curvlinops.GGNLinearOperator` with ``mc_samples > 0``). Loosely speaking, the first Kronecker factor is the un-centered covariance of the inputs to a layer. The second Kronecker factor is the un-centered covariance of 'would-be' gradients w.r.t. the layer's output. Those 'would-be' gradients result from sampling labels from the model's distribution and computing their gradients. Kronecker-Factored Approximate Curvature (KFAC) was originally introduced for MLPs in - Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored approximate curvature. International Conference on Machine Learning (ICML), extended to CNNs in - Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for convolution layers. International Conference on Machine Learning (ICML), and generalized to all linear layers with weight sharing in - Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., Hennig, P. (2023). Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (NeurIPS). Attributes: SELF_ADJOINT: Whether the operator is self-adjoint. ``True`` for KFAC. """ _BACKENDS: dict[str, type] = { "hooks": HooksKFACComputer, "make_fx": MakeFxKFACComputer, } SELF_ADJOINT: bool = True
[docs] def __init__( self, model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor], loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss, params: dict[str, Tensor], data: Iterable[tuple[Tensor | MutableMapping, Tensor]], progressbar: bool = False, check_deterministic: bool = True, seed: int = 2_147_483_647, fisher_type: str = FisherType.MC, mc_samples: int = 1, kfac_approx: str = KFACType.EXPAND, num_per_example_loss_terms: int | None = None, separate_weight_and_bias: bool = True, num_data: int | None = None, batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None, backend: str = "hooks", ): """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. Warning: This is an early proto-type with limitations: - Only Linear and Conv2d modules are supported. - The ``hooks`` backend assumes each module is called exactly once per forward pass. Weight tying (same module called multiple times) will silently produce incorrect results. Use ``backend="make_fx"`` for weight-tied architectures. Args: model_func: The neural network's forward pass, defining the functional relationship ``(params, X) -> prediction``. Either an ``nn.Module`` (architecture) or a callable ``(params_dict, X) -> prediction``. Callables require ``backend="make_fx"``. loss_func: The loss function. params: The parameter values at which the Fisher/GGN is approximated. A dictionary mapping parameter names to tensors. data: A data loader containing the data of the Fisher/GGN. progressbar: Whether to show a progress bar when computing the Kronecker factors. Defaults to ``False``. check_deterministic: Whether to check that the linear operator is deterministic. Defaults to ``True``. seed: The seed for the random number generator used to draw labels from the model's predictive distribution. Defaults to ``2147483647``. fisher_type: The type of Fisher/GGN to approximate. If ``FisherType.TYPE2``, the exact Hessian of the loss w.r.t. the model outputs is used. This requires as many backward passes as the output dimension, i.e. the number of classes for classification. This is sometimes also called type-2 Fisher. If ``FisherType.MC``, the expectation is approximated by sampling ``mc_samples`` labels from the model's predictive distribution. If ``FisherType.EMPIRICAL``, the empirical gradients are used which corresponds to the uncentered gradient covariance, or the empirical Fisher. If ``FisherType.FORWARD_ONLY``, the gradient covariances will be identity matrices, see the FOOF method in `Benzing, 2022 <https://arxiv.org/abs/2201.12250>`_ or ISAAC in `Petersen et al., 2023 <https://arxiv.org/abs/2305.00604>`_. Defaults to ``FisherType.MC``. mc_samples: The number of Monte-Carlo samples to use per data point. Has to be set to ``1`` when ``fisher_type != FisherType.MC``. Defaults to ``1``. kfac_approx: A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g. ``Conv2d`` modules or ``Linear`` modules that process matrix- or higher-dimensional features. Possible values are ``KFACType.EXPAND`` and ``KFACType.REDUCE``. See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_ for an explanation of the two approximations. Defaults to ``KFACType.EXPAND``. num_per_example_loss_terms: Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have ``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is the dimension of the random variable we define the likelihood over -- for the ``CrossEntropyLoss`` it will be the number of classes, for the ``MSELoss`` and ``BCEWithLogitsLoss`` it will be the size of the last dimension of the the model outputs/targets (our convention here). If ``None``, ``num_per_example_loss_terms`` is inferred from the data at the cost of one traversal through the data loader. It is expected to be the same for all examples. Defaults to ``None``. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. Setting this to ``False`` is more efficient because gradient covariances are computed once per layer rather than separately for weight and bias. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this needs to be specified. The intended behavior is to consume the first entry of the iterates from ``data`` and return their batch size. backend: The backend to use for computing Kronecker factors. ``"hooks"`` uses forward/backward hooks (default). ``"make_fx"`` uses FX graph tracing via the IO collector. Defaults to ``"hooks"``. Note: The ``"make_fx"`` backend incurs a significant one-time tracing overhead (seconds for large models) on the first batch. The traced function is cached by batch size, so subsequent batches of the same size reuse it. However, each distinct batch size triggers a re-trace. Use uniform batch sizes in the data loader to avoid repeated tracing. Raises: ValueError: If ``backend`` is not supported. """ if backend not in self._BACKENDS: raise ValueError( f"Invalid backend: {backend!r}. Supported: {tuple(self._BACKENDS)}." ) computer_cls = self._BACKENDS[backend] computer = computer_cls( model_func, loss_func, params, data, progressbar=progressbar, check_deterministic=check_deterministic, seed=seed, fisher_type=fisher_type, mc_samples=mc_samples, kfac_approx=kfac_approx, num_per_example_loss_terms=num_per_example_loss_terms, separate_weight_and_bias=separate_weight_and_bias, num_data=num_data, batch_size_fn=batch_size_fn, ) # KFAC = P @ K @ PT K, mapping = self._compute_canonical_op(computer) P, PT = self._build_converters(computer, mapping) super().__init__(P, K, PT)
@staticmethod def _compute_canonical_op( computer: _BaseKFACComputer, ) -> tuple[BlockDiagonalLinearOperator, list[ParamGroup]]: """Compute Kronecker factors and assemble the canonical block-diagonal operator. Args: computer: A KFAC computer instance. Returns: Tuple of (block diagonal operator in canonical basis, mapping). """ input_covariances, gradient_covariances, mapping = computer.compute() factors = [] for group in mapping: group_key = tuple(group.values()) aaT = input_covariances.get(group_key) ggT = gradient_covariances[group_key] factors.append([ggT, aaT] if aaT is not None else [ggT]) # Create Kronecker product linear operators for each block blocks = [KroneckerProductLinearOperator(*fs) for fs in factors] # KFAC in the canonical basis return BlockDiagonalLinearOperator(blocks), mapping @staticmethod def _build_converters( computer: _BaseKFACComputer, mapping: list[ParamGroup], ) -> tuple[FromCanonicalLinearOperator, ToCanonicalLinearOperator]: """Build the canonical space converters. Args: computer: A KFAC computer instance. mapping: List of parameter groups. Returns: Tuple of ``(from_canonical_op, to_canonical_op)``. """ PT = ToCanonicalLinearOperator( {name: p.shape for name, p in computer._params.items()}, mapping, computer.device, computer.dtype, ) P = PT.adjoint() return P, PT
[docs] def trace(self) -> Tensor: """Trace of the KFAC approximation. Returns: Trace of the KFAC approximation. """ _, K, _ = self return K.trace()
[docs] def det(self) -> Tensor: """Compute the determinant of the KFAC approximation. Returns: Determinant of the KFAC approximation. """ _, K, _ = self return K.det()
[docs] def logdet(self) -> Tensor: """Log determinant of the KFAC approximation. More numerically stable than the ``det`` method. Returns: Log determinant of the KFAC approximation. """ _, K, _ = self return K.logdet()
[docs] def frobenius_norm(self) -> Tensor: """Frobenius norm of the KFAC approximation. Returns: Frobenius norm of the KFAC approximation. """ _, K, _ = self return K.frobenius_norm()
[docs] def inverse( self, damping: float = 0.0, use_heuristic_damping: bool = False, min_damping: float = 1e-8, use_exact_damping: bool = False, retry_double_precision: bool = True, ) -> _ChainPyTorchLinearOperator: r"""Return the inverse of the KFAC approximation. Inverts each Kronecker-factored block of the canonical operator and returns the result in parameter space. Args: damping: Damping value applied to all Kronecker factors. Default: ``0.0``. use_heuristic_damping: Whether to use a heuristic damping strategy by `Martens and Grosse, 2015 <https://arxiv.org/abs/1503.05671>`_ (Section 6.3). Only supported for one or two factors. min_damping: Minimum damping value. Only used if ``use_heuristic_damping`` is ``True``. use_exact_damping: Whether to use exact damping, i.e. to invert :math:`(A \\otimes B) + \\text{damping}\\; \\mathbf{I}`. retry_double_precision: Whether to retry Cholesky decomposition used for inversion in double precision. Returns: Inverse of the KFAC approximation as a linear operator ``P @ K^-1 @ PT``. """ P, K, PT = self K_inv = BlockDiagonalLinearOperator([ block.inverse( damping=damping, use_heuristic_damping=use_heuristic_damping, min_damping=min_damping, use_exact_damping=use_exact_damping, retry_double_precision=retry_double_precision, ) for block in K ]) return _ChainPyTorchLinearOperator(P, K_inv, PT)