.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_inverses.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_inverses.py: Inverses (natural gradient) =============================== This example demonstrates how to work with inverses of linear operators. :code:`curvlinops` offers multiple ways to compute the inverse of a linear operator: conjugate gradient (CG) and Neumann inversion. We will demonstrate CG inversion first and conclude with a comparison to Neumann inversion. Concretely, we will compute the natural gradient :math:`\mathbf{\tilde{g}} = \mathbf{F}^{-1} \mathbf{g}`, defined by the inverse Fisher information matrix :math:`\mathbf{F}^{-1}` and the gradient :math:`\mathbf{g}`. We can use the GGN, as it corresponds to the Fisher for common loss functions like square and cross-entropy loss. .. note:: The GGN is positive semi-definite, i.e. not full-rank. But we need a full-rank matrix to form the inverse. This is why we will add a damping term :math:`\delta \mathbf{I}` before inverting. As always, let's first import the required functionality. .. GENERATED FROM PYTHON SOURCE LINES 23-43 .. code-block:: Python import matplotlib.pyplot as plt from scipy.sparse.linalg import eigsh from torch import cuda, device, eye, float64, manual_seed, rand from torch.linalg import inv from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid from torch.nn.utils import parameters_to_vector from curvlinops import ( CGInverseLinearOperator, GGNLinearOperator, NeumannInverseLinearOperator, ) from curvlinops.examples import IdentityLinearOperator, gradient_and_loss from curvlinops.examples.functorch import functorch_ggn, functorch_gradient_and_loss from curvlinops.utils import allclose_report # make deterministic manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 44-49 Setup ----- We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function. .. GENERATED FROM PYTHON SOURCE LINES 50-74 .. code-block:: Python N = 64 D_in = 7 D_hidden = 5 D_out = 3 DEVICE = device("cuda" if cuda.is_available() else "cpu") DTYPE = float64 # double precision for better stability when computing inverse X1, y1 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE) X2, y2 = rand(N, D_in).to(DEVICE, DTYPE), rand(N, D_out).to(DEVICE, DTYPE) model = Sequential( Linear(D_in, D_hidden), ReLU(), Linear(D_hidden, D_hidden), Sigmoid(), Linear(D_hidden, D_out), ).to(DEVICE, DTYPE) params = {n: p for n, p in model.named_parameters() if p.requires_grad} loss_function = MSELoss(reduction="mean").to(DEVICE, DTYPE) .. GENERATED FROM PYTHON SOURCE LINES 75-81 Next, let's compute the ingredients for the natural gradient. Inverse GGN/Fisher ------------------ First, we set up a linear operator for the damped GGN/Fisher .. GENERATED FROM PYTHON SOURCE LINES 82-90 .. code-block:: Python data = [(X1, y1), (X2, y2)] GGN = GGNLinearOperator(model, loss_function, params, data) shapes = [p.shape for p in params.values()] delta = 1e-2 damping = delta * IdentityLinearOperator(shapes, GGN.device, DTYPE) damped_GGN = GGN + damping .. GENERATED FROM PYTHON SOURCE LINES 91-92 and the linear operator of its inverse: .. GENERATED FROM PYTHON SOURCE LINES 93-102 .. code-block:: Python inverse_damped_GGN = CGInverseLinearOperator( damped_GGN, eps=0, # do not add CG-internal damping # use a small number of iterations for a rough solution max_iter=5, max_tridiag_iter=5, ) .. GENERATED FROM PYTHON SOURCE LINES 103-107 Gradient -------- We can obtain the gradient via a convenience function from :code:`curvlinops.examples`: .. GENERATED FROM PYTHON SOURCE LINES 108-113 .. code-block:: Python gradient, _ = gradient_and_loss(model, loss_function, params, data) # flatten and concatenate gradient = parameters_to_vector(gradient).detach() .. GENERATED FROM PYTHON SOURCE LINES 114-119 Natural gradient ---------------- Now we have all components together to compute the natural gradient with a simple matrix-vector product: .. GENERATED FROM PYTHON SOURCE LINES 120-123 .. code-block:: Python natural_gradient = inverse_damped_GGN @ gradient .. GENERATED FROM PYTHON SOURCE LINES 124-126 As a first sanity check, let's compare if the natural gradient satisfies :math:`\mathbf{F} \mathbf{\tilde{g}} = \mathbf{g}` .. GENERATED FROM PYTHON SOURCE LINES 127-133 .. code-block:: Python approx_gradient = damped_GGN @ natural_gradient print("Comparing gradient with Fisher @ natural gradient.") assert allclose_report(approx_gradient, gradient, rtol=1e-4, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none Comparing gradient with Fisher @ natural gradient. .. GENERATED FROM PYTHON SOURCE LINES 134-140 Verifying results ----------------- To check if the code works, let's compute the GGN with :code:`functorch`, using a utility function of :code:`curvlinops.examples`; then damp it, invert it, and multiply it onto the gradient. .. GENERATED FROM PYTHON SOURCE LINES 141-144 .. code-block:: Python GGN_mat_functorch = functorch_ggn(model, loss_function, params, data).detach() .. GENERATED FROM PYTHON SOURCE LINES 145-146 then damp it and invert it. .. GENERATED FROM PYTHON SOURCE LINES 147-153 .. code-block:: Python damping_mat = delta * eye(GGN_mat_functorch.shape[0], device=DEVICE, dtype=DTYPE) damped_GGN_mat = GGN_mat_functorch + damping_mat inv_damped_GGN_mat = inv(damped_GGN_mat) .. GENERATED FROM PYTHON SOURCE LINES 154-156 Next, let's compute the gradient with :code:`functorch`, using a utility function from :code:`curvlinops.examples`: .. GENERATED FROM PYTHON SOURCE LINES 157-165 .. code-block:: Python gradient_functorch, _ = functorch_gradient_and_loss(model, loss_function, params, data) # flatten and concatenate gradient_functorch = parameters_to_vector(gradient_functorch).detach() print("Comparing gradient with functorch's gradient.") assert allclose_report(gradient, gradient_functorch) .. rst-class:: sphx-glr-script-out .. code-block:: none Comparing gradient with functorch's gradient. .. GENERATED FROM PYTHON SOURCE LINES 166-168 We can now compute the natural gradient from the :code:`functorch` quantities. This should yield approximately the same result: .. GENERATED FROM PYTHON SOURCE LINES 169-178 .. code-block:: Python natural_gradient_functorch = inv_damped_GGN_mat @ gradient_functorch print("Comparing natural gradient with functorch's natural gradient.") rtol, atol = 5e-3, 5e-5 assert allclose_report( natural_gradient, natural_gradient_functorch, rtol=rtol, atol=atol ) .. rst-class:: sphx-glr-script-out .. code-block:: none Comparing natural gradient with functorch's natural gradient. .. GENERATED FROM PYTHON SOURCE LINES 179-182 You might have noticed the rather small tolerances required to achieve approximate equality. We can use stricter convergence hyperparameters for CG to achieve a more accurate inversion .. GENERATED FROM PYTHON SOURCE LINES 183-202 .. code-block:: Python inverse_damped_GGN = CGInverseLinearOperator( damped_GGN, eps=0, # do not add CG-internal damping # increase number of iterations to get an better approximation max_iter=10, max_tridiag_iter=10, ) natural_gradient_more_accurate = inverse_damped_GGN @ gradient smaller_rtol, smaller_atol = rtol / 10, atol / 10 print("Comparing more accurate natural gradient with functorch's natural gradient.") assert allclose_report( natural_gradient_more_accurate, natural_gradient_functorch, rtol=smaller_rtol, atol=smaller_atol, ) .. rst-class:: sphx-glr-script-out .. code-block:: none Comparing more accurate natural gradient with functorch's natural gradient. .. GENERATED FROM PYTHON SOURCE LINES 203-204 whereas the less accurate inversion does not pass this check: .. GENERATED FROM PYTHON SOURCE LINES 205-220 .. code-block:: Python print( "Comparing natural gradient with functorch's natural gradient (smaller tolerances)." ) try: assert allclose_report( natural_gradient, natural_gradient_functorch, rtol=smaller_rtol, atol=smaller_atol, ) raise RuntimeError("This comparison should not pass") except AssertionError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none Comparing natural gradient with functorch's natural gradient (smaller tolerances). at index [4]: 4.04149e-02 ≠ 4.03777e-02, ratio: 1.00092e+00 at index [8]: 6.76268e-03 ≠ 6.81269e-03, ratio: 9.92659e-01 at index [9]: -1.36847e-02 ≠ -1.37049e-02, ratio: 9.98528e-01 at index [10]: -3.70819e-03 ≠ -3.66605e-03, ratio: 1.01150e+00 at index [11]: -7.79267e-03 ≠ -7.76049e-03, ratio: 1.00415e+00 at index [13]: -1.80134e-03 ≠ -1.76606e-03, ratio: 1.01998e+00 at index [22]: 3.16156e-02 ≠ 3.16480e-02, ratio: 9.98977e-01 at index [23]: 4.82754e-02 ≠ 4.83161e-02, ratio: 9.99158e-01 at index [24]: 1.27004e-02 ≠ 1.26878e-02, ratio: 1.00100e+00 at index [26]: -1.42644e-02 ≠ -1.42521e-02, ratio: 1.00086e+00 at index [27]: 3.20968e-02 ≠ 3.21585e-02, ratio: 9.98081e-01 at index [29]: 2.13577e-02 ≠ 2.13813e-02, ratio: 9.98897e-01 at index [30]: 2.14169e-02 ≠ 2.14483e-02, ratio: 9.98537e-01 at index [31]: 8.43465e-03 ≠ 8.45006e-03, ratio: 9.98176e-01 at index [32]: -2.02736e-03 ≠ -2.00121e-03, ratio: 1.01307e+00 at index [33]: 3.67601e-03 ≠ 3.70411e-03, ratio: 9.92413e-01 at index [34]: 4.81874e-04 ≠ 4.88223e-04, ratio: 9.86996e-01 at index [36]: -1.07670e-02 ≠ -1.07276e-02, ratio: 1.00368e+00 at index [39]: 1.91604e-02 ≠ 1.92130e-02, ratio: 9.97266e-01 at index [40]: -1.24558e-02 ≠ -1.24304e-02, ratio: 1.00204e+00 at index [41]: -2.01354e-03 ≠ -2.00333e-03, ratio: 1.00510e+00 at index [45]: 4.96902e-02 ≠ 4.96296e-02, ratio: 1.00122e+00 at index [46]: 4.18778e-04 ≠ 4.07667e-04, ratio: 1.02725e+00 at index [48]: -1.74391e-02 ≠ -1.74200e-02, ratio: 1.00110e+00 at index [49]: 2.25335e-02 ≠ 2.25163e-02, ratio: 1.00076e+00 at index [50]: 1.02638e-02 ≠ 1.02371e-02, ratio: 1.00260e+00 at index [54]: 4.42425e-02 ≠ 4.42709e-02, ratio: 9.99359e-01 at index [55]: 2.56494e-02 ≠ 2.56147e-02, ratio: 1.00135e+00 at index [70]: 9.86109e-03 ≠ 9.84801e-03, ratio: 1.00133e+00 at index [75]: 2.27872e-03 ≠ 2.27010e-03, ratio: 1.00380e+00 at index [77]: -1.58043e-02 ≠ -1.58248e-02, ratio: 9.98707e-01 Abs max: 1.35858e-01 vs. 1.35860e-01. Abs min: 0.00000e+00 vs. 0.00000e+00. Non-close entries: 31 / 88. rtol = 0.0005, atol = 5e-06. .. GENERATED FROM PYTHON SOURCE LINES 221-227 Visual comparison ----------------- Finally, let's visualize the damped Fisher/GGN and its inverse. For improved visibility, we take the logarithm of the absolute value of each element (blank pixels correspond to zeros). .. GENERATED FROM PYTHON SOURCE LINES 228-240 .. code-block:: Python fig, ax = plt.subplots(ncols=2) plt.suptitle("Logarithm of absolute values") ax[0].set_title("Damped GGN/Fisher") image = ax[0].imshow(damped_GGN_mat.detach().cpu().abs().log10()) plt.colorbar(image, ax=ax[0], shrink=0.5) ax[1].set_title("Inv. damped GGN/Fisher") image = ax[1].imshow(inv_damped_GGN_mat.detach().cpu().abs().log10()) plt.colorbar(image, ax=ax[1], shrink=0.5) .. image-sg:: /basic_usage/images/sphx_glr_example_inverses_001.png :alt: Logarithm of absolute values, Damped GGN/Fisher, Inv. damped GGN/Fisher :srcset: /basic_usage/images/sphx_glr_example_inverses_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 241-258 Neumann inverse (CG alternative) -------------------------------- So far, we used CG to solve the linear system :math:`\mathbf{F} \mathbf{\tilde{g}} = \mathbf{g}` for the natural gradient :math:`\mathbf{\tilde{g}}` (i.e. the result of the inverse Fisher-gradient product). Alternatively, we can use the truncated `Neumann series `_ to approximate the inverse, using :py:class:`NeumannLinearOperator`. .. note:: The Neumann series does not always converge. But we can use a re-scaling trick to make it converge if we know the matrix is PSD and are given its largest eigenvalue. More information can be found in the docstring. To make the Neumann series converge, we need to know the largest eigenvalue of the matrix to be inverted: .. GENERATED FROM PYTHON SOURCE LINES 259-263 .. code-block:: Python max_eigval = eigsh(damped_GGN.to_scipy(), k=1, which="LM", return_eigenvectors=False)[0] # eigenvalues (scale * damped_GGN_mat) are in [0; 2) scale = 1.0 if max_eigval < 2.0 else 1.99 / max_eigval .. GENERATED FROM PYTHON SOURCE LINES 264-265 Let's compute the inverse approximation for different truncation numbers: .. GENERATED FROM PYTHON SOURCE LINES 266-274 .. code-block:: Python num_terms = [10] neumann_inverses = [] for n in num_terms: inv = NeumannInverseLinearOperator(damped_GGN, scale=scale, num_terms=n) neumann_inverses.append(inv @ eye(inv.shape[1], device=DEVICE, dtype=DTYPE)) .. GENERATED FROM PYTHON SOURCE LINES 275-276 Here are their visualizations: .. GENERATED FROM PYTHON SOURCE LINES 277-292 .. code-block:: Python fig, axes = plt.subplots(ncols=len(num_terms) + 1) plt.suptitle("Inverse damped Fisher (logarithm of absolute values)") for i, (n, inv) in enumerate(zip(num_terms, neumann_inverses)): ax = axes.flat[i] ax.set_title(f"Neumann, {n} terms") image = ax.imshow(inv.detach().cpu().abs().log10()) plt.colorbar(image, ax=ax, shrink=0.5) ax = axes.flat[-1] ax.set_title("Exact inverse") image = ax.imshow(inv_damped_GGN_mat.detach().cpu().abs().log10()) plt.colorbar(image, ax=ax, shrink=0.5) .. image-sg:: /basic_usage/images/sphx_glr_example_inverses_002.png :alt: Inverse damped Fisher (logarithm of absolute values), Neumann, 10 terms, Exact inverse :srcset: /basic_usage/images/sphx_glr_example_inverses_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 293-296 The Neumann inversion is usually more inaccurate than CG inversion. But it might sometimes be preferred if only a rough approximation of the inverse matrix product is needed. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.801 seconds) .. _sphx_glr_download_basic_usage_example_inverses.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_inverses.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_inverses.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_inverses.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_