.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_fisher_monte_carlo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_fisher_monte_carlo.py: Monte-Carlo approximation of the Fisher ======================================= In this tutorial, we will compare two approaches to compute the Fisher information matrix: 1. The Fisher as expected Hessian under the model's likelihood coincides with the generalized Gauss-Newton (GGN) matrix for common loss functions, like :class:`torch.nn.MSELoss` and :class:`torch.nn.CrossEntropyLoss`. For these settings, Fisher = GGN. 2. The Fisher can also be seen as expectation of the gradient outer product w.r.t. the model's likelihood. This expectation can be approximated by computing the outer product of 'would-be' gradients where the loss is evaluated on a label sampled from the model's likelihood, rather than the true label. The first approach is implemented by :class:`curvlinops.GGNLinearOperator`, the second by :class:`curvlinops.FisherMCLinearOperator`. We will see that both approaches coincide as the Monte-Carlo approximation converges. Let's get the imports out of our way. .. GENERATED FROM PYTHON SOURCE LINES 24-40 .. code-block:: Python import matplotlib.pyplot as plt import numpy import torch from matplotlib import animation from torch import nn from curvlinops import FisherMCLinearOperator, GGNLinearOperator # make deterministic torch.manual_seed(0) numpy.random.seed(0) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 41-46 Setup ----- We will create a synthetic classification task, a small CNN, and use cross-entropy error as loss function. .. GENERATED FROM PYTHON SOURCE LINES 46-74 .. code-block:: Python Ns = [4, 6] D_in = 7 D_hidden = 5 C = 3 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") data = [ ( torch.rand(N, D_in).to(DEVICE), # X torch.randint(low=0, high=C, size=(N,)).to(DEVICE), # y ) for N in Ns ] model = nn.Sequential( nn.Linear(D_in, D_hidden), nn.ReLU(), nn.Linear(D_hidden, D_hidden), nn.Sigmoid(), nn.Linear(D_hidden, C), ).to(DEVICE) params = [p for p in model.parameters() if p.requires_grad] loss_function = nn.CrossEntropyLoss(reduction="mean").to(DEVICE) .. GENERATED FROM PYTHON SOURCE LINES 75-81 A first comparison ------------------ Let's create linear operators for the GGN and the Monte-Carlo approximated Fisher and compute their matrix representations by multiplying them onto the identity matrix: .. GENERATED FROM PYTHON SOURCE LINES 81-91 .. code-block:: Python GGN = GGNLinearOperator(model, loss_function, params, data) F = FisherMCLinearOperator(model, loss_function, params, data) D = GGN.shape[0] identity = numpy.eye(D) GGN_mat = GGN @ identity F_mat = F @ identity .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/curvlinops/envs/1.2.0/lib/python3.8/site-packages/curvlinops/_base.py:259: UserWarning: Input vector is float64, while linear operator is float32. Converting to float32. warn( .. GENERATED FROM PYTHON SOURCE LINES 92-94 We can use the residual's Frobenius norm to quantify the approximation error of the Monte-Carlo estimator: .. GENERATED FROM PYTHON SOURCE LINES 95-99 .. code-block:: Python residual_norm = numpy.linalg.norm(GGN_mat - F_mat) print(f"Residual (Frobenius) norm: {residual_norm:.5f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Residual (Frobenius) norm: 0.31432 .. GENERATED FROM PYTHON SOURCE LINES 100-108 Setting the number of MC samples -------------------------------- To get more accurate estimates, we can use more samples in the MC approximation. This is achieved by specifying the optional :code:`mc_samples` argument to :class:`curvlinops.FisherMCLinearOperator`. The default value is :code:`1`. Here are the residual Frobenius norms when using more samples: .. GENERATED FROM PYTHON SOURCE LINES 108-120 .. code-block:: Python mc_samples = [1, 2, 4, 8] residual_norms = [] for mc in mc_samples: F = FisherMCLinearOperator(model, loss_function, params, data, mc_samples=mc) F_mat = F @ identity residual_norms.append(numpy.linalg.norm(GGN_mat - F_mat)) for mc, norm in zip(mc_samples, residual_norms): print(f"mc_samples = {mc},\tresidual (Frobenius) norm = {norm:.5f}") .. rst-class:: sphx-glr-script-out .. code-block:: none mc_samples = 1, residual (Frobenius) norm = 0.31432 mc_samples = 2, residual (Frobenius) norm = 0.23078 mc_samples = 4, residual (Frobenius) norm = 0.17235 mc_samples = 8, residual (Frobenius) norm = 0.11499 .. GENERATED FROM PYTHON SOURCE LINES 121-134 Setting the random seed ----------------------- You may have noticed above that the two linear operators created with :code:`mc_samples=1` yield identical residual Frobenius norms. This is because the two linear operators realize the same matrix, i.e. the same sample from the Monte-Carlo estimator. To see that :class:`curvlinops.FisherMCLinearOperator` indeed represents a deterministic matrix, let's create two linear operators with identical hyperparameters and compare their matrix representations. After creating the first linear operator, we generate some random numbers to show that the global random number generator does not influence the Monte-Carlo estimator: .. GENERATED FROM PYTHON SOURCE LINES 134-149 .. code-block:: Python F1_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity # draw some random numbers to modify the global random number generator's state torch.rand(123) F2_mat = FisherMCLinearOperator(model, loss_function, params, data) @ identity # still, we get the same deterministic approximation residual_norm = numpy.linalg.norm(F1_mat - F2_mat) if numpy.isclose(residual_norm, 0.0): print(residual_norm) else: raise RuntimeError(f"Residual Frobenius norm should be 0. Got {residual_norm}.") .. rst-class:: sphx-glr-script-out .. code-block:: none 0.0 .. GENERATED FROM PYTHON SOURCE LINES 150-157 This is because the class uses an internal random number generator to draw samples. Therefore, it will not be affected by changes to the global random number generator's state. You can get different realizations of the Monte-Carlo estimator by specifying the optional :code:`seed` argument. The above comparison with differently seeded linear operators leads to different matrices: .. GENERATED FROM PYTHON SOURCE LINES 158-176 .. code-block:: Python seed1 = 123456 F1_mat = ( FisherMCLinearOperator(model, loss_function, params, data, seed=seed1) @ identity ) seed2 = 654321 F2_mat = ( FisherMCLinearOperator(model, loss_function, params, data, seed=seed2) @ identity ) # now, we get two different deterministic approximations residual_norm = numpy.linalg.norm(F1_mat - F2_mat) if not numpy.isclose(residual_norm, 0.0): print(residual_norm) else: raise RuntimeError(f"Residual Frobenius norm should be ≠0. Got {residual_norm}.") .. rst-class:: sphx-glr-script-out .. code-block:: none 0.20334943 .. GENERATED FROM PYTHON SOURCE LINES 177-189 Approximation quality --------------------- Finally, let's combine what we have seen so far to visualize how well the Monte-Carlo approximated Fisher approximates the GGN. To do that, we will repeatedly draw samples for the Fisher information matrix and combine them to yield an estimate that incorporates all previous iterations. This approach allows to record snapshots of the estimator at a different number of total incorporated MC samples. We will use a :code:`logspace` for taking snapshots. .. GENERATED FROM PYTHON SOURCE LINES 190-211 .. code-block:: Python num_steps = 25 mc_samples = numpy.unique( numpy.logspace(0, 2, num_steps, endpoint=True, dtype=numpy.int32) ) F_snapshots = [] F_accumulated = numpy.zeros((D, D)) start_seed = 123456789 for seed, mc in enumerate(range(mc_samples.max()), start=start_seed): # NOTE Only use `check_deterministic=False` if you know what you are doing # We do this here because we have previously convinced ourselves that the created # linear operators indeed realize deterministic matrices. F = FisherMCLinearOperator( model, loss_function, params, data, seed=seed, check_deterministic=False ) F_accumulated += F @ identity if mc + 1 in mc_samples: F_snapshots.append(F_accumulated / (mc + 1)) .. GENERATED FROM PYTHON SOURCE LINES 212-215 Let's visualize both the residual matrices and their Frobenius norms. To visualize the matrix, we will use the element-wise logarithm of its absolute value (shifted by a small constant to avoid taking the logarithm of 0): .. GENERATED FROM PYTHON SOURCE LINES 216-237 .. code-block:: Python residual_snapshots = [mat - GGN_mat for mat in F_snapshots] residual_norms = [numpy.linalg.norm(res) for res in residual_snapshots] def transform(mat: numpy.ndarray, epsilon: float = 1e-5) -> numpy.ndarray: """Transformation applied to the matrix before plotting. Applies element-wise absolute value, shifts by epsilon, then takes the element-wise logarithm. Args: mat: Matrix. epsilon: Small shift to avoid taking the log of 0. Returns: Transformed matrix """ return numpy.log10(numpy.abs(mat) + epsilon) .. GENERATED FROM PYTHON SOURCE LINES 238-239 Here's the plotting code (feel free to skip to the visualization). .. GENERATED FROM PYTHON SOURCE LINES 240-297 .. code-block:: Python img_width = 4 rows, columns = 1, 2 fig, axes = plt.subplots( nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width) ) ax_img, ax_fro = axes[0], axes[1] min_img = min(transform(res).min() for res in residual_snapshots) max_img = max(transform(res).max() for res in residual_snapshots) min_fro = 0.0 max_fro = max(residual_norms) ax_fro.set_xlim(mc_samples.min(), mc_samples.max()) ax_fro.semilogx() ax_fro.set_xlabel("MC samples") ax_fro.set_ylabel("residual Frobenius norm") ax_fro.set_ylim(min_fro, 1.05 * max_fro) ln_style = "bo" # collects artists to draw in each frame of the animation artists = [] for frame_idx in range(len(mc_samples)): snapshot = residual_snapshots[frame_idx] img = ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img, animated=True) # workaround for animated title: https://stackoverflow.com/a/47421938 ax_img_title = plt.text( 0.5, 1.01, f"Residual ({mc_samples[frame_idx]} samples)", horizontalalignment="center", verticalalignment="bottom", transform=ax_img.transAxes, ) if frame_idx == 0: ax_img.imshow(transform(snapshot), vmin=min_img, vmax=max_img) plt.colorbar(img, ax=ax_img, label="logarithmic absolute entries", shrink=0.8) ax_fro.plot( mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style ) plt.subplots_adjust(wspace=0.5, bottom=0.2) ln_fro = ax_fro.plot( mc_samples[: frame_idx + 1], residual_norms[: frame_idx + 1], ln_style ) artists.append([ax_img_title, img] + ln_fro) ani = animation.ArtistAnimation( fig, artists, interval=1000, blit=False, repeat_delay=1000 ) .. container:: sphx-glr-animation .. raw:: html
.. GENERATED FROM PYTHON SOURCE LINES 298-300 Here's a more qualitative comparison that contrasts the GGN matrix and the MC approximated Fisher (both transformed to logspace as described above): .. GENERATED FROM PYTHON SOURCE LINES 301-353 .. code-block:: Python img_width = 4 rows, columns = 1, 2 fig, axes = plt.subplots( nrows=rows, ncols=columns, figsize=(columns * img_width, rows * img_width) ) min_value = min(transform(mat).min() for mat in F_snapshots + [GGN_mat]) max_value = max(transform(mat).max() for mat in F_snapshots + [GGN_mat]) # collects artists to draw in each frame of the animation artists = [] ax_GGN, ax_F = axes[0], axes[1] ax_GGN.set_title("GGN") for frame_idx in range(len(mc_samples)): im_GGN = ax_GGN.imshow( transform(GGN_mat), vmin=min_value, vmax=max_value, animated=True ) im_F = ax_F.imshow( transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value, animated=True ) # workaround for animated title: https://stackoverflow.com/a/47421938 ax_F_title = plt.text( 0.5, 1.01, f"Fisher-MC ({mc_samples[frame_idx]} samples)", horizontalalignment="center", verticalalignment="bottom", transform=ax_F.transAxes, ) if frame_idx == 0: img = ax_GGN.imshow(transform(GGN_mat), vmin=min_value, vmax=max_value) ax_F.imshow(transform(F_snapshots[frame_idx]), vmin=min_value, vmax=max_value) fig.colorbar( img, ax=axes.ravel().tolist(), label="logarithmic absolute entries", shrink=0.8, ) artists.append([ax_F_title, im_GGN, im_F]) ani = animation.ArtistAnimation( fig, artists, interval=1000, blit=False, repeat_delay=1000 ) .. container:: sphx-glr-animation .. raw:: html
.. GENERATED FROM PYTHON SOURCE LINES 354-355 That's all for now. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 40.956 seconds) .. _sphx_glr_download_basic_usage_example_fisher_monte_carlo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_fisher_monte_carlo.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_fisher_monte_carlo.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_fisher_monte_carlo.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_