.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_submatrices.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_submatrices.py: Sub-matrices of linear operators ================================ This tutorial explains how to create linear operators that correspond to a sub-matrix of another linear operator. Specifically, given the linear operator :code:`A`, we are interested in constructing the linear operator that corresponds to its sub-matrix :code:`A[row_idxs, :][:, col_idxs]`, where :code:`row_idxs` contains the sub-matrix's row indices, and :code:`col_idxs` contains the sub-matrix's column indices. First, the imports. .. GENERATED FROM PYTHON SOURCE LINES 14-28 .. code-block:: Python from time import time from torch import Tensor, cuda, device, eye, manual_seed, rand from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid from curvlinops import HessianLinearOperator from curvlinops.examples.functorch import functorch_hessian from curvlinops.submatrix import SubmatrixLinearOperator from curvlinops.utils import allclose_report # make deterministic manual_seed(0) .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 29-33 Setup ----- Let's create some toy data, a small MLP, and use mean-squared error as loss function. .. GENERATED FROM PYTHON SOURCE LINES 34-57 .. code-block:: Python N = 4 D_in = 7 D_hidden = 5 D_out = 3 DEVICE = device("cuda" if cuda.is_available() else "cpu") X1, y1 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE) X2, y2 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE) data = [(X1, y1), (X2, y2)] model = Sequential( Linear(D_in, D_hidden), ReLU(), Linear(D_hidden, D_hidden), Sigmoid(), Linear(D_hidden, D_out), ).to(DEVICE) params = {n: p for n, p in model.named_parameters() if p.requires_grad} loss_function = MSELoss(reduction="mean").to(DEVICE) .. GENERATED FROM PYTHON SOURCE LINES 58-60 We will investigate the Hessian. To make sure our results are correct, let's keep a Hessian matrix computed via :mod:`functorch` around. .. GENERATED FROM PYTHON SOURCE LINES 61-64 .. code-block:: Python H_functorch = functorch_hessian(model, loss_function, params, data) .. GENERATED FROM PYTHON SOURCE LINES 65-68 Here is the corresponding linear operator and a quick check that builds up its matrix representation through multiplication with the identity matrix, followed by comparison to the Hessian matrix computed via :mod:`functorch`. .. GENERATED FROM PYTHON SOURCE LINES 69-76 .. code-block:: Python H = HessianLinearOperator(model, loss_function, params, data) num_params = sum(p.numel() for p in params.values()) identity = eye(num_params, device=DEVICE) assert allclose_report(H_functorch, H @ identity) .. GENERATED FROM PYTHON SOURCE LINES 77-84 Diagonal blocks --------------- The Hessian consists of blocks :code:`(i, j)` that contain the second-order derivatives of the loss w.r.t. the parameters in :code:`(params[i], params[j])`. Let's define a function to extract these blocks from the Hessian: .. GENERATED FROM PYTHON SOURCE LINES 85-107 .. code-block:: Python def extract_block(mat: Tensor, params: dict[str, Tensor], i: int, j: int) -> Tensor: """Extract the Hessian block from parameters ``i`` and ``j``. Args: mat: The matrix with block structure. params: The parameters defining the blocks. i: Row index of the block to be extracted. j: Column index of the block to be extracted. Returns: Block ``(i, j)``. Has shape ``[P_i, P_j]`` where ``P_i`` and ``P_j`` are the number of elements of the ``i``-th and ``j``-th parameter. """ param_dims = [p.numel() for p in params.values()] row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) return mat[row_start:row_end, :][:, col_start:col_end] .. GENERATED FROM PYTHON SOURCE LINES 108-110 As an example, let's extract the block that corresponds to the Hessian w.r.t. the first layer's weights in our model. .. GENERATED FROM PYTHON SOURCE LINES 111-116 .. code-block:: Python i, j = 0, 0 param_names = list(params.keys()) H_param0_functorch = extract_block(H_functorch, params, i, j) .. GENERATED FROM PYTHON SOURCE LINES 117-119 We can build a linear operator for this sub-Hessian by only providing the first layer's weight as parameter: .. GENERATED FROM PYTHON SOURCE LINES 120-124 .. code-block:: Python param_i = {param_names[i]: params[param_names[i]]} H_param0 = HessianLinearOperator(model, loss_function, param_i, data) .. GENERATED FROM PYTHON SOURCE LINES 125-130 Like this we can get blocks from the diagonal. Let's check that this linear operator works as expected by multiplying it onto the identity matrix and comparing the result to the block we extracted from our ground truth: .. GENERATED FROM PYTHON SOURCE LINES 131-136 .. code-block:: Python assert allclose_report( H_param0_functorch, H_param0 @ eye(params[param_names[i]].numel(), device=DEVICE) ) .. GENERATED FROM PYTHON SOURCE LINES 137-153 Now you might be wondering if we can also build up linear operators for off-diagonal blocks. These blocks contain mixed second-order derivatives and are not Hessians anymore. For instance, such a block is rectangular in general, and thus non-symmetric. Since we are not asking for a Hessian anymore, we cannot use the interface of :class:`HessianLinearOperator`. Luckily, there is a different way to achieve this. Off-diagonal blocks ------------------- As an example,let's try to extract the Hessian block from the first and second parameters in our network (i.e. the weights and biases in the first layer). For that we need to slice the Hessian differently along its rows and columns. We can use the :class:`curvlinops.SubmatrixLinearOperator` class for that: .. GENERATED FROM PYTHON SOURCE LINES 154-165 .. code-block:: Python param_dims = [p.numel() for p in params.values()] i, j = 0, 1 row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) row_idxs = list(range(row_start, row_end)) # keep the following row indices col_idxs = list(range(col_start, col_end)) # keep the following column indices H_param0_param1 = SubmatrixLinearOperator(H, row_idxs, col_idxs) .. GENERATED FROM PYTHON SOURCE LINES 166-168 As the following test shows, this linear operator indeed represents the desired rectangular Hessian block: .. GENERATED FROM PYTHON SOURCE LINES 169-177 .. code-block:: Python H_param0_param1_functorch = extract_block(H_functorch, params, i, j) assert allclose_report( H_param0_param1_functorch, H_param0_param1_functorch @ eye(param_dims[j], device=DEVICE), ) .. GENERATED FROM PYTHON SOURCE LINES 178-187 Arbitrary sub-matrices ---------------------- So far, we were constrained to blocks spanned by parameter tensors rather than arbitrary elements. As the name :class:`SubmatrixLinearOperator` suggests, we can use it to create arbitrary sub-matrices. As an example, let's say we want to keep rows :code:`[0, 13, 42]` of the Hessian, and columns :code:`[1, 2, 3]`. This works as follows: .. GENERATED FROM PYTHON SOURCE LINES 188-195 .. code-block:: Python row_idxs = [0, 13, 42] # keep the following row indices col_idxs = [1, 2, 3] # keep the following column indices H_sub = SubmatrixLinearOperator(H, row_idxs, col_idxs) H_sub_functorch = H_functorch[row_idxs, :][:, col_idxs] .. GENERATED FROM PYTHON SOURCE LINES 196-197 Quick check to see if it worked: .. GENERATED FROM PYTHON SOURCE LINES 198-201 .. code-block:: Python assert allclose_report(H_sub_functorch, H_sub @ eye(len(col_idxs), device=DEVICE)) .. GENERATED FROM PYTHON SOURCE LINES 202-209 Looks good. Performance remarks ---------------------- By the way, using this interface, we could have also constructed the first parameter's Hessian as follows: .. GENERATED FROM PYTHON SOURCE LINES 210-224 .. code-block:: Python i, j = 0, 0 row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) row_idxs = list(range(row_start, row_end)) col_idxs = list(range(col_start, col_end)) H_param0_alternative = SubmatrixLinearOperator(H, row_idxs, col_idxs) assert allclose_report( H_param0_functorch, H_param0_alternative @ eye(param_dims[0], device=DEVICE) ) .. GENERATED FROM PYTHON SOURCE LINES 225-233 In general though, it is a good idea to first reduce the linear operator's size as much as possible (in our case, by restricting the parameters to the necessary ones using the :code:`params` argument in :class:`HessianLinearOperator`) and apply slicing afterwards to save computations. In our example, the matrix-vector product of :code:`H_param0` should therefore be faster than that of :code:`H_param0_alternative`: .. GENERATED FROM PYTHON SOURCE LINES 234-249 .. code-block:: Python x = rand(param_dims[0], device=DEVICE) # less computations start = time() _ = H_param0 @ x end = time() print(f"H_param0.matvec: {end - start:.2e} s") # more computations start = time() _ = H_param0_alternative @ x end = time() print(f"H_param0_alternative.matvec: {end - start:.2e} s") .. rst-class:: sphx-glr-script-out .. code-block:: none H_param0.matvec: 8.54e-03 s H_param0_alternative.matvec: 9.19e-03 s .. GENERATED FROM PYTHON SOURCE LINES 250-251 That's all for now. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.111 seconds) .. _sphx_glr_download_basic_usage_example_submatrices.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_submatrices.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_submatrices.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_submatrices.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_