Note
Go to the end to download the full example code.
Sub-matrices of linear operators
This tutorial explains how to create linear operators that correspond to a sub-matrix of another linear operator.
Specifically, given the linear operator A, we are
interested in constructing the linear operator that corresponds to its sub-matrix
A[row_idxs, :][:, col_idxs], where row_idxs contains the sub-matrix’s
row indices, and col_idxs contains the sub-matrix’s column indices.
First, the imports.
from time import time
from torch import Tensor, cuda, device, eye, manual_seed, rand
from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid
from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.utils import allclose_report
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x71fa6dd06810>
Setup
Let’s create some toy data, a small MLP, and use mean-squared error as loss function.
N = 4
D_in = 7
D_hidden = 5
D_out = 3
DEVICE = device("cuda" if cuda.is_available() else "cpu")
X1, y1 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE)
X2, y2 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE)
data = [(X1, y1), (X2, y2)]
model = Sequential(
Linear(D_in, D_hidden),
ReLU(),
Linear(D_hidden, D_hidden),
Sigmoid(),
Linear(D_hidden, D_out),
).to(DEVICE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
loss_function = MSELoss(reduction="mean").to(DEVICE)
We will investigate the Hessian. To make sure our results are correct, let’s keep
a Hessian matrix computed via functorch around.
H_functorch = functorch_hessian(model, loss_function, params, data)
Here is the corresponding linear operator and a quick check that builds up
its matrix representation through multiplication with the identity matrix,
followed by comparison to the Hessian matrix computed via functorch.
H = HessianLinearOperator(model, loss_function, params, data)
num_params = sum(p.numel() for p in params.values())
identity = eye(num_params, device=DEVICE)
assert allclose_report(H_functorch, H @ identity)
Diagonal blocks
The Hessian consists of blocks (i, j) that contain the second-order
derivatives of the loss w.r.t. the parameters in (params[i], params[j]).
Let’s define a function to extract these blocks from the Hessian:
def extract_block(mat: Tensor, params: dict[str, Tensor], i: int, j: int) -> Tensor:
"""Extract the Hessian block from parameters ``i`` and ``j``.
Args:
mat: The matrix with block structure.
params: The parameters defining the blocks.
i: Row index of the block to be extracted.
j: Column index of the block to be extracted.
Returns:
Block ``(i, j)``. Has shape ``[P_i, P_j]`` where ``P_i`` and ``P_j``
are the number of elements of the ``i``-th and ``j``-th parameter.
"""
param_dims = [p.numel() for p in params.values()]
row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1])
col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1])
return mat[row_start:row_end, :][:, col_start:col_end]
As an example, let’s extract the block that corresponds to the Hessian w.r.t. the first layer’s weights in our model.
i, j = 0, 0
param_names = list(params.keys())
H_param0_functorch = extract_block(H_functorch, params, i, j)
We can build a linear operator for this sub-Hessian by only providing the first layer’s weight as parameter:
param_i = {param_names[i]: params[param_names[i]]}
H_param0 = HessianLinearOperator(model, loss_function, param_i, data)
Like this we can get blocks from the diagonal.
Let’s check that this linear operator works as expected by multiplying it onto the identity matrix and comparing the result to the block we extracted from our ground truth:
assert allclose_report(
H_param0_functorch, H_param0 @ eye(params[param_names[i]].numel(), device=DEVICE)
)
Now you might be wondering if we can also build up linear operators for
off-diagonal blocks. These blocks contain mixed second-order derivatives and
are not Hessians anymore. For instance, such a block is rectangular in
general, and thus non-symmetric. Since we are not asking for a Hessian
anymore, we cannot use the interface of HessianLinearOperator.
Luckily, there is a different way to achieve this.
Off-diagonal blocks
As an example,let’s try to extract the Hessian block from the first and
second parameters in our network (i.e. the weights and biases in the first
layer). For that we need to slice the Hessian differently along its rows and
columns. We can use the curvlinops.SubmatrixLinearOperator class for
that:
param_dims = [p.numel() for p in params.values()]
i, j = 0, 1
row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1])
col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1])
row_idxs = list(range(row_start, row_end)) # keep the following row indices
col_idxs = list(range(col_start, col_end)) # keep the following column indices
H_param0_param1 = SubmatrixLinearOperator(H, row_idxs, col_idxs)
As the following test shows, this linear operator indeed represents the desired rectangular Hessian block:
H_param0_param1_functorch = extract_block(H_functorch, params, i, j)
assert allclose_report(
H_param0_param1_functorch,
H_param0_param1_functorch @ eye(param_dims[j], device=DEVICE),
)
Arbitrary sub-matrices
So far, we were constrained to blocks spanned by parameter tensors rather
than arbitrary elements. As the name SubmatrixLinearOperator
suggests, we can use it to create arbitrary sub-matrices.
As an example, let’s say we want to keep rows [0, 13, 42] of the
Hessian, and columns [1, 2, 3]. This works as follows:
row_idxs = [0, 13, 42] # keep the following row indices
col_idxs = [1, 2, 3] # keep the following column indices
H_sub = SubmatrixLinearOperator(H, row_idxs, col_idxs)
H_sub_functorch = H_functorch[row_idxs, :][:, col_idxs]
Quick check to see if it worked:
assert allclose_report(H_sub_functorch, H_sub @ eye(len(col_idxs), device=DEVICE))
Looks good.
Performance remarks
By the way, using this interface, we could have also constructed the first parameter’s Hessian as follows:
i, j = 0, 0
row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1])
col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1])
row_idxs = list(range(row_start, row_end))
col_idxs = list(range(col_start, col_end))
H_param0_alternative = SubmatrixLinearOperator(H, row_idxs, col_idxs)
assert allclose_report(
H_param0_functorch, H_param0_alternative @ eye(param_dims[0], device=DEVICE)
)
In general though, it is a good idea to first reduce the linear operator’s
size as much as possible (in our case, by restricting the parameters to the
necessary ones using the params argument in
HessianLinearOperator) and apply slicing afterwards to save
computations.
In our example, the matrix-vector product of H_param0 should
therefore be faster than that of H_param0_alternative:
H_param0.matvec: 8.54e-03 s
H_param0_alternative.matvec: 9.19e-03 s
That’s all for now.
Total running time of the script: (0 minutes 0.111 seconds)