Monte-Carlo approximation of the Fisher

In this tutorial, we will compare two approaches to compute the Fisher information matrix:

  1. The Fisher as expected Hessian under the model’s likelihood coincides with the generalized Gauss-Newton (GGN) matrix for common loss functions, like torch.nn.MSELoss and torch.nn.CrossEntropyLoss. For these settings, Fisher = GGN.

  2. The Fisher can also be seen as expectation of the gradient outer product w.r.t. the model’s likelihood. This expectation can be approximated by computing the outer product of ‘would-be’ gradients where the loss is evaluated on a label sampled from the model’s likelihood, rather than the true label.

The first approach is implemented by curvlinops.GGNLinearOperator (exact mode), the second by curvlinops.GGNLinearOperator with mc_samples > 0. We will see that both approaches coincide as the Monte-Carlo approximation converges.

Let’s get the imports out of our way.

from math import isclose

import matplotlib.pyplot as plt
from matplotlib import animation
from torch import (
    Tensor,
    cuda,
    device,
    eye,
    int32,
    logspace,
    manual_seed,
    rand,
    randint,
    unique,
    zeros,
)
from torch.linalg import matrix_norm
from torch.nn import CrossEntropyLoss, Linear, ReLU, Sequential, Sigmoid

from curvlinops import GGNLinearOperator

# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>

Setup

We will create a synthetic classification task, a small CNN, and use cross-entropy error as loss function.

Ns = [4, 6]
D_in = 7
D_hidden = 5
C = 3

DEVICE = device("cuda" if cuda.is_available() else "cpu")

data = [
    (
        rand(N, D_in, device=DEVICE),  # X
        randint(low=0, high=C, size=(N,), device=DEVICE),  # y
    )
    for N in Ns
]
model = Sequential(
    Linear(D_in, D_hidden),
    ReLU(),
    Linear(D_hidden, D_hidden),
    Sigmoid(),
    Linear(D_hidden, C),
).to(DEVICE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}

loss_function = CrossEntropyLoss(reduction="mean").to(DEVICE)

A first comparison

Let’s create linear operators for the GGN and the Monte-Carlo approximated Fisher and compute their matrix representations by multiplying them onto the identity matrix:

We can use the residual’s Frobenius norm to quantify the approximation error of the Monte-Carlo estimator:

residual_norm = matrix_norm(GGN_mat - F_mat)
print(f"Residual (Frobenius) norm: {residual_norm:.5f}")
Residual (Frobenius) norm: 0.37490

Setting the number of MC samples

To get more accurate estimates, we can use more samples in the MC approximation. This is achieved by specifying the optional mc_samples argument to curvlinops.GGNLinearOperator. The default value is 0 (exact GGN).

Here are the residual Frobenius norms when using more samples:

mc_samples = [1, 2, 4, 8]
residual_norms = []

for mc in mc_samples:
    F = GGNLinearOperator(model, loss_function, params, data, mc_samples=mc)
    F_mat = F @ identity
    residual_norms.append(matrix_norm(GGN_mat - F_mat))

for mc, norm in zip(mc_samples, residual_norms):
    print(f"mc_samples = {mc},\tresidual (Frobenius) norm = {norm:.5f}")
mc_samples = 1, residual (Frobenius) norm = 0.37490
mc_samples = 2, residual (Frobenius) norm = 0.30229
mc_samples = 4, residual (Frobenius) norm = 0.17022
mc_samples = 8, residual (Frobenius) norm = 0.06145

Setting the random seed

You may have noticed above that the two linear operators created with mc_samples=1 yield identical residual Frobenius norms. This is because the two linear operators realize the same matrix, i.e. the same sample from the Monte-Carlo estimator.

To see that curvlinops.GGNLinearOperator with MC samples indeed represents a deterministic matrix, let’s create two linear operators with identical hyperparameters and compare their matrix representations. After creating the first linear operator, we generate some random numbers to show that the global random number generator does not influence the Monte-Carlo estimator:

F1_mat = GGNLinearOperator(model, loss_function, params, data, mc_samples=1) @ identity

# draw some random numbers to modify the global random number generator's state
rand(123)

F2_mat = GGNLinearOperator(model, loss_function, params, data, mc_samples=1) @ identity

# still, we get the same deterministic approximation
residual_norm = matrix_norm(F1_mat - F2_mat).item()
if isclose(residual_norm, 0.0):
    print(residual_norm)
else:
    raise RuntimeError(f"Residual Frobenius norm should be 0. Got {residual_norm}.")
0.0

This is because the class uses an internal random number generator to draw samples. Therefore, it will not be affected by changes to the global random number generator’s state.

You can get different realizations of the Monte-Carlo estimator by specifying the optional seed argument. The above comparison with differently seeded linear operators leads to different matrices:

seed1 = 123456
F1_mat = (
    GGNLinearOperator(model, loss_function, params, data, mc_samples=1, seed=seed1)
    @ identity
)

seed2 = 654321
F2_mat = (
    GGNLinearOperator(model, loss_function, params, data, mc_samples=1, seed=seed2)
    @ identity
)

# now, we get two different deterministic approximations
residual_norm = matrix_norm(F1_mat - F2_mat).item()
if not isclose(residual_norm, 0.0):
    print(residual_norm)
else:
    raise RuntimeError(f"Residual Frobenius norm should be ≠0. Got {residual_norm}.")
0.15368805825710297

Approximation quality

Finally, let’s combine what we have seen so far to visualize how well the Monte-Carlo approximated Fisher approximates the GGN.

To do that, we will repeatedly draw samples for the Fisher information matrix and combine them to yield an estimate that incorporates all previous iterations. This approach allows to record snapshots of the estimator at a different number of total incorporated MC samples.

We will use a logspace for taking snapshots.

num_steps = 25
mc_samples = unique(logspace(0, 2, num_steps, dtype=int32))
F_snapshots = []
F_accumulated = zeros((D, D), device=DEVICE)
start_seed = 123456789

for seed, mc in enumerate(range(mc_samples.max()), start=start_seed):
    # NOTE Only use `check_deterministic=False` if you know what you are doing
    # We do this here because we have previously convinced ourselves that the created
    # linear operators indeed realize deterministic matrices.
    F = GGNLinearOperator(
        model,
        loss_function,
        params,
        data,
        mc_samples=1,
        seed=seed,
        check_deterministic=False,
    )
    F_accumulated += F @ identity
    if mc + 1 in mc_samples:
        F_snapshots.append(F_accumulated / (mc + 1))

Let’s visualize both the residual matrices and their Frobenius norms. To visualize the matrix, we will use the element-wise logarithm of its absolute value (shifted by a small constant to avoid taking the logarithm of 0):

residual_snapshots = [mat - GGN_mat for mat in F_snapshots]
residual_norms = [matrix_norm(res).item() for res in residual_snapshots]


def transform(mat: Tensor, epsilon: float = 1e-5) -> Tensor:
    """Transform the matrix before plotting.

    Applies element-wise absolute value, shifts by epsilon, then takes the
    element-wise logarithm.

    Args:
        mat: Matrix.
        epsilon: Small shift to avoid taking the log of 0.

    Returns:
        Transformed matrix.
    """
    return (mat.abs() + epsilon).log10()

Here’s the plotting code (feel free to skip to the visualization).

img_width = 4
rows, columns = 1, 2
fig, axes = plt.subplots(
    nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
)
ax_img, ax_fro = axes[0], axes[1]

min_img = min(transform(res).min() for res in residual_snapshots)
max_img = max(transform(res).max() for res in residual_snapshots)

min_fro = 0.0
max_fro = max(residual_norms)

ax_fro.set_xlim(mc_samples.min(), mc_samples.max())
ax_fro.semilogx()
ax_fro.set_xlabel("MC samples")
ax_fro.set_ylabel("residual Frobenius norm")
ax_fro.set_ylim(min_fro, 1.05 * max_fro)
ln_style = "bo"

# collects artists to draw in each frame of the animation
artists = []

for frame_idx in range(len(mc_samples)):
    snapshot = residual_snapshots[frame_idx].cpu()
    img = ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img, animated=True)

    # workaround for animated title: https://stackoverflow.com/a/47421938
    ax_img_title = plt.text(
        0.5,
        1.01,
        f"Residual ({mc_samples[frame_idx]} samples)",
        horizontalalignment="center",
        verticalalignment="bottom",
        transform=ax_img.transAxes,
    )

    if frame_idx == 0:
        ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img)
        plt.colorbar(img, ax=ax_img, label="logarithmic absolute entries", shrink=0.8)
        ax_fro.plot(
            mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style
        )
        plt.subplots_adjust(wspace=0.5, bottom=0.2)

    ln_fro = ax_fro.plot(
        mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style
    )

    artists.append([ax_img_title, img] + ln_fro)


ani = animation.ArtistAnimation(
    fig, artists, interval=1000, blit=False, repeat_delay=1000
)

Here’s a more qualitative comparison that contrasts the GGN matrix and the MC approximated Fisher (both transformed to logspace as described above):

img_width = 4
rows, columns = 1, 2
GGN_mat = GGN_mat.cpu()
F_snapshots = [fs.cpu() for fs in F_snapshots]
fig, axes = plt.subplots(
    nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
)

min_value = min(transform(mat).min() for mat in F_snapshots + [GGN_mat])
max_value = max(transform(mat).max() for mat in F_snapshots + [GGN_mat])

# collects artists to draw in each frame of the animation
artists = []

ax_GGN, ax_F = axes[0], axes[1]
ax_GGN.set_title("GGN")


for frame_idx in range(len(mc_samples)):
    im_GGN = ax_GGN.imshow(
        transform(GGN_mat), vmin=min_value, vmax=max_value, animated=True
    )
    im_F = ax_F.imshow(
        transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value, animated=True
    )

    # workaround for animated title: https://stackoverflow.com/a/47421938
    ax_F_title = plt.text(
        0.5,
        1.01,
        f"Fisher-MC ({mc_samples[frame_idx]} samples)",
        horizontalalignment="center",
        verticalalignment="bottom",
        transform=ax_F.transAxes,
    )

    if frame_idx == 0:
        img = ax_GGN.imshow(transform(GGN_mat), vmin=min_value, vmax=max_value)
        ax_F.imshow(transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value)
        fig.colorbar(
            img,
            ax=axes.ravel().tolist(),
            label="logarithmic absolute entries",
            shrink=0.8,
        )

    artists.append([ax_F_title, im_GGN, im_F])


ani = animation.ArtistAnimation(
    fig, artists, interval=1000, blit=False, repeat_delay=1000
)

That’s all for now.

Total running time of the script: (0 minutes 15.403 seconds)

Gallery generated by Sphinx-Gallery