Note
Go to the end to download the full example code.
Inverses (natural gradient)
This example demonstrates how to work with inverses of linear operators.
curvlinops offers multiple ways to compute the inverse of a linear operator:
conjugate gradient (CG) and Neumann inversion. We will demonstrate CG inversion
first and conclude with a comparison to Neumann inversion.
Concretely, we will compute the natural gradient \(\mathbf{\tilde{g}} = \mathbf{F}^{-1} \mathbf{g}\), defined by the inverse Fisher information matrix \(\mathbf{F}^{-1}\) and the gradient \(\mathbf{g}\). We can use the GGN, as it corresponds to the Fisher for common loss functions like square and cross-entropy loss.
Note
The GGN is positive semi-definite, i.e. not full-rank. But we need a full-rank matrix to form the inverse. This is why we will add a damping term \(\delta \mathbf{I}\) before inverting.
As always, let’s first import the required functionality.
import matplotlib.pyplot as plt
from scipy.sparse.linalg import eigsh
from torch import cuda, device, eye, float64, manual_seed, rand
from torch.linalg import inv
from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid
from torch.nn.utils import parameters_to_vector
from curvlinops import (
CGInverseLinearOperator,
GGNLinearOperator,
NeumannInverseLinearOperator,
)
from curvlinops.examples import IdentityLinearOperator, gradient_and_loss
from curvlinops.examples.functorch import functorch_ggn, functorch_gradient_and_loss
from curvlinops.utils import allclose_report
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>
Setup
We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function.
N = 64
D_in = 7
D_hidden = 5
D_out = 3
DEVICE = device("cuda" if cuda.is_available() else "cpu")
DTYPE = float64 # double precision for better stability when computing inverse
X1, y1 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE)
X2, y2 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE)
model = Sequential(
Linear(D_in, D_hidden),
ReLU(),
Linear(D_hidden, D_hidden),
Sigmoid(),
Linear(D_hidden, D_out),
).to(DEVICE, DTYPE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
loss_function = MSELoss(reduction="mean").to(DEVICE, DTYPE)
Next, let’s compute the ingredients for the natural gradient.
Inverse GGN/Fisher
First, we set up a linear operator for the damped GGN/Fisher
and the linear operator of its inverse:
inverse_damped_GGN = CGInverseLinearOperator(
damped_GGN,
eps=0, # do not add CG-internal damping
# use a small number of iterations for a rough solution
max_iter=5,
max_tridiag_iter=5,
)
Gradient
We can obtain the gradient via a convenience function from curvlinops.examples:
gradient, _ = gradient_and_loss(model, loss_function, params, data)
# flatten and concatenate
gradient = parameters_to_vector(gradient).detach()
Natural gradient
Now we have all components together to compute the natural gradient with a simple matrix-vector product:
As a first sanity check, let’s compare if the natural gradient satisfies \(\mathbf{F} \mathbf{\tilde{g}} = \mathbf{g}\)
approx_gradient = damped_GGN @ natural_gradient
print("Comparing gradient with Fisher @ natural gradient.")
assert allclose_report(approx_gradient, gradient, rtol=1e-4, atol=1e-5)
Comparing gradient with Fisher @ natural gradient.
Verifying results
To check if the code works, let’s compute the GGN with functorch,
using a utility function of curvlinops.examples; then damp it, invert
it, and multiply it onto the gradient.
GGN_mat_functorch = functorch_ggn(model, loss_function, params, data).detach()
then damp it and invert it.
damping_mat = delta * eye(GGN_mat_functorch.shape[0], device=DEVICE, dtype=DTYPE)
damped_GGN_mat = GGN_mat_functorch + damping_mat
inv_damped_GGN_mat = inv(damped_GGN_mat)
Next, let’s compute the gradient with functorch, using a utility
function from curvlinops.examples:
gradient_functorch, _ = functorch_gradient_and_loss(model, loss_function, params, data)
# flatten and concatenate
gradient_functorch = parameters_to_vector(gradient_functorch).detach()
print("Comparing gradient with functorch's gradient.")
assert allclose_report(gradient, gradient_functorch)
Comparing gradient with functorch's gradient.
We can now compute the natural gradient from the functorch
quantities. This should yield approximately the same result:
natural_gradient_functorch = inv_damped_GGN_mat @ gradient_functorch
print("Comparing natural gradient with functorch's natural gradient.")
rtol, atol = 5e-3, 5e-5
assert allclose_report(
natural_gradient, natural_gradient_functorch, rtol=rtol, atol=atol
)
Comparing natural gradient with functorch's natural gradient.
You might have noticed the rather small tolerances required to achieve approximate equality. We can use stricter convergence hyperparameters for CG to achieve a more accurate inversion
inverse_damped_GGN = CGInverseLinearOperator(
damped_GGN,
eps=0, # do not add CG-internal damping
# increase number of iterations to get an better approximation
max_iter=10,
max_tridiag_iter=10,
)
natural_gradient_more_accurate = inverse_damped_GGN @ gradient
smaller_rtol, smaller_atol = rtol / 10, atol / 10
print("Comparing more accurate natural gradient with functorch's natural gradient.")
assert allclose_report(
natural_gradient_more_accurate,
natural_gradient_functorch,
rtol=smaller_rtol,
atol=smaller_atol,
)
Comparing more accurate natural gradient with functorch's natural gradient.
whereas the less accurate inversion does not pass this check:
print(
"Comparing natural gradient with functorch's natural gradient (smaller tolerances)."
)
try:
assert allclose_report(
natural_gradient,
natural_gradient_functorch,
rtol=smaller_rtol,
atol=smaller_atol,
)
raise RuntimeError("This comparison should not pass")
except AssertionError as e:
print(e)
Comparing natural gradient with functorch's natural gradient (smaller tolerances).
at index [4]: 4.04149e-02 ≠ 4.03777e-02, ratio: 1.00092e+00
at index [8]: 6.76268e-03 ≠ 6.81269e-03, ratio: 9.92659e-01
at index [9]: -1.36847e-02 ≠ -1.37049e-02, ratio: 9.98528e-01
at index [10]: -3.70819e-03 ≠ -3.66605e-03, ratio: 1.01150e+00
at index [11]: -7.79267e-03 ≠ -7.76049e-03, ratio: 1.00415e+00
at index [13]: -1.80134e-03 ≠ -1.76606e-03, ratio: 1.01998e+00
at index [22]: 3.16156e-02 ≠ 3.16480e-02, ratio: 9.98977e-01
at index [23]: 4.82754e-02 ≠ 4.83161e-02, ratio: 9.99158e-01
at index [24]: 1.27004e-02 ≠ 1.26878e-02, ratio: 1.00100e+00
at index [26]: -1.42644e-02 ≠ -1.42521e-02, ratio: 1.00086e+00
at index [27]: 3.20968e-02 ≠ 3.21585e-02, ratio: 9.98081e-01
at index [29]: 2.13577e-02 ≠ 2.13813e-02, ratio: 9.98897e-01
at index [30]: 2.14169e-02 ≠ 2.14483e-02, ratio: 9.98537e-01
at index [31]: 8.43465e-03 ≠ 8.45006e-03, ratio: 9.98176e-01
at index [32]: -2.02736e-03 ≠ -2.00121e-03, ratio: 1.01307e+00
at index [33]: 3.67601e-03 ≠ 3.70411e-03, ratio: 9.92413e-01
at index [34]: 4.81874e-04 ≠ 4.88223e-04, ratio: 9.86996e-01
at index [36]: -1.07670e-02 ≠ -1.07276e-02, ratio: 1.00368e+00
at index [39]: 1.91604e-02 ≠ 1.92130e-02, ratio: 9.97266e-01
at index [40]: -1.24558e-02 ≠ -1.24304e-02, ratio: 1.00204e+00
at index [41]: -2.01354e-03 ≠ -2.00333e-03, ratio: 1.00510e+00
at index [45]: 4.96902e-02 ≠ 4.96296e-02, ratio: 1.00122e+00
at index [46]: 4.18778e-04 ≠ 4.07667e-04, ratio: 1.02725e+00
at index [48]: -1.74391e-02 ≠ -1.74200e-02, ratio: 1.00110e+00
at index [49]: 2.25335e-02 ≠ 2.25163e-02, ratio: 1.00076e+00
at index [50]: 1.02638e-02 ≠ 1.02371e-02, ratio: 1.00260e+00
at index [54]: 4.42425e-02 ≠ 4.42709e-02, ratio: 9.99359e-01
at index [55]: 2.56494e-02 ≠ 2.56147e-02, ratio: 1.00135e+00
at index [70]: 9.86109e-03 ≠ 9.84801e-03, ratio: 1.00133e+00
at index [75]: 2.27872e-03 ≠ 2.27010e-03, ratio: 1.00380e+00
at index [77]: -1.58043e-02 ≠ -1.58248e-02, ratio: 9.98707e-01
Abs max: 1.35858e-01 vs. 1.35860e-01.
Abs min: 0.00000e+00 vs. 0.00000e+00.
Non-close entries: 31 / 88.
rtol = 0.0005, atol = 5e-06.
Visual comparison
Finally, let’s visualize the damped Fisher/GGN and its inverse. For improved visibility, we take the logarithm of the absolute value of each element (blank pixels correspond to zeros).
fig, ax = plt.subplots(ncols=2)
plt.suptitle("Logarithm of absolute values")
ax[0].set_title("Damped GGN/Fisher")
image = ax[0].imshow(damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax[0], shrink=0.5)
ax[1].set_title("Inv. damped GGN/Fisher")
image = ax[1].imshow(inv_damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax[1], shrink=0.5)

<matplotlib.colorbar.Colorbar object at 0x71f9441ba080>
Neumann inverse (CG alternative)
So far, we used CG to solve the linear system \(\mathbf{F}
\mathbf{\tilde{g}} = \mathbf{g}\) for the natural gradient
\(\mathbf{\tilde{g}}\) (i.e. the result of the inverse Fisher-gradient
product). Alternatively, we can use the truncated Neumann series to approximate the inverse,
using NeumannLinearOperator.
Note
The Neumann series does not always converge. But we can use a re-scaling trick to make it converge if we know the matrix is PSD and are given its largest eigenvalue. More information can be found in the docstring.
To make the Neumann series converge, we need to know the largest eigenvalue of the matrix to be inverted:
max_eigval = eigsh(damped_GGN.to_scipy(), k=1, which="LM", return_eigenvectors=False)[0]
# eigenvalues (scale * damped_GGN_mat) are in [0; 2)
scale = 1.0 if max_eigval < 2.0 else 1.99 / max_eigval
Let’s compute the inverse approximation for different truncation numbers:
Here are their visualizations:
fig, axes = plt.subplots(ncols=len(num_terms) + 1)
plt.suptitle("Inverse damped Fisher (logarithm of absolute values)")
for i, (n, inv) in enumerate(zip(num_terms, neumann_inverses)):
ax = axes.flat[i]
ax.set_title(f"Neumann, {n} terms")
image = ax.imshow(inv.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax, shrink=0.5)
ax = axes.flat[-1]
ax.set_title("Exact inverse")
image = ax.imshow(inv_damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax, shrink=0.5)

<matplotlib.colorbar.Colorbar object at 0x71f93ec0e8f0>
The Neumann inversion is usually more inaccurate than CG inversion. But it might sometimes be preferred if only a rough approximation of the inverse matrix product is needed.
Total running time of the script: (0 minutes 0.801 seconds)