.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_visual_tour.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_visual_tour.py: Visual tour of curvature matrices ================================= This tutorial visualizes different curvature matrices for a model with sufficiently small parameter space. First, the imports. .. GENERATED FROM PYTHON SOURCE LINES 9-42 .. code-block:: Python from collections.abc import Callable import matplotlib.pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure from numpy import cumsum from torch import Tensor, cuda, device, eye, manual_seed, rand, randint from torch.nn import ( Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential, Sigmoid, ) from torch.utils.data import DataLoader, TensorDataset from tueplots import bundles from curvlinops import ( EFLinearOperator, EKFACLinearOperator, GGNLinearOperator, HessianLinearOperator, KFACLinearOperator, ) # make deterministic manual_seed(0) DEVICE = device("cuda" if cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 43-48 Setup ----- We will create a synthetic classification task, a small CNN, and use cross-entropy error as loss function. .. GENERATED FROM PYTHON SOURCE LINES 48-85 .. code-block:: Python num_data = 50 batch_size = 20 in_channels = 3 in_features_shape = (in_channels, 10, 10) num_classes = 5 # dataset dataset = TensorDataset( rand(num_data, *in_features_shape), # X randint(size=(num_data,), low=0, high=num_classes), # y ) dataloader = DataLoader(dataset, batch_size=batch_size) # model model = Sequential( Conv2d(in_channels, 4, 3, padding=1), ReLU(), Conv2d(4, 4, 5, padding=2, stride=2), Sigmoid(), Conv2d(4, 1, 3, padding=1), Flatten(), Linear(25, num_classes), ).to(DEVICE) params = {n: p for n, p in model.named_parameters() if p.requires_grad} num_params = sum(p.numel() for p in params.values()) num_params_layer = [ sum(p.numel() for p in child.parameters()) for child in model.children() ] num_tensors_layer = [len(list(child.parameters())) for child in model.children()] loss_function = CrossEntropyLoss(reduction="mean").to(DEVICE) print(f"Total parameters: {num_params}") print(f"Layer parameters: {num_params_layer}") .. rst-class:: sphx-glr-script-out .. code-block:: none Total parameters: 683 Layer parameters: [112, 0, 404, 0, 37, 0, 130] .. GENERATED FROM PYTHON SOURCE LINES 86-94 Computation ----------- We can now set up linear operators for the curvature matrices we want to visualize, and compute them by multiplying the linear operator onto the identity matrix. First, create the linear operators: .. GENERATED FROM PYTHON SOURCE LINES 94-106 .. code-block:: Python Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader) GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader) EF_linop = EFLinearOperator(model, loss_function, params, dataloader) EKFAC_linop = EKFACLinearOperator( model, loss_function, params, dataloader, separate_weight_and_bias=False ) F_linop = GGNLinearOperator(model, loss_function, params, dataloader, mc_samples=1) KFAC_linop = KFACLinearOperator( model, loss_function, params, dataloader, separate_weight_and_bias=False ) .. GENERATED FROM PYTHON SOURCE LINES 107-108 Then, compute the matrices .. GENERATED FROM PYTHON SOURCE LINES 109-119 .. code-block:: Python identity = eye(num_params, device=DEVICE) Hessian_mat = Hessian_linop @ identity GGN_mat = GGN_linop @ identity EF_mat = EF_linop @ identity EKFAC_mat = EKFAC_linop @ identity F_mat = F_linop @ identity KFAC_mat = KFAC_linop @ identity .. GENERATED FROM PYTHON SOURCE LINES 120-124 Visualization ------------- We will show the matrix entries on a shared domain for better comparability. .. GENERATED FROM PYTHON SOURCE LINES 124-196 .. code-block:: Python matrices = [m.cpu() for m in (Hessian_mat, GGN_mat, EF_mat, EKFAC_mat, F_mat, KFAC_mat)] titles = [ "Hessian", "Generalized Gauss-Newton", "Empirical Fisher", "EKFAC", "Monte-Carlo Fisher", "KFAC", ] rows, columns = 2, 3 def plot( transform: Callable[[Tensor], Tensor], transform_title: str = None ) -> tuple[Figure, Axes]: """Visualize transformed curvature matrices using a shared domain. Args: transform: A transformation that will be applied to the matrices. Must accept a matrix and return a matrix of the same shape. transform_title: An optional string describing the transformation. Default: `None` (empty). Returns: Figure and axes of the created subplot. """ min_value = min(transform(mat).min() for mat in matrices) max_value = max(transform(mat).max() for mat in matrices) fig, axes = plt.subplots(nrows=rows, ncols=columns, sharex=True, sharey=True) fig.supxlabel("Layer") fig.supylabel("Layer") for idx, (ax, mat, title) in enumerate(zip(axes.flat, matrices, titles)): ax.set_title(title) img = ax.imshow(transform(mat), vmin=min_value, vmax=max_value) # layer blocks boundaries = [0] + cumsum(num_params_layer).tolist() for pos in boundaries: if pos not in [0, num_params]: style = {"color": "w", "lw": 0.5, "ls": "-"} ax.axhline(y=pos - 1, xmin=0, xmax=num_params - 1, **style) ax.axvline(x=pos - 1, ymin=0, ymax=num_params - 1, **style) # label positions label_positions = [ (boundaries[layer_idx] + boundaries[layer_idx + 1]) / 2 for layer_idx in range(len(boundaries) - 1) if boundaries[layer_idx] != boundaries[layer_idx + 1] ] labels = [str(i + 1) for i in range(len(label_positions))] ax.set_xticks(label_positions) ax.set_xticklabels(labels) ax.set_yticks(label_positions) ax.set_yticklabels(labels) # colorbar last = idx == len(matrices) - 1 if last: fig.colorbar( img, ax=axes.ravel().tolist(), label=transform_title, shrink=0.8 ) return fig, axes # use `tueplots` to make the plot look pretty plot_config = bundles.icml2024(column="full", nrows=1.5 * rows, ncols=columns) .. GENERATED FROM PYTHON SOURCE LINES 197-198 We will show their logarithmic absolute value: .. GENERATED FROM PYTHON SOURCE LINES 199-214 .. code-block:: Python def logabs(mat: Tensor, epsilon: float = 1e-6) -> Tensor: """Return the log10 of the clamped absolute values. Returns: Transformed matrix. """ return mat.abs().clamp(min=epsilon).log10() with plt.rc_context(plot_config): plot(logabs, transform_title="Logarithmic absolute entries") plt.savefig("curvature_matrices_log_abs.pdf", bbox_inches="tight") .. image-sg:: /basic_usage/images/sphx_glr_example_visual_tour_001.png :alt: Hessian, Generalized Gauss-Newton, Empirical Fisher, EKFAC, Monte-Carlo Fisher, KFAC :srcset: /basic_usage/images/sphx_glr_example_visual_tour_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 215-216 That's because it is hard to recognize structure in the unaltered entries: .. GENERATED FROM PYTHON SOURCE LINES 217-231 .. code-block:: Python def unchanged(mat: Tensor) -> Tensor: """Return the matrix unchanged. Returns: Unchanged matrix. """ return mat with plt.rc_context(plot_config): plot(unchanged, transform_title="Unaltered matrix entries") .. image-sg:: /basic_usage/images/sphx_glr_example_visual_tour_002.png :alt: Hessian, Generalized Gauss-Newton, Empirical Fisher, EKFAC, Monte-Carlo Fisher, KFAC :srcset: /basic_usage/images/sphx_glr_example_visual_tour_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 232-233 That's all for now. .. GENERATED FROM PYTHON SOURCE LINES 234-236 .. code-block:: Python plt.close("all") .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 16.131 seconds) .. _sphx_glr_download_basic_usage_example_visual_tour.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_visual_tour.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_visual_tour.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_visual_tour.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_