.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_fisher_monte_carlo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_fisher_monte_carlo.py: Monte-Carlo approximation of the Fisher ======================================= In this tutorial, we will compare two approaches to compute the Fisher information matrix: 1. The Fisher as expected Hessian under the model's likelihood coincides with the generalized Gauss-Newton (GGN) matrix for common loss functions, like :class:`torch.nn.MSELoss` and :class:`torch.nn.CrossEntropyLoss`. For these settings, Fisher = GGN. 2. The Fisher can also be seen as expectation of the gradient outer product w.r.t. the model's likelihood. This expectation can be approximated by computing the outer product of 'would-be' gradients where the loss is evaluated on a label sampled from the model's likelihood, rather than the true label. The first approach is implemented by :class:`curvlinops.GGNLinearOperator`, the second by :class:`curvlinops.FisherMCLinearOperator`. We will see that both approaches coincide as the Monte-Carlo approximation converges. Let's get the imports out of our way. .. GENERATED FROM PYTHON SOURCE LINES 24-50 .. code-block:: Python from math import isclose import matplotlib.pyplot as plt from matplotlib import animation from torch import ( Tensor, cuda, device, eye, int32, logspace, manual_seed, rand, randint, unique, zeros, ) from torch.linalg import matrix_norm from torch.nn import CrossEntropyLoss, Linear, ReLU, Sequential, Sigmoid from curvlinops import FisherMCLinearOperator, GGNLinearOperator # make deterministic manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 51-56 Setup ----- We will create a synthetic classification task, a small CNN, and use cross-entropy error as loss function. .. GENERATED FROM PYTHON SOURCE LINES 56-83 .. code-block:: Python Ns = [4, 6] D_in = 7 D_hidden = 5 C = 3 DEVICE = device("cuda" if cuda.is_available() else "cpu") data = [ ( rand(N, D_in, device=DEVICE), # X randint(low=0, high=C, size=(N,), device=DEVICE), # y ) for N in Ns ] model = Sequential( Linear(D_in, D_hidden), ReLU(), Linear(D_hidden, D_hidden), Sigmoid(), Linear(D_hidden, C), ).to(DEVICE) params = [p for p in model.parameters() if p.requires_grad] loss_function = CrossEntropyLoss(reduction="mean").to(DEVICE) .. GENERATED FROM PYTHON SOURCE LINES 84-90 A first comparison ------------------ Let's create linear operators for the GGN and the Monte-Carlo approximated Fisher and compute their matrix representations by multiplying them onto the identity matrix: .. GENERATED FROM PYTHON SOURCE LINES 90-100 .. code-block:: Python GGN = GGNLinearOperator(model, loss_function, params, data) F = FisherMCLinearOperator(model, loss_function, params, data) D = GGN.shape[0] identity = eye(D, device=DEVICE) GGN_mat = GGN @ identity F_mat = F @ identity .. GENERATED FROM PYTHON SOURCE LINES 101-103 We can use the residual's Frobenius norm to quantify the approximation error of the Monte-Carlo estimator: .. GENERATED FROM PYTHON SOURCE LINES 104-107 .. code-block:: Python residual_norm = matrix_norm(GGN_mat - F_mat) print(f"Residual (Frobenius) norm: {residual_norm:.5f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Residual (Frobenius) norm: 0.37490 .. GENERATED FROM PYTHON SOURCE LINES 108-116 Setting the number of MC samples -------------------------------- To get more accurate estimates, we can use more samples in the MC approximation. This is achieved by specifying the optional :code:`mc_samples` argument to :class:`curvlinops.FisherMCLinearOperator`. The default value is :code:`1`. Here are the residual Frobenius norms when using more samples: .. GENERATED FROM PYTHON SOURCE LINES 116-128 .. code-block:: Python mc_samples = [1, 2, 4, 8] residual_norms = [] for mc in mc_samples: F = FisherMCLinearOperator(model, loss_function, params, data, mc_samples=mc) F_mat = F @ identity residual_norms.append(matrix_norm(GGN_mat - F_mat)) for mc, norm in zip(mc_samples, residual_norms): print(f"mc_samples = {mc},\tresidual (Frobenius) norm = {norm:.5f}") .. rst-class:: sphx-glr-script-out .. code-block:: none mc_samples = 1, residual (Frobenius) norm = 0.37490 mc_samples = 2, residual (Frobenius) norm = 0.30229 mc_samples = 4, residual (Frobenius) norm = 0.17022 mc_samples = 8, residual (Frobenius) norm = 0.06145 .. GENERATED FROM PYTHON SOURCE LINES 129-142 Setting the random seed ----------------------- You may have noticed above that the two linear operators created with :code:`mc_samples=1` yield identical residual Frobenius norms. This is because the two linear operators realize the same matrix, i.e. the same sample from the Monte-Carlo estimator. To see that :class:`curvlinops.FisherMCLinearOperator` indeed represents a deterministic matrix, let's create two linear operators with identical hyperparameters and compare their matrix representations. After creating the first linear operator, we generate some random numbers to show that the global random number generator does not influence the Monte-Carlo estimator: .. GENERATED FROM PYTHON SOURCE LINES 142-157 .. code-block:: Python F1_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity # draw some random numbers to modify the global random number generator's state rand(123) F2_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity # still, we get the same deterministic approximation residual_norm = matrix_norm(F1_mat - F2_mat).item() if isclose(residual_norm, 0.0): print(residual_norm) else: raise RuntimeError(f"Residual Frobenius norm should be 0. Got {residual_norm}.") .. rst-class:: sphx-glr-script-out .. code-block:: none 0.0 .. GENERATED FROM PYTHON SOURCE LINES 158-165 This is because the class uses an internal random number generator to draw samples. Therefore, it will not be affected by changes to the global random number generator's state. You can get different realizations of the Monte-Carlo estimator by specifying the optional :code:`seed` argument. The above comparison with differently seeded linear operators leads to different matrices: .. GENERATED FROM PYTHON SOURCE LINES 166-184 .. code-block:: Python seed1 = 123456 F1_mat = ( FisherMCLinearOperator(model, loss_function, params, data, seed=seed1) @ identity ) seed2 = 654321 F2_mat = ( FisherMCLinearOperator(model, loss_function, params, data, seed=seed2) @ identity ) # now, we get two different deterministic approximations residual_norm = matrix_norm(F1_mat - F2_mat).item() if not isclose(residual_norm, 0.0): print(residual_norm) else: raise RuntimeError(f"Residual Frobenius norm should be ≠0. Got {residual_norm}.") .. rst-class:: sphx-glr-script-out .. code-block:: none 0.15368808805942535 .. GENERATED FROM PYTHON SOURCE LINES 185-197 Approximation quality --------------------- Finally, let's combine what we have seen so far to visualize how well the Monte-Carlo approximated Fisher approximates the GGN. To do that, we will repeatedly draw samples for the Fisher information matrix and combine them to yield an estimate that incorporates all previous iterations. This approach allows to record snapshots of the estimator at a different number of total incorporated MC samples. We will use a :code:`logspace` for taking snapshots. .. GENERATED FROM PYTHON SOURCE LINES 198-217 .. code-block:: Python num_steps = 25 mc_samples = unique(logspace(0, 2, num_steps, dtype=int32)) F_snapshots = [] F_accumulated = zeros((D, D), device=DEVICE) start_seed = 123456789 for seed, mc in enumerate(range(mc_samples.max()), start=start_seed): # NOTE Only use `check_deterministic=False` if you know what you are doing # We do this here because we have previously convinced ourselves that the created # linear operators indeed realize deterministic matrices. F = FisherMCLinearOperator( model, loss_function, params, data, seed=seed, check_deterministic=False ) F_accumulated += F @ identity if mc + 1 in mc_samples: F_snapshots.append(F_accumulated / (mc + 1)) .. GENERATED FROM PYTHON SOURCE LINES 218-221 Let's visualize both the residual matrices and their Frobenius norms. To visualize the matrix, we will use the element-wise logarithm of its absolute value (shifted by a small constant to avoid taking the logarithm of 0): .. GENERATED FROM PYTHON SOURCE LINES 222-243 .. code-block:: Python residual_snapshots = [mat - GGN_mat for mat in F_snapshots] residual_norms = [matrix_norm(res).item() for res in residual_snapshots] def transform(mat: Tensor, epsilon: float = 1e-5) -> Tensor: """Transformation applied to the matrix before plotting. Applies element-wise absolute value, shifts by epsilon, then takes the element-wise logarithm. Args: mat: Matrix. epsilon: Small shift to avoid taking the log of 0. Returns: Transformed matrix """ return (mat.abs() + epsilon).log10() .. GENERATED FROM PYTHON SOURCE LINES 244-245 Here's the plotting code (feel free to skip to the visualization). .. GENERATED FROM PYTHON SOURCE LINES 246-303 .. code-block:: Python img_width = 4 rows, columns = 1, 2 fig, axes = plt.subplots( nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width) ) ax_img, ax_fro = axes[0], axes[1] min_img = min(transform(res).min() for res in residual_snapshots) max_img = max(transform(res).max() for res in residual_snapshots) min_fro = 0.0 max_fro = max(residual_norms) ax_fro.set_xlim(mc_samples.min(), mc_samples.max()) ax_fro.semilogx() ax_fro.set_xlabel("MC samples") ax_fro.set_ylabel("residual Frobenius norm") ax_fro.set_ylim(min_fro, 1.05 * max_fro) ln_style = "bo" # collects artists to draw in each frame of the animation artists = [] for frame_idx in range(len(mc_samples)): snapshot = residual_snapshots[frame_idx].cpu() img = ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img, animated=True) # workaround for animated title: https://stackoverflow.com/a/47421938 ax_img_title = plt.text( 0.5, 1.01, f"Residual ({mc_samples[frame_idx]} samples)", horizontalalignment="center", verticalalignment="bottom", transform=ax_img.transAxes, ) if frame_idx == 0: ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img) plt.colorbar(img, ax=ax_img, label="logarithmic absolute entries", shrink=0.8) ax_fro.plot( mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style ) plt.subplots_adjust(wspace=0.5, bottom=0.2) ln_fro = ax_fro.plot( mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style ) artists.append([ax_img_title, img] + ln_fro) ani = animation.ArtistAnimation( fig, artists, interval=1000, blit=False, repeat_delay=1000 ) .. container:: sphx-glr-animation .. raw:: html
.. GENERATED FROM PYTHON SOURCE LINES 304-306 Here's a more qualitative comparison that contrasts the GGN matrix and the MC approximated Fisher (both transformed to logspace as described above): .. GENERATED FROM PYTHON SOURCE LINES 307-361 .. code-block:: Python img_width = 4 rows, columns = 1, 2 GGN_mat = GGN_mat.cpu() F_snapshots = [fs.cpu() for fs in F_snapshots] fig, axes = plt.subplots( nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width) ) min_value = min(transform(mat).min() for mat in F_snapshots + [GGN_mat]) max_value = max(transform(mat).max() for mat in F_snapshots + [GGN_mat]) # collects artists to draw in each frame of the animation artists = [] ax_GGN, ax_F = axes[0], axes[1] ax_GGN.set_title("GGN") for frame_idx in range(len(mc_samples)): im_GGN = ax_GGN.imshow( transform(GGN_mat), vmin=min_value, vmax=max_value, animated=True ) im_F = ax_F.imshow( transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value, animated=True ) # workaround for animated title: https://stackoverflow.com/a/47421938 ax_F_title = plt.text( 0.5, 1.01, f"Fisher-MC ({mc_samples[frame_idx]} samples)", horizontalalignment="center", verticalalignment="bottom", transform=ax_F.transAxes, ) if frame_idx == 0: img = ax_GGN.imshow(transform(GGN_mat), vmin=min_value, vmax=max_value) ax_F.imshow(transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value) fig.colorbar( img, ax=axes.ravel().tolist(), label="logarithmic absolute entries", shrink=0.8, ) artists.append([ax_F_title, im_GGN, im_F]) ani = animation.ArtistAnimation( fig, artists, interval=1000, blit=False, repeat_delay=1000 ) .. container:: sphx-glr-animation .. raw:: html
.. GENERATED FROM PYTHON SOURCE LINES 362-363 That's all for now. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 11.366 seconds) .. _sphx_glr_download_basic_usage_example_fisher_monte_carlo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_fisher_monte_carlo.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_fisher_monte_carlo.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_fisher_monte_carlo.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_