Source code for curvlinops.fisher

"""Contains LinearOperator implementation of the approximate Fisher."""

from collections.abc import MutableMapping
from math import sqrt
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum, rearrange
from torch import Generator, Tensor, as_tensor, normal, softmax, zeros, zeros_like
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter
from torch.nn.functional import one_hot

from curvlinops._torch_base import CurvatureLinearOperator


[docs] class FisherMCLinearOperator(CurvatureLinearOperator): r"""Monte-Carlo approximation of the Fisher as SciPy linear operator. Consider the empirical risk .. math:: \mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for ``reduction='sum'``. Let :math:`\ell(\mathbf{f}, \mathbf{y}) = - \log q(\mathbf{y} \mid \mathbf{f})` be a negative log-likelihood loss. Denoting :math:`\mathbf{f}_n = f_{\mathbf{\theta}}(\mathbf{x}_n)`, the Fisher information matrix is .. math:: c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)^\top \mathbb{E}_{\mathbf{\tilde{y}}_n \sim q( \cdot \mid \mathbf{f}_n)} \left[ \nabla_{\mathbf{f}_n}^2 \ell(\mathbf{f}_n, \mathbf{\tilde{y}}_n) \right] \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right) \\ = c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)^\top \mathbb{E}_{\mathbf{\tilde{y}}_n \sim q( \cdot \mid \mathbf{f}_n)} \left[ \left( \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \mathbf{\tilde{y}}_n) \right) \left( \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \mathbf{\tilde{y}}_n) \right)^{\top} \right] \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right) \\ \approx c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)^\top \frac{1}{M} \sum_{m=1}^M \left[ \left( \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \mathbf{\tilde{y}}_{n}^{(m)}) \right) \left( \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \mathbf{\tilde{y}}_{n}^{(m)}) \right)^{\top} \right] \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right) with sampled targets :math:`\mathbf{\tilde{y}}_{n}^{(m)} \sim q( \cdot \mid \mathbf{f}_n)`. The expectation over the model's likelihood is approximated via a Monte-Carlo estimator with :math:`M` samples. The linear operator represents a deterministic sample from this MC Fisher estimator. To generate different samples, you have to create instances with varying random seed argument. Attributes: SELF_ADJOINT: Whether the operator is self-adjoint. ``True`` for the Fisher. supported_losses: Supported loss functions. FIXED_DATA_ORDER: Whether the data order must be fix. ``True`` for MC-Fisher. """ SELF_ADJOINT: bool = True FIXED_DATA_ORDER: bool = True supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)
[docs] def __init__( self, model_func: Callable[[Union[Tensor, MutableMapping]], Tensor], loss_func: Union[MSELoss, CrossEntropyLoss], params: List[Parameter], data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]], progressbar: bool = False, check_deterministic: bool = True, seed: int = 2147483647, mc_samples: int = 1, num_data: Optional[int] = None, batch_size_fn: Optional[Callable[[MutableMapping], int]] = None, ): """Linear operator for the Monte-Carlo approximation of the type-I Fisher. Note: f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y. Args: model_func: A function that maps the mini-batch input X to predictions. Could be a PyTorch module representing a neural network. loss_func: Loss function criterion. Maps predictions and mini-batch labels to a scalar value. params: List of differentiable parameters used by the prediction function. data: Source from which mini-batches can be drawn, for instance a list of mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X`` could be a ``dict`` or ``UserDict``; this is useful for custom models. In this case, you must (i) specify the ``batch_size_fn`` argument, and (ii) take care of preprocessing like ``X.to(device)`` inside of your ``model.forward()`` function. Due to the sequential internal Monte-Carlo sampling, batches must be presented in the same deterministic order (no shuffling!). progressbar: Show a progressbar during matrix-multiplication. Default: ``False``. check_deterministic: Probe that model and data are deterministic, i.e. that the data does not use `drop_last` or data augmentation. Also, the model's forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default: ``True``. This is a safeguard, only turn it off if you know what you are doing. seed: Seed used to construct the internal random number generator used to draw samples at the beginning of each matrix-vector product. Default: ``2147483647`` mc_samples: Number of samples to use. Default: ``1``. num_data: Number of data points. If ``None``, it is inferred from the data at the cost of one traversal through the data loader. batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this needs to be specified. The intended behavior is to consume the first entry of the iterates from ``data`` and return their batch size. Raises: NotImplementedError: If the loss function differs from ``MSELoss``, BCEWithLogitsLoss, or ``CrossEntropyLoss``. """ if not isinstance(loss_func, self.supported_losses): raise NotImplementedError( f"Loss must be one of {self.supported_losses}. Got: {loss_func}." ) self._seed = seed self._generator: Union[None, Generator] = None self._mc_samples = mc_samples super().__init__( model_func, loss_func, params, data, progressbar=progressbar, check_deterministic=check_deterministic, num_data=num_data, batch_size_fn=batch_size_fn, )
def _matmat(self, M: List[Tensor]) -> List[Tensor]: """Multiply the MC-Fisher onto a matrix. Create and seed the random number generator. Args: M: Matrix for multiplication in tensor list format. Returns: Matrix-multiplication result ``mat @ M`` in tensor list format. """ if self._generator is None or self._generator.device != self.device: self._generator = Generator(device=self.device) self._generator.manual_seed(self._seed) return super()._matmat(M) def _matmat_batch( self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] ) -> List[Tensor]: """Apply the mini-batch MC-Fisher to a matrix. Args: X: Input to the DNN. y: Ground truth. M: Matrix to be multiplied with in tensor list format. Tensors have same shape as trainable model parameters, and an additional trailing axis for the matrix columns. Returns: Result of MC-Fisher multiplication in tensor list format. Has the same shape as ``M``, i.e. each tensor in the list has the shape of a parameter and a trailing dimension of matrix columns. """ # compute ∂ℓₙ(yₙₘ)/∂fₙ where fₙ is the prediction for datum n and # yₙₘ is the m-th sampled label for datum n output = self._model_func(X) # If >2d output we convert to an equivalent 2d output if isinstance(self._loss_func, CrossEntropyLoss): output = rearrange(output, "batch c ... -> (batch ...) c") y = rearrange(y, "batch ... -> (batch ...)") else: output = rearrange(output, "batch ... c -> (batch ...) c") y = rearrange(y, "batch ... c -> (batch ...) c") grad_output = self.sample_grad_output(output, self._mc_samples, y) # Adjust the scale depending on the loss function and reduction used num_loss_terms, C = output.shape reduction_factor = { "mean": ( num_loss_terms if isinstance(self._loss_func, CrossEntropyLoss) else num_loss_terms * C ), "sum": 1.0, }[self._loss_func.reduction] # Compute the pseudo-loss L' := 0.5 / (M * c) ∑ₙ ∑ₘ fₙᵀ (gₙₘ gₙₘᵀ) fₙ where # gₙₘ = ∂ℓₙ(yₙₘ)/∂fₙ (detached) and M is the number of MC samples. # The GGN of L' linearized at fₙ is the MC Fisher. # We can thus multiply with it by computing the GGN-vector products of L'. loss = ( 0.5 / reduction_factor / self._mc_samples * (einsum(output, grad_output, "n ..., m n ... -> m n") ** 2).sum() ) # Multiply the MC Fisher onto each vector in the input matrix FM = [zeros_like(m) for m in M] (num_vectors,) = {m.shape[-1] for m in M} for v in range(num_vectors): for idx, Fm in enumerate( ggn_vector_product_from_plist( loss, output, self._params, [m[..., v] for m in M] ) ): FM[idx][..., v].add_(Fm.detach()) return FM def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Tensor: """Draw would-be gradients ``∇_f log p(·|f)``. For a single data point, the would-be gradient's outer product equals the Hessian ``∇²_f log p(·|f)`` in expectation. Currently only supports ``MSELoss``, ``CrossEntropyLoss``, and ``BCEWithLogitsLoss``. The returned gradient does not account for the scaling of the loss function by the output dimension ``C`` that ``MSELoss`` and ``BCEWithLogitsLoss`` apply when ``reduction='mean'``. Args: output: model prediction ``f`` for multiple data with batch axis as 0th dimension. num_samples: Number of samples to draw. y: Labels of the data on which output was produced. Returns: Samples of the gradient w.r.t. the model prediction. Has shape ``[num_samples, *output.shape]``. Raises: NotImplementedError: For unsupported loss functions. NotImplementedError: If the prediction does not have two dimensions. NotImplementedError: If binary classification labels are not binary. """ if output.ndim != 2: raise NotImplementedError(f"Only 2d outputs supported. Got {output.shape}") C = output.shape[1] if isinstance(self._loss_func, MSELoss): std = as_tensor(sqrt(0.5), device=output.device) mean = zeros( num_samples, *output.shape, device=output.device, dtype=output.dtype ) return 2 * normal(mean, std, generator=self._generator) elif isinstance(self._loss_func, CrossEntropyLoss): prob = softmax(output, dim=1) sample = prob.multinomial( num_samples=num_samples, replacement=True, generator=self._generator ) sample = rearrange(sample, "batch s -> s batch") onehot_sample = one_hot(sample, num_classes=C) # repeat ``num_sample`` times along a new leading axis to avoid broadcasting prob = prob.unsqueeze(0).expand_as(onehot_sample) return prob - onehot_sample elif isinstance(self._loss_func, BCEWithLogitsLoss): unique = set(y.unique().flatten().tolist()) if not unique.issubset({0, 1}): raise NotImplementedError( "Only binary targets (0, 1) are currently supported with" + f" BCEWithLogitsLoss. Got {unique}." ) prob = output.sigmoid() # repeat ``num_sample`` times along a new leading axis prob = prob.unsqueeze(0).expand(num_samples, -1, -1) sample = prob.bernoulli(generator=self._generator) return prob - sample else: raise NotImplementedError(f"Supported losses: {self.supported_losses}")