Source code for curvlinops.kfac_utils

"""Utility functions related to KFAC."""

from math import sqrt
from typing import Tuple, Union

from einconv import index_pattern
from einconv.utils import get_conv_paddings
from einops import einsum, rearrange, reduce
from torch import Tensor, diag, eye
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.functional import unfold
from torch.nn.modules.utils import _pair


[docs] def loss_hessian_matrix_sqrt( output_one_datum: Tensor, target_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss, BCEWithLogitsLoss], ) -> Tensor: r"""Compute the loss function's matrix square root for a sample's output. Args: output_one_datum: The model's prediction on a single datum. Has shape ``[1, C]`` where ``C`` is the number of classes (outputs of the neural network). target_one_datum: The label of the single datum. loss_func: The loss function. Returns: The matrix square root :math:`\mathbf{S}` of the Hessian. Has shape ``[C, C]`` and satisfies the relation .. math:: \mathbf{S} \mathbf{S}^\top = \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) \in \mathbb{R}^{C \times C} where :math:`\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C` is the model's prediction on a single datum :math:`\mathbf{x}` and :math:`\mathbf{y}` is the label. Note: For :class:`torch.nn.MSELoss` (with :math:`c = 1` for ``reduction='sum'`` and :math:`c = 1/C` for ``reduction='mean'``), we have: .. math:: \ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C \\ \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C Note: For :class:`torch.nn.CrossEntropyLoss` (with :math:`c = 1` irrespective of the reduction, :math:`\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C`, and the element-wise natural logarithm :math:`\log`) we have: .. math:: \ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) = c \left( \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top \right) \\ \mathbf{S} = \sqrt{c} \left( \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top \right)\,, where the square root is applied element-wise. See for instance Example 5.1 of `this thesis <https://d-nb.info/1280233206/34>`_ or equations (5) and (6) of `this paper <https://arxiv.org/abs/1901.08244>`_. Note: For :class:`torch.nn.BCEWithLogitsLoss` (with :math:`c = 1` for ``reduction='sum'`` and :math:`c = 1/C` for ``reduction='mean'``) we have (:math:`\sigma` is the sigmoid function, and assuming binary labels): .. math:: \ell(\mathbf{f}) &= c \sum_{i=1}^C - y_i \log(\sigma(f_i)) - (1 - y_i) \log(1 - \sigma(f_i)) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= c \mathrm{diag}( \sigma(f_i) \odot (1 - \sigma(f_i)) ) \\ \mathbf{S} &= \sqrt{c} \mathrm{diag}(\sqrt{\sigma(f_i) \odot (1 - \sigma(f_i))})\,, where the square root is applied element-wise. Raises: ValueError: If the batch size is not one, or the output is not 2d. NotImplementedError: If the loss function is not supported. NotImplementedError: If the loss function is ``BCEWithLogitsLoss`` but the target is not binary. """ if output_one_datum.ndim != 2 or output_one_datum.shape[0] != 1: raise ValueError( f"Expected 'output_one_datum' to be 2d with shape [1, C], got " f"{output_one_datum.shape}" ) if target_one_datum.shape[0] != 1: # targets for 2d predictions are sometimes 1d raise ValueError( "Expected 'target_one_datum' to have batch_size 1." + f" Got {target_one_datum.shape}." ) output = output_one_datum.squeeze(0) output_dim = output.numel() if isinstance(loss_func, MSELoss): c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction] return eye(output_dim, device=output.device, dtype=output.dtype).mul_( sqrt(2 * c) ) elif isinstance(loss_func, CrossEntropyLoss): c = 1.0 p = output_one_datum.softmax(dim=1).squeeze() p_sqrt = p.sqrt() return (diag(p_sqrt) - einsum(p, p_sqrt, "i, j -> i j")).mul_(sqrt(c)) elif isinstance(loss_func, BCEWithLogitsLoss): unique = set(target_one_datum.unique().flatten().tolist()) if not unique.issubset({0, 1}): raise NotImplementedError( "Only binary targets (0, 1) are currently supported with" + f"BCEWithLogitsLoss. Got {unique}." ) c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction] p = output_one_datum.sigmoid().squeeze(0) hess_diag = sqrt(c) * (p * (1 - p)).sqrt() return hess_diag.diag() else: raise NotImplementedError(f"Loss function {loss_func} not supported.")
def extract_patches( x: Tensor, kernel_size: Union[Tuple[int, int], int], stride: Union[Tuple[int, int], int], padding: Union[Tuple[int, int], int, str], dilation: Union[Tuple[int, int], int], groups: int, ) -> Tensor: """Extract patches from the input of a 2d-convolution. The patches are averaged over channel groups. Args: x: Input to a 2d-convolution. Has shape ``[batch_size, C_in, I1, I2]``. kernel_size: The convolution's kernel size supplied as 2-tuple or integer. stride: The convolution's stride supplied as 2-tuple or integer. padding: The convolution's padding supplied as 2-tuple, integer, or string. dilation: The convolution's dilation supplied as 2-tuple or integer. groups: The number of channel groups. Returns: A tensor of shape ``[batch_size, O1 * O2, C_in // groups * K1 * K2]`` where each column ``[b, o1_o2, :]`` contains the flattened patch of sample ``b`` used for output location ``(o1, o2)``, averaged over channel groups. Raises: NotImplementedError: If ``padding`` is a string that would lead to unequal padding along a dimension. """ if isinstance(padding, str): # get padding as integers padding_as_int = [] for k, s, d in zip(_pair(kernel_size), _pair(stride), _pair(dilation)): p_left, p_right = get_conv_paddings(k, s, padding, d) if p_left != p_right: raise NotImplementedError("Unequal padding not supported in unfold.") padding_as_int.append(p_left) padding = tuple(padding_as_int) # average channel groups x = rearrange(x, "b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups) x = reduce(x, "b g c_in i1 i2 -> b c_in i1 i2", "mean") x_unfold = unfold(x, kernel_size, dilation=dilation, padding=padding, stride=stride) return rearrange(x_unfold, "b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2") def extract_averaged_patches( x: Tensor, kernel_size: Union[Tuple[int, int], int], stride: Union[Tuple[int, int], int], padding: Union[Tuple[int, int], int, str], dilation: Union[Tuple[int, int], int], groups: int, ) -> Tensor: """Extract averaged patches from the input of a 2d-convolution. The patches are averaged over channel groups and output locations. Uses the tensor network formulation of convolution from `Dangel, 2023 <https://arxiv.org/abs/2307.02275>`_. Args: x: Input to a 2d-convolution. Has shape ``[batch_size, C_in, I1, I2]``. kernel_size: The convolution's kernel size supplied as 2-tuple or integer. stride: The convolution's stride supplied as 2-tuple or integer. padding: The convolution's padding supplied as 2-tuple, integer, or string. dilation: The convolution's dilation supplied as 2-tuple or integer. groups: The number of channel groups. Returns: A tensor of shape ``[batch_size, C_in // groups * K1 * K2]`` where each column ``[b, :]`` contains the flattened patch of sample ``b`` averaged over all output locations and channel groups. """ # average channel groups x = rearrange(x, "b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups) x = reduce(x, "b g c_in i1 i2 -> b c_in i1 i2", "mean") # TODO For convolutions with special structure, we don't even need to compute # the index pattern tensors, or can resort to contracting only slices thereof. # In order for this to work `einconv`'s TN simplification mechanism must first # be refactored to work purely symbolically. Once this is done, it will be # possible to do the below even more efficiently (memory and run time) for # structured convolutions. # compute index pattern tensors, average output dimension patterns = [] input_sizes = x.shape[-2:] for i, k, s, p, d in zip( input_sizes, _pair(kernel_size), _pair(stride), (padding, padding) if isinstance(padding, str) else _pair(padding), _pair(dilation), ): pi = index_pattern( i, k, stride=s, padding=p, dilation=d, dtype=x.dtype, device=x.device ) pi = reduce(pi, "k o i -> k i", "mean") patterns.append(pi) x = einsum(x, *patterns, "b c_in i1 i2, k1 i1, k2 i2 -> b c_in k1 k2") return rearrange(x, "b c_in k1 k2 -> b (c_in k1 k2)")