Internals
This section is for internal purposes only and serves to inform developers about details; because rendered LaTeX is easier to read than source code.
KFAC-related
- curvlinops.kfac_utils.loss_hessian_matrix_sqrt(output_one_datum: Tensor, target_one_datum: Tensor, loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss) → Tensor[source]
Compute the loss function’s matrix square root for a sample’s output.
- Parameters:
output_one_datum – The model’s prediction on a single datum. Has shape
[1, C]
whereC
is the number of classes (outputs of the neural network).target_one_datum – The label of the single datum.
loss_func – The loss function.
- Returns:
The matrix square root \(\mathbf{S}\) of the Hessian. Has shape
[C, C]
and satisfies the relation\[\mathbf{S} \mathbf{S}^\top = \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) \in \mathbb{R}^{C \times C}\]where \(\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C\) is the model’s prediction on a single datum \(\mathbf{x}\) and \(\mathbf{y}\) is the label.
Note
For
torch.nn.MSELoss
(with \(c = 1\) forreduction='sum'
and \(c = 1/C\) forreduction='mean'
), we have:\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2 \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C \\ \mathbf{S} &= \sqrt{2 c} \mathbf{I}_C\end{split}\]Note
For
torch.nn.CrossEntropyLoss
(with \(c = 1\) irrespective of the reduction, \(\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C\), and the element-wise natural logarithm \(\log\)) we have:\[\begin{split}\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y) = c \left( \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top \right) \\ \mathbf{S} = \sqrt{c} \left( \mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top \right)\,,\end{split}\]where the square root is applied element-wise. See for instance Example 5.1 of this thesis or equations (5) and (6) of this paper.
Note
For
torch.nn.BCEWithLogitsLoss
(with \(c = 1\) forreduction='sum'
and \(c = 1/C\) forreduction='mean'
) we have (\(\sigma\) is the sigmoid function, and assuming binary labels):\[\begin{split}\ell(\mathbf{f}) &= c \sum_{i=1}^C - y_i \log(\sigma(f_i)) - (1 - y_i) \log(1 - \sigma(f_i)) \\ \nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= c \mathrm{diag}( \sigma(f_i) \odot (1 - \sigma(f_i)) ) \\ \mathbf{S} &= \sqrt{c} \mathrm{diag}(\sqrt{\sigma(f_i) \odot (1 - \sigma(f_i))})\,,\end{split}\]where the square root is applied element-wise.
- Raises:
ValueError – If the batch size is not one, or the output is not 2d.
NotImplementedError – If the loss function is not supported.
NotImplementedError – If the loss function is
BCEWithLogitsLoss
but the target is not binary.