Note
Go to the end to download the full example code.
Monte-Carlo approximation of the Fisher
In this tutorial, we will compare two approaches to compute the Fisher information matrix:
The Fisher as expected Hessian under the model’s likelihood coincides with the generalized Gauss-Newton (GGN) matrix for common loss functions, like
torch.nn.MSELossandtorch.nn.CrossEntropyLoss. For these settings, Fisher = GGN.The Fisher can also be seen as expectation of the gradient outer product w.r.t. the model’s likelihood. This expectation can be approximated by computing the outer product of ‘would-be’ gradients where the loss is evaluated on a label sampled from the model’s likelihood, rather than the true label.
The first approach is implemented by curvlinops.GGNLinearOperator (exact mode),
the second by curvlinops.GGNLinearOperator with mc_samples > 0. We will see
that both approaches coincide as the Monte-Carlo approximation converges.
Let’s get the imports out of our way.
from math import isclose
import matplotlib.pyplot as plt
from matplotlib import animation
from torch import (
Tensor,
cuda,
device,
eye,
int32,
logspace,
manual_seed,
rand,
randint,
unique,
zeros,
)
from torch.linalg import matrix_norm
from torch.nn import CrossEntropyLoss, Linear, ReLU, Sequential, Sigmoid
from curvlinops import GGNLinearOperator
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>
Setup
We will create a synthetic classification task, a small CNN, and use cross-entropy error as loss function.
Ns = [4, 6]
D_in = 7
D_hidden = 5
C = 3
DEVICE = device("cuda" if cuda.is_available() else "cpu")
data = [
(
rand(N, D_in, device=DEVICE), # X
randint(low=0, high=C, size=(N,), device=DEVICE), # y
)
for N in Ns
]
model = Sequential(
Linear(D_in, D_hidden),
ReLU(),
Linear(D_hidden, D_hidden),
Sigmoid(),
Linear(D_hidden, C),
).to(DEVICE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
loss_function = CrossEntropyLoss(reduction="mean").to(DEVICE)
A first comparison
Let’s create linear operators for the GGN and the Monte-Carlo approximated Fisher and compute their matrix representations by multiplying them onto the identity matrix:
GGN = GGNLinearOperator(model, loss_function, params, data)
F = GGNLinearOperator(model, loss_function, params, data, mc_samples=1)
D = GGN.shape[0]
identity = eye(D, device=DEVICE)
GGN_mat = GGN @ identity
F_mat = F @ identity
We can use the residual’s Frobenius norm to quantify the approximation error of the Monte-Carlo estimator:
residual_norm = matrix_norm(GGN_mat - F_mat)
print(f"Residual (Frobenius) norm: {residual_norm:.5f}")
Residual (Frobenius) norm: 0.37490
Setting the number of MC samples
To get more accurate estimates, we can use more samples in the MC approximation.
This is achieved by specifying the optional mc_samples argument to
curvlinops.GGNLinearOperator. The default value is 0 (exact GGN).
Here are the residual Frobenius norms when using more samples:
mc_samples = [1, 2, 4, 8]
residual_norms = []
for mc in mc_samples:
F = GGNLinearOperator(model, loss_function, params, data, mc_samples=mc)
F_mat = F @ identity
residual_norms.append(matrix_norm(GGN_mat - F_mat))
for mc, norm in zip(mc_samples, residual_norms):
print(f"mc_samples = {mc},\tresidual (Frobenius) norm = {norm:.5f}")
mc_samples = 1, residual (Frobenius) norm = 0.37490
mc_samples = 2, residual (Frobenius) norm = 0.30229
mc_samples = 4, residual (Frobenius) norm = 0.17022
mc_samples = 8, residual (Frobenius) norm = 0.06145
Setting the random seed
You may have noticed above that the two linear operators created with
mc_samples=1 yield identical residual Frobenius norms. This is
because the two linear operators realize the same matrix, i.e. the same
sample from the Monte-Carlo estimator.
To see that curvlinops.GGNLinearOperator with MC samples indeed represents a
deterministic matrix, let’s create two linear operators with identical
hyperparameters and compare their matrix representations. After creating the
first linear operator, we generate some random numbers to show that the
global random number generator does not influence the Monte-Carlo estimator:
F1_mat = GGNLinearOperator(model, loss_function, params, data, mc_samples=1) @ identity
# draw some random numbers to modify the global random number generator's state
rand(123)
F2_mat = GGNLinearOperator(model, loss_function, params, data, mc_samples=1) @ identity
# still, we get the same deterministic approximation
residual_norm = matrix_norm(F1_mat - F2_mat).item()
if isclose(residual_norm, 0.0):
print(residual_norm)
else:
raise RuntimeError(f"Residual Frobenius norm should be 0. Got {residual_norm}.")
0.0
This is because the class uses an internal random number generator to draw samples. Therefore, it will not be affected by changes to the global random number generator’s state.
You can get different realizations of the Monte-Carlo estimator by specifying
the optional seed argument. The above comparison with differently
seeded linear operators leads to different matrices:
seed1 = 123456
F1_mat = (
GGNLinearOperator(model, loss_function, params, data, mc_samples=1, seed=seed1)
@ identity
)
seed2 = 654321
F2_mat = (
GGNLinearOperator(model, loss_function, params, data, mc_samples=1, seed=seed2)
@ identity
)
# now, we get two different deterministic approximations
residual_norm = matrix_norm(F1_mat - F2_mat).item()
if not isclose(residual_norm, 0.0):
print(residual_norm)
else:
raise RuntimeError(f"Residual Frobenius norm should be ≠0. Got {residual_norm}.")
0.15368805825710297
Approximation quality
Finally, let’s combine what we have seen so far to visualize how well the Monte-Carlo approximated Fisher approximates the GGN.
To do that, we will repeatedly draw samples for the Fisher information matrix and combine them to yield an estimate that incorporates all previous iterations. This approach allows to record snapshots of the estimator at a different number of total incorporated MC samples.
We will use a logspace for taking snapshots.
num_steps = 25
mc_samples = unique(logspace(0, 2, num_steps, dtype=int32))
F_snapshots = []
F_accumulated = zeros((D, D), device=DEVICE)
start_seed = 123456789
for seed, mc in enumerate(range(mc_samples.max()), start=start_seed):
# NOTE Only use `check_deterministic=False` if you know what you are doing
# We do this here because we have previously convinced ourselves that the created
# linear operators indeed realize deterministic matrices.
F = GGNLinearOperator(
model,
loss_function,
params,
data,
mc_samples=1,
seed=seed,
check_deterministic=False,
)
F_accumulated += F @ identity
if mc + 1 in mc_samples:
F_snapshots.append(F_accumulated / (mc + 1))
Let’s visualize both the residual matrices and their Frobenius norms. To visualize the matrix, we will use the element-wise logarithm of its absolute value (shifted by a small constant to avoid taking the logarithm of 0):
residual_snapshots = [mat - GGN_mat for mat in F_snapshots]
residual_norms = [matrix_norm(res).item() for res in residual_snapshots]
def transform(mat: Tensor, epsilon: float = 1e-5) -> Tensor:
"""Transform the matrix before plotting.
Applies element-wise absolute value, shifts by epsilon, then takes the
element-wise logarithm.
Args:
mat: Matrix.
epsilon: Small shift to avoid taking the log of 0.
Returns:
Transformed matrix.
"""
return (mat.abs() + epsilon).log10()
Here’s the plotting code (feel free to skip to the visualization).
img_width = 4
rows, columns = 1, 2
fig, axes = plt.subplots(
nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
)
ax_img, ax_fro = axes[0], axes[1]
min_img = min(transform(res).min() for res in residual_snapshots)
max_img = max(transform(res).max() for res in residual_snapshots)
min_fro = 0.0
max_fro = max(residual_norms)
ax_fro.set_xlim(mc_samples.min(), mc_samples.max())
ax_fro.semilogx()
ax_fro.set_xlabel("MC samples")
ax_fro.set_ylabel("residual Frobenius norm")
ax_fro.set_ylim(min_fro, 1.05 * max_fro)
ln_style = "bo"
# collects artists to draw in each frame of the animation
artists = []
for frame_idx in range(len(mc_samples)):
snapshot = residual_snapshots[frame_idx].cpu()
img = ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img, animated=True)
# workaround for animated title: https://stackoverflow.com/a/47421938
ax_img_title = plt.text(
0.5,
1.01,
f"Residual ({mc_samples[frame_idx]} samples)",
horizontalalignment="center",
verticalalignment="bottom",
transform=ax_img.transAxes,
)
if frame_idx == 0:
ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img)
plt.colorbar(img, ax=ax_img, label="logarithmic absolute entries", shrink=0.8)
ax_fro.plot(
mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style
)
plt.subplots_adjust(wspace=0.5, bottom=0.2)
ln_fro = ax_fro.plot(
mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style
)
artists.append([ax_img_title, img] + ln_fro)
ani = animation.ArtistAnimation(
fig, artists, interval=1000, blit=False, repeat_delay=1000
)
Here’s a more qualitative comparison that contrasts the GGN matrix and the MC approximated Fisher (both transformed to logspace as described above):
img_width = 4
rows, columns = 1, 2
GGN_mat = GGN_mat.cpu()
F_snapshots = [fs.cpu() for fs in F_snapshots]
fig, axes = plt.subplots(
nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width)
)
min_value = min(transform(mat).min() for mat in F_snapshots + [GGN_mat])
max_value = max(transform(mat).max() for mat in F_snapshots + [GGN_mat])
# collects artists to draw in each frame of the animation
artists = []
ax_GGN, ax_F = axes[0], axes[1]
ax_GGN.set_title("GGN")
for frame_idx in range(len(mc_samples)):
im_GGN = ax_GGN.imshow(
transform(GGN_mat), vmin=min_value, vmax=max_value, animated=True
)
im_F = ax_F.imshow(
transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value, animated=True
)
# workaround for animated title: https://stackoverflow.com/a/47421938
ax_F_title = plt.text(
0.5,
1.01,
f"Fisher-MC ({mc_samples[frame_idx]} samples)",
horizontalalignment="center",
verticalalignment="bottom",
transform=ax_F.transAxes,
)
if frame_idx == 0:
img = ax_GGN.imshow(transform(GGN_mat), vmin=min_value, vmax=max_value)
ax_F.imshow(transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value)
fig.colorbar(
img,
ax=axes.ravel().tolist(),
label="logarithmic absolute entries",
shrink=0.8,
)
artists.append([ax_F_title, im_GGN, im_F])
ani = animation.ArtistAnimation(
fig, artists, interval=1000, blit=False, repeat_delay=1000
)
That’s all for now.
Total running time of the script: (0 minutes 15.403 seconds)