Internals
This section is for internal purposes only and serves to inform developers about details; because rendered LaTeX is easier to read than source code.
GGN-related
- curvlinops.ggn_utils.loss_hessian_matrix_sqrt(
- output_one_datum: Tensor,
- target_one_datum: Tensor,
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
Compute the loss function’s matrix square root for a sample’s output.
- Parameters:
output_one_datum – The model’s prediction on a single datum. Has shape
[C, *D]for CE whereCis the number of classes, or[*D]for MSE/BCE with*Doptional (and potentially multiple) sequence dimensions. Has no batch axis.target_one_datum – The label of the single datum. Has shape
[*D]. Has no batch axis.loss_func – The loss function.
- Returns:
The matrix square root \(\mathbf{S}\) of the Hessian. Has shape
[C, *D, C, *D]for CE and[*D, *D]for BCE/MSE loss. Its matrix view satisfies\[\mathbf{S} \mathbf{S}^\top = \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y})\]where \(\mathbf{f} := f(\mathbf{x})\) is the model’s prediction on a single datum \(\mathbf{x}\) and \(\mathbf{y}\) is the label.
Below, we list the Hessian square roots for vector-valued predictions of shape
[C].Note
For
torch.nn.MSELoss(with \(c = 1\) forreduction='sum'and \(c = 1/C\) forreduction='mean'), we have:\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C \\ \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C\end{split}\]Note
For
torch.nn.CrossEntropyLoss(with \(c = 1\) irrespective of the reduction, \(\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C\), and the element-wise natural logarithm \(\log\)) we have:\[\begin{split}\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) = c \left( \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top \right) \\ \mathbf{S} = \sqrt{c} \left( \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top \right)\,,\end{split}\]where the square root is applied element-wise. See for instance Example 5.1 of this thesis or equations (5) and (6) of this paper.
Note
For
torch.nn.BCEWithLogitsLoss(with \(c = 1\) forreduction='sum'and \(c = 1/C\) forreduction='mean') we have (\(\sigma\) is the sigmoid function; targets may be any value in \([0, 1]\)):\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C - y_i \log(\sigma(f_i)) - (1 - y_i) \log(1 - \sigma(f_i)) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= c \mathrm{diag}( \sigma(f_i) \odot (1 - \sigma(f_i)) ) \\ \mathbf{S} &= \sqrt{c} \mathrm{diag}(\sqrt{\sigma(f_i) \odot (1 - \sigma(f_i))})\,,\end{split}\]where the square root is applied element-wise.
- Raises:
NotImplementedError – If the loss function is not supported.