Note
Go to the end to download the full example code.
Eigenvalues
This example demonstrates how to compute a subset of eigenvalues of a linear
operator, using scipy.sparse.linalg.eigsh(). Concretely, we will compute
leading eigenvalues of the Hessian.
As always, imports go first.
import numpy
import scipy
import torch
from torch import nn
from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.examples.utils import report_nonclose
# make deterministic
torch.manual_seed(0)
numpy.random.seed(0)
Setup
We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function.
N = 20
D_in = 7
D_hidden = 5
D_out = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X1, y1 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)
X2, y2 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)
model = nn.Sequential(
nn.Linear(D_in, D_hidden),
nn.ReLU(),
nn.Linear(D_hidden, D_hidden),
nn.Sigmoid(),
nn.Linear(D_hidden, D_out),
).to(DEVICE)
params = [p for p in model.parameters() if p.requires_grad]
loss_function = nn.MSELoss(reduction="mean").to(DEVICE)
Linear operator
We are ready to setup the linear operator. In this example, we will use the Hessian.
data = [(X1, y1), (X2, y2)]
H = HessianLinearOperator(model, loss_function, params, data)
Leading eigenvalues
Through scipy.sparse.linalg.eigsh(), we can obtain the leading
\(k=3\) eigenvalues.
k = 3
which = "LA" # largest algebraic
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)
print(f"Leading {k} Hessian eigenvalues: {top_k_evals}")
Leading 3 Hessian eigenvalues: [1.4109129 1.4281708 1.4943049]
Verifying results
To double-check this result, let’s compute the Hessian with
functorch, compute all its eigenvalues with
scipy.linalg.eigh(), then extract the top \(k\).
H_functorch = (
functorch_hessian(model, loss_function, params, data).detach().cpu().numpy()
)
evals_functorch, _ = scipy.linalg.eigh(H_functorch)
top_k_evals_functorch = evals_functorch[-k:]
print(f"Leading {k} Hessian eigenvalues (functorch): {top_k_evals_functorch}")
Leading 3 Hessian eigenvalues (functorch): [1.4109125 1.4281707 1.4943047]
Both results should match.
print(f"Comparing leading {k} Hessian eigenvalues (linear operator vs. functorch).")
report_nonclose(top_k_evals, top_k_evals_functorch)
Comparing leading 3 Hessian eigenvalues (linear operator vs. functorch).
Compared arrays match.
scipy.sparse.linalg.eigsh() can also compute other subsets of
eigenvalues, and also their associated eigenvectors. Check out its
documentation for more!
Total running time of the script: (0 minutes 3.166 seconds)