Internals
This section is for internal purposes only and serves to inform developers about details; because rendered LaTeX is easier to read than source code.
KFAC-related
- curvlinops.kfac_utils.loss_hessian_matrix_sqrt(output_one_datum: Tensor, loss_func: MSELoss | CrossEntropyLoss) → Tensor
Compute the loss function’s matrix square root for a sample’s output.
- Parameters:
output_one_datum – The model’s prediction on a single datum. Has shape
[1, C]whereCis the number of classes (outputs of the neural network).loss_func – The loss function.
- Returns:
The matrix square root \(\mathbf{S}\) of the Hessian. Has shape
[C, C]and satisfies the relation\[\mathbf{S} \mathbf{S}^\top = \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) \in \mathbb{R}^{C \times C}\]where \(\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C\) is the model’s prediction on a single datum \(\mathbf{x}\) and \(\mathbf{y}\) is the label.
Note
For
torch.nn.MSELoss(with \(c = 1\) forreduction='sum'and \(c = 1/C\) forreduction='mean'), we have:\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C \\ \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C\end{split}\]Note
For
torch.nn.CrossEntropyLoss(with \(c = 1\) irrespective of the reduction, \(\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C\), and the element-wise natural logarithm \(\log\)) we have:\[\begin{split}\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) = c \left( \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top \right) \\ \mathbf{S} = \sqrt{c} \left( \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top \right)\,,\end{split}\]where the square root is applied element-wise. See for instance Example 5.1 of this thesis or equations (5) and (6) of this paper.
- Raises:
ValueError – If the batch size is not one, or the output is not 2d.
NotImplementedError – If the loss function is not supported.