Source code for curvlinops.ekfac

"""Contains LinearOperator implementation of EKFAC approximation of the Fisher/GGN."""

from curvlinops._torch_base import _ChainPyTorchLinearOperator
from curvlinops.blockdiagonal import BlockDiagonalLinearOperator
from curvlinops.computers._base import ParamGroup
from curvlinops.computers.ekfac_hooks import HooksEKFACComputer
from curvlinops.computers.ekfac_make_fx import MakeFxEKFACComputer
from curvlinops.eigh import EighDecomposedLinearOperator
from curvlinops.kfac import KFACLinearOperator
from curvlinops.kronecker import KroneckerProductLinearOperator


[docs] class EKFACLinearOperator(KFACLinearOperator): """Linear operator to multiply with the Fisher/GGN's EKFAC approximation. Eigenvalue-corrected Kronecker-Factored Approximate Curvature (EKFAC) was originally introduced in - George, T., Laurent, C., Bouthillier, X., Ballas, N., Vincent, P. (2018). Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis (NeurIPS) and concurrently in the context of continual learning in - Liu, X., Masana, M., Herranz, L., Van de Weijer, J., Lopez, A., Bagdanov, A. (2018). Rotate your networks: Better weight consolidation and less catastrophic forgetting (ICPR). """ _BACKENDS: dict[str, type] = { "hooks": HooksEKFACComputer, "make_fx": MakeFxEKFACComputer, } @staticmethod def _compute_canonical_op( computer: HooksEKFACComputer | MakeFxEKFACComputer, ) -> tuple[BlockDiagonalLinearOperator, list[ParamGroup]]: """Compute EKFAC factors and assemble the canonical block-diagonal operator. Args: computer: An EKFAC computer instance (hooks or FX backend). Returns: Tuple of (block diagonal operator in canonical basis, mapping). """ input_eigvecs, gradient_eigvecs, corrected_eigenvalues, mapping = ( computer.compute() ) bases = [] corrections = [] for group in mapping: group_key = tuple(group.values()) Q_a = input_eigvecs.get(group_key) Q_g = gradient_eigvecs[group_key] lambdas = corrected_eigenvalues[group_key] bases.append([Q_g, Q_a] if Q_a is not None else [Q_g]) corrections.append(lambdas) # Create Kronecker product linear operators for each block blocks = [ EighDecomposedLinearOperator( correction.flatten(), KroneckerProductLinearOperator(*basis) ) for basis, correction in zip(bases, corrections) ] # EKFAC in the canonical basis return BlockDiagonalLinearOperator(blocks), mapping
[docs] def inverse(self, damping: float = 0.0) -> _ChainPyTorchLinearOperator: """Return the inverse of the EKFAC approximation. Inverts each eigendecomposed block of the canonical operator and returns the result in parameter space. Args: damping: Damping term added to eigenvalues before inversion. Default: ``0.0``. Returns: Inverse of the EKFAC approximation as a linear operator. """ P, K, PT = self K_inv = BlockDiagonalLinearOperator([ block.inverse(damping=damping) for block in K ]) return _ChainPyTorchLinearOperator(P, K_inv, PT)