Source code for curvlinops.ekfac

"""Contains LinearOperator implementation of EKFAC approximation of the Fisher/GGN."""

from __future__ import annotations

from collections.abc import MutableMapping
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from einops import einsum, rearrange
from torch import Generator, Tensor, cat
from torch.linalg import eigh
from torch.nn import (
    BCEWithLogitsLoss,
    Conv2d,
    CrossEntropyLoss,
    Module,
    MSELoss,
    Parameter,
)
from torch.utils.hooks import RemovableHandle

from curvlinops.kfac import (
    FisherType,
    KFACLinearOperator,
    KFACType,
)
from curvlinops.kfac_utils import extract_patches


[docs] def compute_eigenvalue_correction_linear_weight_sharing( g: Tensor, ggT_eigvecs: Tensor, a: Tensor, aaT_eigvecs: Union[Tensor, None], _force_strategy: Optional[str] = None, ) -> Tensor: r"""Computes eigenvalue corrections for a linear layer with weight sharing. Chooses between two computational strategies depending on memory requirements. Args: g: Output gradients of the layer with shape ``[N, S, D1]``, where ``N`` is the batch size, ``S`` the weight sharing dimension, and ``D1`` the output dimension. ggT_eigvecs: Eigenvectors of the gradient covariance with shape ``[D1, D1]``. a: Layer inputs with shape ``[N, S, D2]``, where ``D2`` is the input dimension. aaT_eigvecs: Eigenvectors of the input covariance with shape ``[D2, D2]`` or ``None`` if the layer has no weights (bias only). _force_strategy: If specified, forces the use of either ``'gramian'`` or ``'per_example_gradients'`` strategy. If ``None``, the strategy is chosen based on memory requirements. Defaults to ``None``. This flag serves mainly for testing purposes. Returns: The eigencorrection with shape ``[D1, D2]`` (or ``[D1]`` for the bias case). Raises: ValueError: If an invalid ``_force_strategy`` is provided. Below we explain the mathematical details of what this function does. The mapping is is as follows: (``g``, :math:`\mathbf{Y}`), (``ggT_eigvecs``, :math:`\mathbf{Q}_1`), (``a``, :math:`\mathbf{X}`), (``aaT_eigvecs``, :math:`\mathbf{Q}_2`). Note: **Introduction:** In the following, let :math:`D_1` be the output dimension of the layer, :math:`D_2` the input dimension, :math:`S` the weight sharing dimension, and :math:`N` the batch size. Given the layer inputs :math:`\mathbf{X} \in \mathbb{R}^{N \times S \times D_2}`, output gradients :math:`\mathbf{Y} \in \mathbb{R}^{N \times S \times D_1}`, and a Kronecker-factored basis :math:`\mathbf{Q}_1 \otimes \mathbf{Q}_2` with factors :math:`\mathbf{Q}_i \in \mathbb{R}^{D_i \times D_i}`, our goal is to compute the eigencorrection :math:`\mathbf{E} \in \mathbb{R}^{D_1 \times D_2}` which has the same shape as the layer's weights. The common way to do that is to compute the per-example gradients :math:`\mathbf{G} \in \mathbb{R}^{N \times D_1 \times D_2}` with .. math:: \mathbf{G}_{n,d_1,d_2} = \sum_s \mathbf{Y}_{n,s,d_1} \mathbf{X}_{n,s,d_2}, rotate them into the Kronecker-factored basis, .. math:: \mathbf{\tilde{G}}_{n,d_1,d_2} = \sum_{i,j} \mathbf{G}_{n,i,j} \mathbf{Q}_{1,i,d_1} \mathbf{Q}_{2,j,d_2}, and compute the correction by squaring and summing out the batch dimension, .. math:: \mathbf{E}_{d_1,d_2} = \sum_{n} \mathbf{\tilde{G}}_{n,d_1,d_2}^2. Building up the per-example gradients can be extremely memory-costly. Therefore, we also consider an alternative approach which can have smaller memory footprint if the weight sharing is mild. Note: **(1) Cost analysis of per-example gradient approach:** The peak memory of building up per-example gradients is dominated by :math:`N D_1 D_2`. We have two options to compute the rotated per-example gradient. 1. First compute :math:`\mathbf{G}` and then rotate it. The first step costs :math:`N S D_1 D_2` time and the rotation costs :math:`N D_1 D_2 (D_1 + D_2)` time. 2. Rotate the activations and output gradients, then compute the rotated per-example gradient. The rotations cost :math:`N S (D_1^2 + D_2^2)` time The last step is :math:`N S D_1 D_2` time. So in practise, we should prefer the first approach over the second if .. math:: D_1 D_2 (D_1 + D_2) < S (D_1^2 + D_2^2). In the implementation, ``opt-einsum`` will automatically do that for us. Adding the cost for squaring and contracting, the overall cost is :math:`N S D_1 D_2 + N \min(S (D_1^2 + D_2^2), D_1 D_2 (D_1 + D_2)) + 2 N D_1 D_2`. Note: **(2) Cost analysis of Gramian contraction approach:** A way to avoid building up per-example gradients is to write the eigencorrection as big contraction of the rotated activations :math:`\mathbf{\tilde{X}}, \mathbf{\tilde{Y}}` and then rearrange the contractions such that the batch dimension can be directly summed: .. math:: \mathbf{E}_{d_1,d_2} = \sum_{n} \left( \sum_{s} \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{X}}_{n,s,d_2} \right) \left( \sum_{t} \mathbf{\tilde{Y}}_{n,t,d_1} \mathbf{\tilde{X}}_{n,t,d_2} \right) \\ = \sum_{n} \sum_{s} \sum_{t} \left( \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{Y}}_{n,t,d_1} \right) \left( \mathbf{\tilde{X}}_{n,s,d_2} \mathbf{\tilde{X}}_{n,t,d_2} \right) This requires building up the Gramians .. math:: \mathbf{G^Y}_{n,s,t,d_1} = \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{Y}}_{n,t,d_1}, \\ \mathbf{G^X}_{n,s,t,d_2} = \mathbf{\tilde{X}}_{n,s,d_2} \mathbf{\tilde{X}}_{n,t,d_2}. Peak memory is dominated by :math:`N S^2 (D_1 + D_2)`. The time is :math:`N S (D_1^2 + D_2^2)` for the rotations, :math:`N S^2 (D_1 + D_2)` for building up the Gramians, and :math:`N S^2 D_1 D_2` for the final contraction. In total, this is :math:`N S (D_1^2 + D_2^2) + N S^2 (D_1 + D_2 + D_1 D_2)`. **We select the approach with the smaller memory footprint**, i.e. Gramian contraction if :math:`S^2 (D_1 + D_2) < D_1 D_2`, and squaring per-example gradients otherwise. So generally speaking, the more weight sharing, the better building up per-example gradients will be. In the extreme case :math:`S=1` (no weight sharing), the Gramian contraction approach uses only :math:`N (D_1 + D_2) < N D_1 D_2` memory compared to the per-example gradient approach. In terms of time, the Gramian contraction uses :math:`N (D_1^2 + D_2^2 + D_1 + D_2 + D_1 D_2) < N (3 D_1 D_2 + D_1^2 + D_2^2)` compared to the per-example gradient. """ strategies = {"gramian", "per_example_gradients", None} if _force_strategy not in strategies: raise ValueError( f"Invalid _force_strategy: {_force_strategy}. Supported: {strategies}." ) Q1, Q2 = ggT_eigvecs, aaT_eigvecs Y, X = g, a if Q2 is None: # -> 1d (bias) case eigencorrection = ( einsum(Q1, Y, "j d1, batch shared j -> batch d1").square_().sum(0) ) else: # -> 2d (weight or weight+bias) case # Determine approach: Gramian contraction or per-example gradients (_, S, D1), (_, _, D2) = g.shape, a.shape # Determine approach based on _force_strategy or memory requirements use_gramian = ( _force_strategy == "gramian" if _force_strategy is not None # Memory of per-example gradients is dominated by N * D1 * D2 # Memory of Gramian contraction is dominated by N * S^2 * (D1 + D2) # We choose the approach that requires less memory. else S**2 * (D1 + D2) < D1 * D2 ) if use_gramian: # -> Gramian approach X_rot = einsum(X, Q2, "batch shared j, j d2 -> batch shared d2") Y_rot = einsum(Y, Q1, "batch shared i, i d1 -> batch shared d1") # In the absence of weight sharing (S=1), this simply computes # (Q^T X_rot)^2 and (Q^T Y_rot)^2, then computes the correction X_gram = einsum(X_rot, X_rot, "batch s d2, batch t d2 -> batch s t d2") Y_gram = einsum(Y_rot, Y_rot, "batch s d1, batch t d1 -> batch s t d1") eigencorrection = einsum( Y_gram, X_gram, "batch s t d1, batch s t d2 -> d1 d2" ) else: # -> per-example gradient approach rotated_per_example_gradient = einsum( Q1, Y, X, Q2, "i d1, batch shared i, batch shared j, j d2 -> batch d1 d2", ) eigencorrection = rotated_per_example_gradient.square_().sum(dim=0) return eigencorrection
[docs] class EKFACLinearOperator(KFACLinearOperator): """Linear operator to multiply with the Fisher/GGN's EKFAC approximation. Eigenvalue-corrected Kronecker-Factored Approximate Curvature (EKFAC) was originally introduced in - George, T., Laurent, C., Bouthillier, X., Ballas, N., Vincent, P. (2018). Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis (NeurIPS) and concurrently in the context of continual learning in - Liu, X., Masana, M., Herranz, L., Van de Weijer, J., Lopez, A., Bagdanov, A. (2018). Rotate your networks: Better weight consolidation and less catastrophic forgetting (ICPR). Attributes: _SUPPORTED_FISHER_TYPE: Tuple with supported Fisher types. """ _SUPPORTED_FISHER_TYPE: Tuple[FisherType] = ( FisherType.TYPE2, FisherType.MC, FisherType.EMPIRICAL, )
[docs] def __init__( self, model_func: Module, loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, check_deterministic: bool = True, seed: int = 2147483647, fisher_type: str = FisherType.MC, mc_samples: int = 1, kfac_approx: str = KFACType.EXPAND, num_per_example_loss_terms: Optional[int] = None, separate_weight_and_bias: bool = True, num_data: Optional[int] = None, batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, ): """Eigenvalue-corrected KFAC (EKFAC) proxy of the Fisher/GGN. Warning: If the model's parameters change, e.g. during training, you need to create a fresh instance of this object. This is because, for performance reasons, the Kronecker factors are computed once and cached during the first matrix-vector product. They will thus become outdated if the model changes. Warning: This is an early proto-type with limitations: - Only Linear and Conv2d modules are supported. - Only models with 2d output are supported. Args: model_func: The neural network. Must consist of modules. loss_func: The loss function. params: The parameters defining the Fisher/GGN that will be approximated through EKFAC. data: A data loader containing the data of the Fisher/GGN. progressbar: Whether to show a progress bar when computing the Kronecker factors. Defaults to ``False``. check_deterministic: Whether to check that the linear operator is deterministic. Defaults to ``True``. seed: The seed for the random number generator used to draw labels from the model's predictive distribution. Defaults to ``2147483647``. fisher_type: The type of Fisher/GGN to approximate. If ``FisherType.TYPE2``, the exact Hessian of the loss w.r.t. the model outputs is used. This requires as many backward passes as the output dimension, i.e. the number of classes for classification. This is sometimes also called type-2 Fisher. If ``FisherType.MC``, the expectation is approximated by sampling ``mc_samples`` labels from the model's predictive distribution. If ``FisherType.EMPIRICAL``, the empirical gradients are used which corresponds to the uncentered gradient covariance/empirical Fisher. Defaults to ``FisherType.MC``. mc_samples: The number of Monte-Carlo samples to use per data point. Has to be set to ``1`` when ``fisher_type != FisherType.MC``. Defaults to ``1``. kfac_approx: A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g. ``Conv2d`` modules or ``Linear`` modules that process matrix- or higher-dimensional features. Possible values are ``KFACType.EXPAND`` and ``KFACType.REDUCE``. See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_ for an explanation of the two approximations. Defaults to ``KFACType.EXPAND``. num_per_example_loss_terms: Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have ``num_data * num_per_example_loss_terms * C`` entries, where ``C`` is the dimension of the random variable we define the likelihood over -- for the ``CrossEntropyLoss`` it will be the number of classes, for the ``MSELoss`` and ``BCEWithLogitsLoss`` it will be the size of the last dimension of the the model outputs/targets (our convention here). If ``None``, ``num_per_example_loss_terms`` is inferred from the data at the cost of one traversal through the data loader. It is expected to be the same for all examples. Defaults to ``None``. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this needs to be specified. The intended behavior is to consume the first entry of the iterates from ``data`` and return their batch size. """ super().__init__( model_func=model_func, loss_func=loss_func, params=params, data=data, progressbar=progressbar, check_deterministic=False, seed=seed, fisher_type=fisher_type, mc_samples=mc_samples, kfac_approx=kfac_approx, num_per_example_loss_terms=num_per_example_loss_terms, separate_weight_and_bias=separate_weight_and_bias, num_data=num_data, batch_size_fn=batch_size_fn, ) # Initialize the eigenvectors of the Kronecker factors self._input_covariances_eigenvectors: Dict[str, Tensor] = {} self._gradient_covariances_eigenvectors: Dict[str, Tensor] = {} # Initialize the cache for activations self._cached_activations: Dict[str, Tensor] = {} # Initialize the corrected eigenvalues for EKFAC self._corrected_eigenvalues: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} if check_deterministic: self._check_deterministic()
def _rearrange_for_larger_than_2d_output( self, output: Tensor, y: Tensor ) -> Tuple[Tensor, Tensor]: r"""Rearrange the output and target if output is >2d. This will determine what kind of Fisher/GGN is approximated. Args: output: The model's prediction :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`. y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`. Returns: The rearranged outputs and targets. Raises: ValueError: If the output is not 2d and y is not 1d/2d. """ # Our individual gradient implementation for EKFAC does not support computing # the individual gradients for any loss terms that might dependent on each other, # i.e., loss terms other than the per-data point loss terms. if output.ndim != 2 or y.ndim not in {1, 2}: raise ValueError( "Only 2d output and 1d/2d target are supported for EKFAC. " f"Got {output.ndim=} and {y.ndim=}." ) return output, y def _maybe_compute_ekfac(self): """Compute the EKFAC approximation when necessary.""" if not self._corrected_eigenvalues: if not (self._input_covariances or self._gradient_covariances): self.compute_kronecker_factors() self.compute_eigenvalue_correction() def _matmat(self, M: List[Tensor]) -> List[Tensor]: """Apply EKFAC to a matrix (multiple vectors) in tensor list format. This allows for matrix-matrix products with the EKFAC approximation in PyTorch without converting tensors to numpy arrays, which avoids unnecessary device transfers when working with GPUs and flattening/concatenating. Args: M: Matrix for multiplication in tensor list format. Each entry has the same shape as a parameter with an additional trailing dimension of size ``K`` for the columns, i.e. ``[(*p1.shape, K), (*p2.shape, K), ...]``. Returns: Matrix-multiplication result ``EKFAC @ M`` in tensor list format. Has the same shapes as the input. """ self._maybe_compute_ekfac() KM: List[Tensor | None] = [None] * len(M) for mod_name, param_pos in self._mapping.items(): # cache the weight shape to ensure correct shapes are returned if "weight" in param_pos: weight_shape = M[param_pos["weight"]].shape # Get the EKFAC approximation components for the current module # aaT_eigenvectors does not exist if the weight matrix is excluded aaT_eigenvectors = self._input_covariances_eigenvectors.get(mod_name) # ggT_eigenvectors and corrected_eigenvalues always exists ggT_eigenvectors = self._gradient_covariances_eigenvectors[mod_name] corrected_eigenvalues = self._corrected_eigenvalues[mod_name] # bias and weights are treated jointly if ( not self._separate_weight_and_bias and "weight" in param_pos.keys() and "bias" in param_pos.keys() ): w_pos, b_pos = param_pos["weight"], param_pos["bias"] # v denotes the free dimension for treating multiple vectors in parallel M_w = rearrange(M[w_pos], "c_out ... v -> c_out (...) v") M_joint = cat([M_w, M[b_pos].unsqueeze(-2)], dim=-2) M_joint = self._left_and_right_multiply( M_joint, aaT_eigenvectors, ggT_eigenvectors, corrected_eigenvalues ) w_cols = M_w.shape[1] KM[w_pos], KM[b_pos] = M_joint.split([w_cols, 1], dim=-2) KM[b_pos].squeeze_(1) else: self._separate_left_and_right_multiply( KM, M, param_pos, aaT_eigenvectors, ggT_eigenvectors, corrected_eigenvalues, ) # restore original shapes if "weight" in param_pos: KM[param_pos["weight"]] = KM[param_pos["weight"]].view(weight_shape) return KM def _compute_eigenvectors(self): """Compute the eigenvectors of the KFAC approximation.""" if not (self._input_covariances or self._gradient_covariances): self.compute_kronecker_factors() for mod_name in self._mapping.keys(): for source, destination in zip( (self._input_covariances, self._gradient_covariances), ( self._input_covariances_eigenvectors, self._gradient_covariances_eigenvectors, ), ): factor = source.pop(mod_name, None) if factor is not None: destination[mod_name] = eigh(factor).eigenvectors def compute_eigenvalue_correction(self): """Compute and cache the corrected eigenvalues for EKFAC.""" self._reset_matrix_properties() # Compute the eigenvectors of the KFAC approximation if not ( self._input_covariances_eigenvectors or self._gradient_covariances_eigenvectors ): self._compute_eigenvectors() # install forward and backward hooks hook_handles: List[RemovableHandle] = [] for mod_name, param_pos in self._mapping.items(): module = self._model_func.get_submodule(mod_name) # cache activations for computing per-example gradients if "weight" in param_pos.keys(): hook_handles.append( module.register_forward_pre_hook( partial(self._hook_cache_inputs, module_name=mod_name) ) ) # compute the corrected eigenvalues using the per-example gradients hook_handles.append( module.register_forward_hook( partial( self._register_tensor_hook_on_output_to_accumulate_corrected_eigenvalues, module_name=mod_name, ) ) ) if self._generator is None or self._generator.device != self.device: self._generator = Generator(device=self.device) self._generator.manual_seed(self._seed) # loop over data set, computing the corrected eigenvalues for X, y in self._loop_over_data(desc="Eigenvalue correction"): output = self._model_func(X) output, y = self._rearrange_for_larger_than_2d_output(output, y) self._compute_loss_and_backward(output, y) # Clear the cached activations self._cached_activations.clear() # clean up for handle in hook_handles: handle.remove() def _hook_cache_inputs( self, module: Module, inputs: Tuple[Tensor], module_name: str ): """Pre-forward hook that caches the inputs of a layer. Updates ``self._cached_activations``. Args: module: Module on which the hook is called. inputs: Inputs to the module. module_name: Name of the module in the neural network. Raises: ValueError: If the module has multiple inputs. """ if len(inputs) != 1: raise ValueError("Modules with multiple inputs are not supported.") self._cached_activations[module_name] = inputs[0].data.detach() def _register_tensor_hook_on_output_to_accumulate_corrected_eigenvalues( self, module: Module, inputs: Tuple[Tensor], output: Tensor, module_name: str ): """Register tensor hook on layer's output to accumulate the corrected eigenvalues. Note: The easier way to compute the corrected eigenvalues would be via a full backward hook on the module itself which performs the computation. However, this approach breaks down if the output of a layer feeds into an activation with `inplace=True` (see https://github.com/pytorch/pytorch/issues/61519). Hence we use the workaround https://github.com/pytorch/pytorch/issues/61519#issuecomment-883524237, and install a module hook which installs a tensor hook on the module's output tensor, which performs the accumulation of the gradient covariance. Args: module: Layer onto whose output a tensor hook to accumulate the corrected eigenvalues will be installed. inputs: The layer's input tensors. output: The layer's output tensor. module_name: The name of the layer in the neural network. """ tensor_hook = partial( self._accumulate_corrected_eigenvalues, module=module, module_name=module_name, ) output.register_hook(tensor_hook) def _accumulate_corrected_eigenvalues( self, grad_output: Tensor, module: Module, module_name: str ): r"""Accumulate the corrected eigenvalues. The corrected eigenvalues are computed as :math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where :math:`Q_a` and :math:`Q_g` are the eigenvectors of the input and gradient covariances, respectively, and ``G`` is the gradient matrix. The corrected eigenvalues are used to correct the eigenvalues of the KFAC approximation (EKFAC). Updates ``self._corrected_eigenvalues``. Args: grad_output: The gradient w.r.t. the output. module: The layer for which corrected eigenvalues will be accumulated. module_name: The name of the layer in the neural network. """ g = grad_output.data.detach() batch_size = g.shape[0] if isinstance(module, Conv2d): g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c") g = rearrange(g, "batch ... d_out -> batch (...) d_out") # Compute correction for the loss scaling depending on the loss reduction used num_loss_terms = batch_size * self._num_per_example_loss_terms # self._mc_samples will be 1 if fisher_type != FisherType.MC correction = { "sum": 1.0 / self._mc_samples, "mean": num_loss_terms**2 / (self._N_data * self._mc_samples * self._num_per_example_loss_terms), }[self._loss_func.reduction] # Compute the corrected eigenvalues for the EKFAC approximation param_pos = self._mapping[module_name] # aaT_eigenvectors does not exist if the weight matrix of the module is excluded aaT_eigenvectors = self._input_covariances_eigenvectors.get(module_name) # ggT_eigenvectors always exists ggT_eigenvectors = self._gradient_covariances_eigenvectors[module_name] # Rearrange the activations for computing per-example gradients activations = self._cached_activations.get(module_name) if activations is not None: if isinstance(module, Conv2d): activations = extract_patches( activations, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, ) activations = rearrange(activations, "batch ... d_in -> batch (...) d_in") if ( not self._separate_weight_and_bias and "weight" in param_pos.keys() and "bias" in param_pos.keys() ): a_augmented = cat( [activations, activations.new_ones(*activations.shape[:-1], 1)], dim=-1 ) eigencorrection = compute_eigenvalue_correction_linear_weight_sharing( g, ggT_eigenvectors, a_augmented, aaT_eigenvectors ) self._corrected_eigenvalues = self._set_or_add_( self._corrected_eigenvalues, module_name, eigencorrection.mul_(correction), ) else: if module_name not in self._corrected_eigenvalues: self._corrected_eigenvalues[module_name] = {} for p_name, pos in param_pos.items(): eigencorrection = compute_eigenvalue_correction_linear_weight_sharing( g, ggT_eigenvectors, activations, aaT_eigvecs=None if p_name == "bias" else aaT_eigenvectors, ) self._corrected_eigenvalues[module_name] = self._set_or_add_( self._corrected_eigenvalues[module_name], pos, eigencorrection.mul_(correction), ) @property def trace(self) -> Tensor: r"""Trace of the EKFAC approximation. Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if either of them has not been called before and will cache the trace until one of them is called again. Returns: Trace of the EKFAC approximation. """ if self._trace is not None: return self._trace self._maybe_compute_ekfac() # Compute the trace using the corrected eigenvalues self._trace = 0.0 for corrected_eigenvalues in self._corrected_eigenvalues.values(): if isinstance(corrected_eigenvalues, dict): for val in corrected_eigenvalues.values(): self._trace += val.sum() else: self._trace += corrected_eigenvalues.sum() return self._trace @property def det(self) -> Tensor: r"""Determinant of the EKFAC approximation. Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if either of them has not been called before and will cache the determinant until one of them is called again. Returns: Determinant of the EKFAC approximation. """ if self._det is not None: return self._det self._maybe_compute_ekfac() # Compute the determinant using the corrected eigenvalues self._det = 1.0 for corrected_eigenvalues in self._corrected_eigenvalues.values(): if isinstance(corrected_eigenvalues, dict): for val in corrected_eigenvalues.values(): self._det *= val.prod() else: self._det *= corrected_eigenvalues.prod() return self._det @property def logdet(self) -> Tensor: r"""Log determinant of the EKFAC approximation. More numerically stable than the ``det`` property. Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if either of them has not been called before and will cache the logdet until one of them is called again. Returns: Log determinant of the EKFAC approximation. """ if self._logdet is not None: return self._logdet self._maybe_compute_ekfac() # Compute the log determinant using the corrected eigenvalues self._logdet = 0.0 for corrected_eigenvalues in self._corrected_eigenvalues.values(): if isinstance(corrected_eigenvalues, dict): for val in corrected_eigenvalues.values(): self._logdet += val.log().sum() else: self._logdet += corrected_eigenvalues.log().sum() return self._logdet @property def frobenius_norm(self) -> Tensor: r"""Frobenius norm of the EKFAC approximation. Will call ``compute_kronecker_factors`` and ``compute_eigenvalue_correction`` if either of them has not been called before and will cache the Frobenius norm until one of them is called again. Returns: Frobenius norm of the EKFAC approximation. """ if self._frobenius_norm is not None: return self._frobenius_norm self._maybe_compute_ekfac() # Compute the Frobenius norm using the corrected eigenvalues self._frobenius_norm = 0.0 for corrected_eigenvalues in self._corrected_eigenvalues.values(): if isinstance(corrected_eigenvalues, dict): for val in corrected_eigenvalues.values(): self._frobenius_norm += val.square().sum() else: self._frobenius_norm += corrected_eigenvalues.square().sum() return self._frobenius_norm.sqrt_()
[docs] def state_dict(self) -> Dict[str, Any]: """Return the state of the EKFAC linear operator. Returns: State dictionary. """ state_dict = super().state_dict() # Add quantities specifically for EKFAC (if computed) state_dict.update( { "input_covariances_eigenvectors": self._input_covariances_eigenvectors, "gradient_covariances_eigenvectors": self._gradient_covariances_eigenvectors, "cached_activations": self._cached_activations, "corrected_eigenvalues": self._corrected_eigenvalues, } ) return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): """Load the state of the EKFAC linear operator. Args: state_dict: State dictionary. """ super().load_state_dict(state_dict) # Set EKFAC-specific quantities self._check_if_keys_match_mapping_keys( state_dict["input_covariances_eigenvectors"] ) self._check_if_keys_match_mapping_keys( state_dict["gradient_covariances_eigenvectors"] ) self._input_covariances_eigenvectors = state_dict[ "input_covariances_eigenvectors" ] self._gradient_covariances_eigenvectors = state_dict[ "gradient_covariances_eigenvectors" ] self._cached_activations = state_dict["cached_activations"] self._corrected_eigenvalues = state_dict["corrected_eigenvalues"]