Inverses (natural gradient)

This example demonstrates how to work with inverses of linear operators.

curvlinops offers multiple ways to compute the inverse of a linear operator: conjugate gradient (CG) and Neumann inversion. We will demonstrate CG inversion first and conclude with a comparison to Neumann inversion.

Concretely, we will compute the natural gradient \(\mathbf{\tilde{g}} = \mathbf{F}^{-1} \mathbf{g}\), defined by the inverse Fisher information matrix \(\mathbf{F}^{-1}\) and the gradient \(\mathbf{g}\). We can use the GGN, as it corresponds to the Fisher for common loss functions like square and cross-entropy loss.

Note

The GGN is positive semi-definite, i.e. not full-rank. But we need a full-rank matrix to form the inverse. This is why we will add a damping term \(\delta \mathbf{I}\) before inverting.

As always, let’s first import the required functionality.

import matplotlib.pyplot as plt
from scipy.sparse.linalg import eigsh
from torch import cuda, device, eye, float64, manual_seed, rand
from torch.linalg import inv
from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid
from torch.nn.utils import parameters_to_vector

from curvlinops import (
    CGInverseLinearOperator,
    GGNLinearOperator,
    NeumannInverseLinearOperator,
)
from curvlinops.examples import IdentityLinearOperator, gradient_and_loss
from curvlinops.examples.functorch import functorch_ggn, functorch_gradient_and_loss
from curvlinops.utils import allclose_report

# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>

Setup

We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function.

N = 64
D_in = 7
D_hidden = 5
D_out = 3

DEVICE = device("cuda" if cuda.is_available() else "cpu")
DTYPE = float64  # double precision for better stability when computing inverse

X1, y1 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE)
X2, y2 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE)

model = Sequential(
    Linear(D_in, D_hidden),
    ReLU(),
    Linear(D_hidden, D_hidden),
    Sigmoid(),
    Linear(D_hidden, D_out),
).to(DEVICE, DTYPE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}

loss_function = MSELoss(reduction="mean").to(DEVICE, DTYPE)

Next, let’s compute the ingredients for the natural gradient.

Inverse GGN/Fisher

First, we set up a linear operator for the damped GGN/Fisher

data = [(X1, y1), (X2, y2)]
GGN = GGNLinearOperator(model, loss_function, params, data)
shapes = [p.shape for p in params.values()]
delta = 1e-2
damping = delta * IdentityLinearOperator(shapes, GGN.device, DTYPE)
damped_GGN = GGN + damping

and the linear operator of its inverse:

inverse_damped_GGN = CGInverseLinearOperator(
    damped_GGN,
    eps=0,  # do not add CG-internal damping
    # use a small number of iterations for a rough solution
    max_iter=5,
    max_tridiag_iter=5,
)

Gradient

We can obtain the gradient via a convenience function from curvlinops.examples:

gradient, _ = gradient_and_loss(model, loss_function, params, data)
# flatten and concatenate
gradient = parameters_to_vector(gradient).detach()

Natural gradient

Now we have all components together to compute the natural gradient with a simple matrix-vector product:

As a first sanity check, let’s compare if the natural gradient satisfies \(\mathbf{F} \mathbf{\tilde{g}} = \mathbf{g}\)

approx_gradient = damped_GGN @ natural_gradient

print("Comparing gradient with Fisher @ natural gradient.")
assert allclose_report(approx_gradient, gradient, rtol=1e-4, atol=1e-5)
Comparing gradient with Fisher @ natural gradient.

Verifying results

To check if the code works, let’s compute the GGN with functorch, using a utility function of curvlinops.examples; then damp it, invert it, and multiply it onto the gradient.

GGN_mat_functorch = functorch_ggn(model, loss_function, params, data).detach()

then damp it and invert it.

Next, let’s compute the gradient with functorch, using a utility function from curvlinops.examples:

gradient_functorch, _ = functorch_gradient_and_loss(model, loss_function, params, data)
# flatten and concatenate
gradient_functorch = parameters_to_vector(gradient_functorch).detach()

print("Comparing gradient with functorch's gradient.")
assert allclose_report(gradient, gradient_functorch)
Comparing gradient with functorch's gradient.

We can now compute the natural gradient from the functorch quantities. This should yield approximately the same result:

natural_gradient_functorch = inv_damped_GGN_mat @ gradient_functorch

print("Comparing natural gradient with functorch's natural gradient.")
rtol, atol = 5e-3, 5e-5
assert allclose_report(
    natural_gradient, natural_gradient_functorch, rtol=rtol, atol=atol
)
Comparing natural gradient with functorch's natural gradient.

You might have noticed the rather small tolerances required to achieve approximate equality. We can use stricter convergence hyperparameters for CG to achieve a more accurate inversion

inverse_damped_GGN = CGInverseLinearOperator(
    damped_GGN,
    eps=0,  # do not add CG-internal damping
    # increase number of iterations to get an better approximation
    max_iter=10,
    max_tridiag_iter=10,
)
natural_gradient_more_accurate = inverse_damped_GGN @ gradient

smaller_rtol, smaller_atol = rtol / 10, atol / 10
print("Comparing more accurate natural gradient with functorch's natural gradient.")
assert allclose_report(
    natural_gradient_more_accurate,
    natural_gradient_functorch,
    rtol=smaller_rtol,
    atol=smaller_atol,
)
Comparing more accurate natural gradient with functorch's natural gradient.

whereas the less accurate inversion does not pass this check:

print(
    "Comparing natural gradient with functorch's natural gradient (smaller tolerances)."
)
try:
    assert allclose_report(
        natural_gradient,
        natural_gradient_functorch,
        rtol=smaller_rtol,
        atol=smaller_atol,
    )
    raise RuntimeError("This comparison should not pass")
except AssertionError as e:
    print(e)
Comparing natural gradient with functorch's natural gradient (smaller tolerances).
at index [4]: 4.04149e-02 ≠ 4.03777e-02, ratio: 1.00092e+00
at index [8]: 6.76268e-03 ≠ 6.81269e-03, ratio: 9.92659e-01
at index [9]: -1.36847e-02 ≠ -1.37049e-02, ratio: 9.98528e-01
at index [10]: -3.70819e-03 ≠ -3.66605e-03, ratio: 1.01150e+00
at index [11]: -7.79267e-03 ≠ -7.76049e-03, ratio: 1.00415e+00
at index [13]: -1.80134e-03 ≠ -1.76606e-03, ratio: 1.01998e+00
at index [22]: 3.16156e-02 ≠ 3.16480e-02, ratio: 9.98977e-01
at index [23]: 4.82754e-02 ≠ 4.83161e-02, ratio: 9.99158e-01
at index [24]: 1.27004e-02 ≠ 1.26878e-02, ratio: 1.00100e+00
at index [26]: -1.42644e-02 ≠ -1.42521e-02, ratio: 1.00086e+00
at index [27]: 3.20968e-02 ≠ 3.21585e-02, ratio: 9.98081e-01
at index [29]: 2.13577e-02 ≠ 2.13813e-02, ratio: 9.98897e-01
at index [30]: 2.14169e-02 ≠ 2.14483e-02, ratio: 9.98537e-01
at index [31]: 8.43465e-03 ≠ 8.45006e-03, ratio: 9.98176e-01
at index [32]: -2.02736e-03 ≠ -2.00121e-03, ratio: 1.01307e+00
at index [33]: 3.67601e-03 ≠ 3.70411e-03, ratio: 9.92413e-01
at index [34]: 4.81874e-04 ≠ 4.88223e-04, ratio: 9.86996e-01
at index [36]: -1.07670e-02 ≠ -1.07276e-02, ratio: 1.00368e+00
at index [39]: 1.91604e-02 ≠ 1.92130e-02, ratio: 9.97266e-01
at index [40]: -1.24558e-02 ≠ -1.24304e-02, ratio: 1.00204e+00
at index [41]: -2.01354e-03 ≠ -2.00333e-03, ratio: 1.00510e+00
at index [45]: 4.96902e-02 ≠ 4.96296e-02, ratio: 1.00122e+00
at index [46]: 4.18778e-04 ≠ 4.07667e-04, ratio: 1.02725e+00
at index [48]: -1.74391e-02 ≠ -1.74200e-02, ratio: 1.00110e+00
at index [49]: 2.25335e-02 ≠ 2.25163e-02, ratio: 1.00076e+00
at index [50]: 1.02638e-02 ≠ 1.02371e-02, ratio: 1.00260e+00
at index [54]: 4.42425e-02 ≠ 4.42709e-02, ratio: 9.99359e-01
at index [55]: 2.56494e-02 ≠ 2.56147e-02, ratio: 1.00135e+00
at index [70]: 9.86109e-03 ≠ 9.84801e-03, ratio: 1.00133e+00
at index [75]: 2.27872e-03 ≠ 2.27010e-03, ratio: 1.00380e+00
at index [77]: -1.58043e-02 ≠ -1.58248e-02, ratio: 9.98707e-01
Abs max: 1.35858e-01 vs. 1.35860e-01.
Abs min: 0.00000e+00 vs. 0.00000e+00.
Non-close entries: 31 / 88.
rtol = 0.0005, atol = 5e-06.

Visual comparison

Finally, let’s visualize the damped Fisher/GGN and its inverse. For improved visibility, we take the logarithm of the absolute value of each element (blank pixels correspond to zeros).

fig, ax = plt.subplots(ncols=2)
plt.suptitle("Logarithm of absolute values")

ax[0].set_title("Damped GGN/Fisher")
image = ax[0].imshow(damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax[0], shrink=0.5)

ax[1].set_title("Inv. damped GGN/Fisher")
image = ax[1].imshow(inv_damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax[1], shrink=0.5)
Logarithm of absolute values, Damped GGN/Fisher, Inv. damped GGN/Fisher
<matplotlib.colorbar.Colorbar object at 0x71f9441ba080>

Neumann inverse (CG alternative)

So far, we used CG to solve the linear system \(\mathbf{F} \mathbf{\tilde{g}} = \mathbf{g}\) for the natural gradient \(\mathbf{\tilde{g}}\) (i.e. the result of the inverse Fisher-gradient product). Alternatively, we can use the truncated Neumann series to approximate the inverse, using NeumannLinearOperator.

Note

The Neumann series does not always converge. But we can use a re-scaling trick to make it converge if we know the matrix is PSD and are given its largest eigenvalue. More information can be found in the docstring.

To make the Neumann series converge, we need to know the largest eigenvalue of the matrix to be inverted:

max_eigval = eigsh(damped_GGN.to_scipy(), k=1, which="LM", return_eigenvectors=False)[0]
# eigenvalues (scale * damped_GGN_mat) are in [0; 2)
scale = 1.0 if max_eigval < 2.0 else 1.99 / max_eigval

Let’s compute the inverse approximation for different truncation numbers:

num_terms = [10]
neumann_inverses = []

for n in num_terms:
    inv = NeumannInverseLinearOperator(damped_GGN, scale=scale, num_terms=n)
    neumann_inverses.append(inv @ eye(inv.shape[1], device=DEVICE, dtype=DTYPE))

Here are their visualizations:

fig, axes = plt.subplots(ncols=len(num_terms) + 1)
plt.suptitle("Inverse damped Fisher (logarithm of absolute values)")

for i, (n, inv) in enumerate(zip(num_terms, neumann_inverses)):
    ax = axes.flat[i]
    ax.set_title(f"Neumann, {n} terms")
    image = ax.imshow(inv.detach().cpu().abs().log10())
    plt.colorbar(image, ax=ax, shrink=0.5)

ax = axes.flat[-1]
ax.set_title("Exact inverse")
image = ax.imshow(inv_damped_GGN_mat.detach().cpu().abs().log10())
plt.colorbar(image, ax=ax, shrink=0.5)
Inverse damped Fisher (logarithm of absolute values), Neumann, 10 terms, Exact inverse
<matplotlib.colorbar.Colorbar object at 0x71f93ec0e8f0>

The Neumann inversion is usually more inaccurate than CG inversion. But it might sometimes be preferred if only a rough approximation of the inverse matrix product is needed.

Total running time of the script: (0 minutes 0.801 seconds)

Gallery generated by Sphinx-Gallery