Internals
This section is for internal purposes only and serves to inform developers about details; because rendered LaTeX is easier to read than source code.
KFAC-related
- curvlinops.kfac_utils.loss_hessian_matrix_sqrt(
- output_one_datum: Tensor,
- target_one_datum: Tensor,
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
Compute the loss function’s matrix square root for a sample’s output.
- Parameters:
output_one_datum – The model’s prediction on a single datum. Has shape
[1, C]whereCis the number of classes (outputs of the neural network).target_one_datum – The label of the single datum.
loss_func – The loss function.
- Returns:
The matrix square root \(\mathbf{S}\) of the Hessian. Has shape
[C, C]and satisfies the relation\[\mathbf{S} \mathbf{S}^\top = \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) \in \mathbb{R}^{C \times C}\]where \(\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C\) is the model’s prediction on a single datum \(\mathbf{x}\) and \(\mathbf{y}\) is the label.
Note
For
torch.nn.MSELoss(with \(c = 1\) forreduction='sum'and \(c = 1/C\) forreduction='mean'), we have:\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C \\ \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C\end{split}\]Note
For
torch.nn.CrossEntropyLoss(with \(c = 1\) irrespective of the reduction, \(\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C\), and the element-wise natural logarithm \(\log\)) we have:\[\begin{split}\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) = c \left( \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top \right) \\ \mathbf{S} = \sqrt{c} \left( \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top \right)\,,\end{split}\]where the square root is applied element-wise. See for instance Example 5.1 of this thesis or equations (5) and (6) of this paper.
Note
For
torch.nn.BCEWithLogitsLoss(with \(c = 1\) forreduction='sum'and \(c = 1/C\) forreduction='mean') we have (\(\sigma\) is the sigmoid function, and assuming binary labels):\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C - y_i \log(\sigma(f_i)) - (1 - y_i) \log(1 - \sigma(f_i)) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= c \mathrm{diag}( \sigma(f_i) \odot (1 - \sigma(f_i)) ) \\ \mathbf{S} &= \sqrt{c} \mathrm{diag}(\sqrt{\sigma(f_i) \odot (1 - \sigma(f_i))})\,,\end{split}\]where the square root is applied element-wise.
- Raises:
ValueError – If the batch size is not one, or the output is not 2d.
NotImplementedError – If the loss function is not supported.
NotImplementedError – If the loss function is
BCEWithLogitsLossbut the target is not binary.
EKFAC-related
- curvlinops.ekfac.compute_eigenvalue_correction_linear_weight_sharing(
- g: Tensor,
- ggT_eigvecs: Tensor,
- a: Tensor,
- aaT_eigvecs: Tensor | None,
- _force_strategy: str | None = None,
Computes eigenvalue corrections for a linear layer with weight sharing.
Chooses between two computational strategies depending on memory requirements.
- Parameters:
g – Output gradients of the layer with shape
[N, S, D1], whereNis the batch size,Sthe weight sharing dimension, andD1the output dimension.ggT_eigvecs – Eigenvectors of the gradient covariance with shape
[D1, D1].a – Layer inputs with shape
[N, S, D2], whereD2is the input dimension.aaT_eigvecs – Eigenvectors of the input covariance with shape
[D2, D2]orNoneif the layer has no weights (bias only)._force_strategy – If specified, forces the use of either
'gramian'or'per_example_gradients'strategy. IfNone, the strategy is chosen based on memory requirements. Defaults toNone. This flag serves mainly for testing purposes.
- Returns:
The eigencorrection with shape
[D1, D2](or[D1]for the bias case).- Raises:
ValueError – If an invalid
_force_strategyis provided.
Below we explain the mathematical details of what this function does. The mapping is is as follows: (
g, \(\mathbf{Y}\)), (ggT_eigvecs, \(\mathbf{Q}_1\)), (a, \(\mathbf{X}\)), (aaT_eigvecs, \(\mathbf{Q}_2\)).Note
Introduction: In the following, let \(D_1\) be the output dimension of the layer, \(D_2\) the input dimension, \(S\) the weight sharing dimension, and \(N\) the batch size.
Given the layer inputs \(\mathbf{X} \in \mathbb{R}^{N \times S \times D_2}\), output gradients \(\mathbf{Y} \in \mathbb{R}^{N \times S \times D_1}\), and a Kronecker-factored basis \(\mathbf{Q}_1 \otimes \mathbf{Q}_2\) with factors \(\mathbf{Q}_i \in \mathbb{R}^{D_i \times D_i}\), our goal is to compute the eigencorrection \(\mathbf{E} \in \mathbb{R}^{D_1 \times D_2}\) which has the same shape as the layer’s weights.
The common way to do that is to compute the per-example gradients \(\mathbf{G} \in \mathbb{R}^{N \times D_1 \times D_2}\) with
\[\mathbf{G}_{n,d_1,d_2} = \sum_s \mathbf{Y}_{n,s,d_1} \mathbf{X}_{n,s,d_2},\]rotate them into the Kronecker-factored basis,
\[\mathbf{\tilde{G}}_{n,d_1,d_2} = \sum_{i,j} \mathbf{G}_{n,i,j} \mathbf{Q}_{1,i,d_1} \mathbf{Q}_{2,j,d_2},\]and compute the correction by squaring and summing out the batch dimension,
\[\mathbf{E}_{d_1,d_2} = \sum_{n} \mathbf{\tilde{G}}_{n,d_1,d_2}^2.\]Building up the per-example gradients can be extremely memory-costly. Therefore, we also consider an alternative approach which can have smaller memory footprint if the weight sharing is mild.
Note
(1) Cost analysis of per-example gradient approach: The peak memory of building up per-example gradients is dominated by \(N D_1 D_2\).
We have two options to compute the rotated per-example gradient.
First compute \(\mathbf{G}\) and then rotate it. The first step costs \(N S D_1 D_2\) time and the rotation costs \(N D_1 D_2 (D_1 + D_2)\) time.
Rotate the activations and output gradients, then compute the rotated per-example gradient. The rotations cost \(N S (D_1^2 + D_2^2)\) time The last step is \(N S D_1 D_2\) time.
So in practise, we should prefer the first approach over the second if
\[D_1 D_2 (D_1 + D_2) < S (D_1^2 + D_2^2).\]In the implementation,
opt-einsumwill automatically do that for us.Adding the cost for squaring and contracting, the overall cost is \(N S D_1 D_2 + N \min(S (D_1^2 + D_2^2), D_1 D_2 (D_1 + D_2)) + 2 N D_1 D_2\).
Note
(2) Cost analysis of Gramian contraction approach: A way to avoid building up per-example gradients is to write the eigencorrection as big contraction of the rotated activations \(\mathbf{\tilde{X}}, \mathbf{\tilde{Y}}\) and then rearrange the contractions such that the batch dimension can be directly summed:
\[\begin{split}\mathbf{E}_{d_1,d_2} = \sum_{n} \left( \sum_{s} \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{X}}_{n,s,d_2} \right) \left( \sum_{t} \mathbf{\tilde{Y}}_{n,t,d_1} \mathbf{\tilde{X}}_{n,t,d_2} \right) \\ = \sum_{n} \sum_{s} \sum_{t} \left( \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{Y}}_{n,t,d_1} \right) \left( \mathbf{\tilde{X}}_{n,s,d_2} \mathbf{\tilde{X}}_{n,t,d_2} \right)\end{split}\]This requires building up the Gramians
\[\begin{split}\mathbf{G^Y}_{n,s,t,d_1} = \mathbf{\tilde{Y}}_{n,s,d_1} \mathbf{\tilde{Y}}_{n,t,d_1}, \\ \mathbf{G^X}_{n,s,t,d_2} = \mathbf{\tilde{X}}_{n,s,d_2} \mathbf{\tilde{X}}_{n,t,d_2}.\end{split}\]Peak memory is dominated by \(N S^2 (D_1 + D_2)\). The time is \(N S (D_1^2 + D_2^2)\) for the rotations, \(N S^2 (D_1 + D_2)\) for building up the Gramians, and \(N S^2 D_1 D_2\) for the final contraction. In total, this is \(N S (D_1^2 + D_2^2) + N S^2 (D_1 + D_2 + D_1 D_2)\).
We select the approach with the smaller memory footprint, i.e. Gramian contraction if \(S^2 (D_1 + D_2) < D_1 D_2\), and squaring per-example gradients otherwise. So generally speaking, the more weight sharing, the better building up per-example gradients will be.
In the extreme case \(S=1\) (no weight sharing), the Gramian contraction approach uses only \(N (D_1 + D_2) < N D_1 D_2\) memory compared to the per-example gradient approach. In terms of time, the Gramian contraction uses \(N (D_1^2 + D_2^2 + D_1 + D_2 + D_1 D_2) < N (3 D_1 D_2 + D_1^2 + D_2^2)\) compared to the per-example gradient.