"""Contains LinearOperator implementation of EKFAC approximation of the Fisher/GGN."""
from __future__ import annotations
from collections.abc import MutableMapping
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from einops import einsum, rearrange
from torch import Generator, Tensor, cat
from torch.linalg import eigh
from torch.nn import (
BCEWithLogitsLoss,
Conv2d,
CrossEntropyLoss,
Module,
MSELoss,
Parameter,
)
from torch.utils.hooks import RemovableHandle
from curvlinops.kfac import (
FisherType,
KFACLinearOperator,
KFACType,
)
from curvlinops.kfac_utils import extract_patches
[docs]
class EKFACLinearOperator(KFACLinearOperator):
"""Linear operator to multiply with the Fisher/GGN's EKFAC approximation.
Eigenvalue-corrected Kronecker-Factored Approximate Curvature (EKFAC) was originally
introduced in
- George, T., Laurent, C., Bouthillier, X., Ballas, N., Vincent, P. (2018).
Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis (NeurIPS)
and concurrently in the context of continual learning in
Liu, X., Masana, M., Herranz, L., Van de Weijer, J., Lopez, A., Bagdanov, A. (2018).
Rotate your networks: Better weight consolidation and less catastrophic forgetting
(ICPR).
Attributes:
_SUPPORTED_FISHER_TYPE: Tuple with supported Fisher types.
"""
_SUPPORTED_FISHER_TYPE: Tuple[FisherType] = (
FisherType.TYPE2,
FisherType.MC,
FisherType.EMPIRICAL,
)
[docs]
def __init__(
self,
model_func: Module,
loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
seed: int = 2147483647,
fisher_type: str = FisherType.MC,
mc_samples: int = 1,
kfac_approx: str = KFACType.EXPAND,
num_per_example_loss_terms: Optional[int] = None,
separate_weight_and_bias: bool = True,
num_data: Optional[int] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
):
"""Eigenvalue-corrected KFAC (EKFAC) proxy of the Fisher/GGN.
Warning:
If the model's parameters change, e.g. during training, you need to
create a fresh instance of this object. This is because, for performance
reasons, the Kronecker factors are computed once and cached during the
first matrix-vector product. They will thus become outdated if the model
changes.
Warning:
This is an early proto-type with limitations:
- Only Linear and Conv2d modules are supported.
- Only models with 2d output are supported.
Args:
model_func: The neural network. Must consist of modules.
loss_func: The loss function.
params: The parameters defining the Fisher/GGN that will be approximated
through EKFAC.
data: A data loader containing the data of the Fisher/GGN.
progressbar: Whether to show a progress bar when computing the Kronecker
factors. Defaults to ``False``.
check_deterministic: Whether to check that the linear operator is
deterministic. Defaults to ``True``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
fisher_type: The type of Fisher/GGN to approximate.
If ``FisherType.TYPE2``, the exact Hessian of the loss w.r.t. the model
outputs is used. This requires as many backward passes as the output
dimension, i.e. the number of classes for classification. This is
sometimes also called type-2 Fisher.
If ``FisherType.MC``, the expectation is approximated by sampling
``mc_samples`` labels from the model's predictive distribution.
If ``FisherType.EMPIRICAL``, the empirical gradients are used which
corresponds to the uncentered gradient covariance/empirical Fisher.
Defaults to ``FisherType.MC``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != FisherType.MC``.
Defaults to ``1``.
kfac_approx: A string specifying the KFAC approximation that should
be used for linear weight-sharing layers, e.g. ``Conv2d`` modules
or ``Linear`` modules that process matrix- or higher-dimensional
features.
Possible values are ``KFACType.EXPAND`` and ``KFACType.REDUCE``.
See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_
for an explanation of the two approximations.
Defaults to ``KFACType.EXPAND``.
num_per_example_loss_terms: Number of per-example loss terms, e.g., the
number of tokens in a sequence. The model outputs will have
``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is
the dimension of the random variable we define the likelihood over --
for the ``CrossEntropyLoss`` it will be the number of classes, for the
``MSELoss`` and ``BCEWithLogitsLoss`` it will be the size of the last
dimension of the the model outputs/targets (our convention here).
If ``None``, ``num_per_example_loss_terms`` is inferred from the data at
the cost of one traversal through the data loader. It is expected to be
the same for all examples. Defaults to ``None``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.
"""
super().__init__(
model_func=model_func,
loss_func=loss_func,
params=params,
data=data,
progressbar=progressbar,
check_deterministic=False,
seed=seed,
fisher_type=fisher_type,
mc_samples=mc_samples,
kfac_approx=kfac_approx,
num_per_example_loss_terms=num_per_example_loss_terms,
separate_weight_and_bias=separate_weight_and_bias,
num_data=num_data,
batch_size_fn=batch_size_fn,
)
# Initialize the eigenvectors of the Kronecker factors
self._input_covariances_eigenvectors: Dict[str, Tensor] = {}
self._gradient_covariances_eigenvectors: Dict[str, Tensor] = {}
# Initialize the cache for activations
self._cached_activations: Dict[str, Tensor] = {}
# Initialize the corrected eigenvalues for EKFAC
self._corrected_eigenvalues: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}
if check_deterministic:
self._check_deterministic()
def _rearrange_for_larger_than_2d_output(
self, output: Tensor, y: Tensor
) -> Tuple[Tensor, Tensor]:
r"""Rearrange the output and target if output is >2d.
This will determine what kind of Fisher/GGN is approximated.
Args:
output: The model's prediction
:math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`.
y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.
Returns:
The rearranged outputs and targets.
Raises:
ValueError: If the output is not 2d and y is not 1d/2d.
"""
# Our individual gradient implementation for EKFAC does not support computing
# the individual gradients for any loss terms that might dependent on each other,
# i.e., loss terms other than the per-data point loss terms.
if output.ndim != 2 or y.ndim not in {1, 2}:
raise ValueError(
"Only 2d output and 1d/2d target are supported for EKFAC. "
f"Got {output.ndim=} and {y.ndim=}."
)
return output, y
def _maybe_compute_ekfac(self):
"""Compute the EKFAC approximation when necessary."""
if not self._corrected_eigenvalues:
if not (self._input_covariances or self._gradient_covariances):
self.compute_kronecker_factors()
self.compute_eigenvalue_correction()
def _matmat(self, M: List[Tensor]) -> List[Tensor]:
"""Apply EKFAC to a matrix (multiple vectors) in tensor list format.
This allows for matrix-matrix products with the EKFAC approximation in PyTorch
without converting tensors to numpy arrays, which avoids unnecessary
device transfers when working with GPUs and flattening/concatenating.
Args:
M: Matrix for multiplication in tensor list format. Each entry has the
same shape as a parameter with an additional trailing dimension of size
``K`` for the columns, i.e. ``[(*p1.shape, K), (*p2.shape, K), ...]``.
Returns:
Matrix-multiplication result ``EKFAC @ M`` in tensor list format. Has the same
shapes as the input.
"""
self._maybe_compute_ekfac()
KM: List[Tensor | None] = [None] * len(M)
for mod_name, param_pos in self._mapping.items():
# cache the weight shape to ensure correct shapes are returned
if "weight" in param_pos:
weight_shape = M[param_pos["weight"]].shape
# Get the EKFAC approximation components for the current module
# aaT_eigenvectors does not exist if the weight matrix is excluded
aaT_eigenvectors = self._input_covariances_eigenvectors.get(mod_name)
# ggT_eigenvectors and corrected_eigenvalues always exists
ggT_eigenvectors = self._gradient_covariances_eigenvectors[mod_name]
corrected_eigenvalues = self._corrected_eigenvalues[mod_name]
# bias and weights are treated jointly
if (
not self._separate_weight_and_bias
and "weight" in param_pos.keys()
and "bias" in param_pos.keys()
):
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
# v denotes the free dimension for treating multiple vectors in parallel
M_w = rearrange(M[w_pos], "c_out ... v -> c_out (...) v")
M_joint = cat([M_w, M[b_pos].unsqueeze(-2)], dim=-2)
M_joint = self._left_and_right_multiply(
M_joint, aaT_eigenvectors, ggT_eigenvectors, corrected_eigenvalues
)
w_cols = M_w.shape[1]
KM[w_pos], KM[b_pos] = M_joint.split([w_cols, 1], dim=-2)
KM[b_pos].squeeze_(1)
else:
self._separate_left_and_right_multiply(
KM,
M,
param_pos,
aaT_eigenvectors,
ggT_eigenvectors,
corrected_eigenvalues,
)
# restore original shapes
if "weight" in param_pos:
KM[param_pos["weight"]] = KM[param_pos["weight"]].view(weight_shape)
return KM
def _compute_eigenvectors(self):
"""Compute the eigenvectors of the KFAC approximation."""
if not (self._input_covariances or self._gradient_covariances):
self.compute_kronecker_factors()
for mod_name in self._mapping.keys():
for source, destination in zip(
(self._input_covariances, self._gradient_covariances),
(
self._input_covariances_eigenvectors,
self._gradient_covariances_eigenvectors,
),
):
factor = source.pop(mod_name, None)
if factor is not None:
destination[mod_name] = eigh(factor).eigenvectors
def compute_eigenvalue_correction(self):
"""Compute and cache the corrected eigenvalues for EKFAC."""
self._reset_matrix_properties()
# Compute the eigenvectors of the KFAC approximation
if not (
self._input_covariances_eigenvectors
or self._gradient_covariances_eigenvectors
):
self._compute_eigenvectors()
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []
for mod_name, param_pos in self._mapping.items():
module = self._model_func.get_submodule(mod_name)
# cache activations for computing per-example gradients
if "weight" in param_pos.keys():
hook_handles.append(
module.register_forward_pre_hook(
partial(self._hook_cache_inputs, module_name=mod_name)
)
)
# compute the corrected eigenvalues using the per-example gradients
hook_handles.append(
module.register_forward_hook(
partial(
self._register_tensor_hook_on_output_to_accumulate_corrected_eigenvalues,
module_name=mod_name,
)
)
)
if self._generator is None or self._generator.device != self.device:
self._generator = Generator(device=self.device)
self._generator.manual_seed(self._seed)
# loop over data set, computing the corrected eigenvalues
for X, y in self._loop_over_data(desc="Eigenvalue correction"):
output = self._model_func(X)
output, y = self._rearrange_for_larger_than_2d_output(output, y)
self._compute_loss_and_backward(output, y)
# Clear the cached activations
self._cached_activations.clear()
# clean up
for handle in hook_handles:
handle.remove()
def _hook_cache_inputs(
self, module: Module, inputs: Tuple[Tensor], module_name: str
):
"""Pre-forward hook that caches the inputs of a layer.
Updates ``self._cached_activations``.
Args:
module: Module on which the hook is called.
inputs: Inputs to the module.
module_name: Name of the module in the neural network.
Raises:
ValueError: If the module has multiple inputs.
"""
if len(inputs) != 1:
raise ValueError("Modules with multiple inputs are not supported.")
self._cached_activations[module_name] = inputs[0].data.detach()
def _register_tensor_hook_on_output_to_accumulate_corrected_eigenvalues(
self, module: Module, inputs: Tuple[Tensor], output: Tensor, module_name: str
):
"""Register tensor hook on layer's output to accumulate the corrected eigenvalues.
Note:
The easier way to compute the corrected eigenvalues would be via a full
backward hook on the module itself which performs the computation.
However, this approach breaks down if the output of a layer feeds into an
activation with `inplace=True` (see
https://github.com/pytorch/pytorch/issues/61519). Hence we use the
workaround
https://github.com/pytorch/pytorch/issues/61519#issuecomment-883524237, and
install a module hook which installs a tensor hook on the module's output
tensor, which performs the accumulation of the gradient covariance.
Args:
module: Layer onto whose output a tensor hook to accumulate the corrected
eigenvalues will be installed.
inputs: The layer's input tensors.
output: The layer's output tensor.
module_name: The name of the layer in the neural network.
"""
tensor_hook = partial(
self._accumulate_corrected_eigenvalues,
module=module,
module_name=module_name,
)
output.register_hook(tensor_hook)
def _accumulate_corrected_eigenvalues(
self, grad_output: Tensor, module: Module, module_name: str
):
r"""Accumulate the corrected eigenvalues.
The corrected eigenvalues are computed as
:math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where
:math:`Q_a` and :math:`Q_g` are the eigenvectors of the input and gradient
covariances, respectively, and ``G`` is the gradient matrix. The corrected
eigenvalues are used to correct the eigenvalues of the KFAC approximation
(EKFAC).
Updates ``self._corrected_eigenvalues``.
Args:
grad_output: The gradient w.r.t. the output.
module: The layer for which corrected eigenvalues will be accumulated.
module_name: The name of the layer in the neural network.
"""
g = grad_output.data.detach()
batch_size = g.shape[0]
if isinstance(module, Conv2d):
g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c")
g = rearrange(g, "batch ... d_out -> batch (...) d_out")
# Compute correction for the loss scaling depending on the loss reduction used
num_loss_terms = batch_size * self._num_per_example_loss_terms
# self._mc_samples will be 1 if fisher_type != FisherType.MC
correction = {
"sum": 1.0 / self._mc_samples,
"mean": num_loss_terms**2
/ (self._N_data * self._mc_samples * self._num_per_example_loss_terms),
}[self._loss_func.reduction]
# Compute the corrected eigenvalues for the EKFAC approximation
param_pos = self._mapping[module_name]
# aaT_eigenvectors does not exist if the weight matrix of the module is excluded
aaT_eigenvectors = self._input_covariances_eigenvectors.get(module_name)
# ggT_eigenvectors always exists
ggT_eigenvectors = self._gradient_covariances_eigenvectors[module_name]
# Rearrange the activations for computing per-example gradients
activations = self._cached_activations.get(module_name)
if activations is not None:
if isinstance(module, Conv2d):
activations = extract_patches(
activations,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
)
activations = rearrange(activations, "batch ... d_in -> batch (...) d_in")
if (
not self._separate_weight_and_bias
and "weight" in param_pos.keys()
and "bias" in param_pos.keys()
):
activations = cat(
[activations, activations.new_ones(*activations.shape[:-1], 1)], dim=-1
)
# Compute per-example gradient using the cached activations
per_example_gradient = einsum(
g,
activations,
"batch shared d_out, batch shared d_in -> batch d_out d_in",
)
# Transform the per-example gradient to the eigenbasis and square it
self._corrected_eigenvalues = self._set_or_add_(
self._corrected_eigenvalues,
module_name,
einsum(
ggT_eigenvectors,
per_example_gradient,
aaT_eigenvectors,
"d_out1 d_out2, batch d_out1 d_in1, d_in1 d_in2 -> batch d_out2 d_in2",
)
.square_()
.sum(dim=0)
.mul_(correction),
)
else:
if module_name not in self._corrected_eigenvalues:
self._corrected_eigenvalues[module_name] = {}
for p_name, pos in param_pos.items():
# Compute per-example gradient using the cached activations
per_example_gradient = (
einsum(
g,
activations,
"batch shared d_out, batch shared d_in -> batch d_out d_in",
)
if p_name == "weight"
else einsum(g, "batch shared d_out -> batch d_out")
)
# Transform the per-example gradient to the eigenbasis and square it
if p_name == "weight":
per_example_gradient = einsum(
per_example_gradient,
aaT_eigenvectors,
"batch d_out d_in1, d_in1 d_in2 -> batch d_out d_in2",
)
self._corrected_eigenvalues[module_name] = self._set_or_add_(
self._corrected_eigenvalues[module_name],
pos,
einsum(
ggT_eigenvectors,
per_example_gradient,
"d_out1 d_out2, batch d_out1 ... -> batch d_out2 ...",
)
.square_()
.sum(dim=0)
.mul_(correction),
)
@property
def trace(self) -> Tensor:
r"""Trace of the EKFAC approximation.
Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if
either of them has not been called before and will cache the trace until one of
them is called again.
Returns:
Trace of the EKFAC approximation.
"""
if self._trace is not None:
return self._trace
self._maybe_compute_ekfac()
# Compute the trace using the corrected eigenvalues
self._trace = 0.0
for corrected_eigenvalues in self._corrected_eigenvalues.values():
if isinstance(corrected_eigenvalues, dict):
for val in corrected_eigenvalues.values():
self._trace += val.sum()
else:
self._trace += corrected_eigenvalues.sum()
return self._trace
@property
def det(self) -> Tensor:
r"""Determinant of the EKFAC approximation.
Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if
either of them has not been called before and will cache the determinant until
one of them is called again.
Returns:
Determinant of the EKFAC approximation.
"""
if self._det is not None:
return self._det
self._maybe_compute_ekfac()
# Compute the determinant using the corrected eigenvalues
self._det = 1.0
for corrected_eigenvalues in self._corrected_eigenvalues.values():
if isinstance(corrected_eigenvalues, dict):
for val in corrected_eigenvalues.values():
self._det *= val.prod()
else:
self._det *= corrected_eigenvalues.prod()
return self._det
@property
def logdet(self) -> Tensor:
r"""Log determinant of the EKFAC approximation.
More numerically stable than the ``det`` property.
Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if
either of them has not been called before and will cache the logdet until one of
them is called again.
Returns:
Log determinant of the EKFAC approximation.
"""
if self._logdet is not None:
return self._logdet
self._maybe_compute_ekfac()
# Compute the log determinant using the corrected eigenvalues
self._logdet = 0.0
for corrected_eigenvalues in self._corrected_eigenvalues.values():
if isinstance(corrected_eigenvalues, dict):
for val in corrected_eigenvalues.values():
self._logdet += val.log().sum()
else:
self._logdet += corrected_eigenvalues.log().sum()
return self._logdet
@property
def frobenius_norm(self) -> Tensor:
r"""Frobenius norm of the EKFAC approximation.
Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if
either of them has not been called before and will cache the Frobenius norm
until one of them is called again.
Returns:
Frobenius norm of the EKFAC approximation.
"""
if self._frobenius_norm is not None:
return self._frobenius_norm
self._maybe_compute_ekfac()
# Compute the Frobenius norm using the corrected eigenvalues
self._frobenius_norm = 0.0
for corrected_eigenvalues in self._corrected_eigenvalues.values():
if isinstance(corrected_eigenvalues, dict):
for val in corrected_eigenvalues.values():
self._frobenius_norm += val.square().sum()
else:
self._frobenius_norm += corrected_eigenvalues.square().sum()
return self._frobenius_norm.sqrt_()
[docs]
def state_dict(self) -> Dict[str, Any]:
"""Return the state of the EKFAC linear operator.
Returns:
State dictionary.
"""
state_dict = super().state_dict()
# Add quantities specifically for EKFAC (if computed)
state_dict.update(
{
"input_covariances_eigenvectors": self._input_covariances_eigenvectors,
"gradient_covariances_eigenvectors": self._gradient_covariances_eigenvectors,
"cached_activations": self._cached_activations,
"corrected_eigenvalues": self._corrected_eigenvalues,
}
)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the state of the EKFAC linear operator.
Args:
state_dict: State dictionary.
"""
super().load_state_dict(state_dict)
# Set EKFAC-specific quantities
self._check_if_keys_match_mapping_keys(
state_dict["input_covariances_eigenvectors"]
)
self._check_if_keys_match_mapping_keys(
state_dict["gradient_covariances_eigenvectors"]
)
self._input_covariances_eigenvectors = state_dict[
"input_covariances_eigenvectors"
]
self._gradient_covariances_eigenvectors = state_dict[
"gradient_covariances_eigenvectors"
]
self._cached_activations = state_dict["cached_activations"]
self._corrected_eigenvalues = state_dict["corrected_eigenvalues"]