Note
Go to the end to download the full example code.
Eigenvalues
This example demonstrates how to compute a subset of eigenvalues of a linear
operator, using scipy.sparse.linalg.eigsh(). Concretely, we will compute
leading eigenvalues of the Hessian.
As always, imports go first.
from contextlib import redirect_stderr
from io import StringIO
import numpy
import scipy
import torch
from torch import nn
from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.utils import allclose_report
# make deterministic
torch.manual_seed(0)
numpy.random.seed(0)
Setup
We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function.
N = 20
D_in = 7
D_hidden = 5
D_out = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X1, y1 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)
X2, y2 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)
model = nn.Sequential(
nn.Linear(D_in, D_hidden),
nn.ReLU(),
nn.Linear(D_hidden, D_hidden),
nn.Sigmoid(),
nn.Linear(D_hidden, D_out),
).to(DEVICE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
loss_function = nn.MSELoss(reduction="mean").to(DEVICE)
Linear operator
We are ready to setup the linear operator. In this example, we will use the Hessian.
data = [(X1, y1), (X2, y2)]
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()
Leading eigenvalues
Through scipy.sparse.linalg.eigsh(), we can obtain the leading
\(k=3\) eigenvalues.
k = 3
which = "LM" # largest magnitude
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)
print(f"Leading {k} Hessian eigenvalues: {top_k_evals}")
Leading 3 Hessian eigenvalues: [1.41091264 1.42817086 1.49430477]
Verifying results
To double-check this result, let’s compute the Hessian with
functorch, compute all its eigenvalues with
scipy.linalg.eigh(), then extract the top \(k\).
H_functorch = functorch_hessian(model, loss_function, params, data).detach()
evals_functorch, _ = torch.linalg.eigh(H_functorch)
top_k_evals_functorch = evals_functorch[-k:]
print(f"Leading {k} Hessian eigenvalues (functorch): {top_k_evals_functorch}")
Leading 3 Hessian eigenvalues (functorch): tensor([1.4109, 1.4282, 1.4943])
Both results should match.
print(f"Comparing leading {k} Hessian eigenvalues (linear operator vs. functorch).")
assert allclose_report(top_k_evals, top_k_evals_functorch.double(), rtol=1e-4)
Comparing leading 3 Hessian eigenvalues (linear operator vs. functorch).
scipy.sparse.linalg.eigsh() can also compute other subsets of
eigenvalues, and also their associated eigenvectors. Check out its
documentation for more!
Power iteration versus eigsh
Here, we compare the query efficiency of scipy.sparse.linalg.eigsh() with the
power iteration method, a simple
method to compute the leading eigenvalues (in terms of magnitude). We re-use the im-
plementation from the PyHessian library
and adapt it to work with SciPy arrays rather than PyTorch tensors:
def power_method(
A: scipy.sparse.linalg.LinearOperator,
max_iterations: int = 100,
tol: float = 1e-3,
k: int = 1,
) -> tuple[numpy.ndarray, numpy.ndarray]:
"""Compute the top-k eigenpairs of a linear operator using power iteration.
Code modified from PyHessian, see
https://github.com/amirgholami/PyHessian/blob/72e5f0a0d06142387fccdab2226b4c6bae088202/pyhessian/hessian.py#L111-L156
Args:
A: Linear operator of dimension ``D`` whose top eigenpairs will be computed.
max_iterations: Maximum number of iterations. Defaults to ``100``.
tol: Relative tolerance between two consecutive iterations that has to be
reached for convergence. Defaults to ``1e-3``.
k: Number of eigenpairs to compute. Defaults to ``1``.
Returns:
The eigenvalues as array of shape ``[k]`` in descending order, and their
corresponding eigenvectors as array of shape ``[D, k]``.
"""
eigenvalues = []
eigenvectors = []
def normalize(v: numpy.ndarray) -> numpy.ndarray:
return v / numpy.linalg.norm(v)
def orthonormalize(v: numpy.ndarray, basis: list[numpy.ndarray]) -> numpy.ndarray:
for basis_vector in basis:
v -= numpy.dot(v, basis_vector) * basis_vector
return normalize(v)
computed_dim = 0
while computed_dim < k:
eigenvalue = None
v = normalize(numpy.random.randn(A.shape[0]))
for _ in range(max_iterations):
v = orthonormalize(v, eigenvectors)
Av = A @ v
tmp_eigenvalue = v.dot(Av)
v = normalize(Av)
if eigenvalue is None:
eigenvalue = tmp_eigenvalue
elif abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
break
else:
eigenvalue = tmp_eigenvalue
eigenvalues.append(eigenvalue)
eigenvectors.append(v)
computed_dim += 1
# sort in ascending order and convert into arrays
eigenvalues = numpy.array(eigenvalues[::-1])
eigenvectors = numpy.array(eigenvectors[::-1])
return eigenvalues, eigenvectors
Let’s compute the top-3 eigenvalues via power iteration and verify they roughly match.
Note that we are using a smaller tol value than the PyHessian default value
here to get better convergence, and we have to use relatively large tolerances for the
comparison (which we didn’t do when comparing eigsh with eigh).
top_k_evals_power, _ = power_method(H, tol=1e-4, k=k)
print(f"Comparing leading {k} Hessian eigenvalues (eigsh vs. power).")
assert allclose_report(
top_k_evals_functorch.double(), top_k_evals_power, rtol=2e-2, atol=1e-6
)
Comparing leading 3 Hessian eigenvalues (eigsh vs. power).
This indicates that the power method achieves poorer accuracy than eigsh. But
does it therefore require fewer matrix-vector products? To answer this, let’s turn on
the linear operator’s progress bar, which allows us to count the number of
matrix-vector products invoked by both eigen-solvers:
H = HessianLinearOperator(
model, loss_function, params, data, progressbar=True
).to_scipy()
# determine number of matrix-vector products used by `eigsh`
with StringIO() as buf, redirect_stderr(buf):
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)
# The tqdm progressbar will print "matmat" for each batch in a matrix-vector
# product. Therefore, we need to divide by the number of batches
queries_eigsh = buf.getvalue().count("matmat") // len(data)
print(f"eigsh used {queries_eigsh} matrix-vector products.")
# determine number of matrix-vector products used by power iteration
with StringIO() as buf, redirect_stderr(buf):
top_k_evals_power, _ = power_method(H, k=k, tol=1e-4)
# The tqdm progressbar will print "matmat" for each batch in a matrix-vector
# product. Therefore, we need to divide by the number of batches
queries_power = buf.getvalue().count("matmat") // len(data)
print(f"Power iteration used {queries_power} matrix-vector products.")
assert queries_power > queries_eigsh
HessianLinearOperator.data_statistics (on cpu): 0%| | 0/2 [00:00<?, ?it/s]
HessianLinearOperator.data_statistics (on cpu): 100%|██████████| 2/2 [00:00<00:00, 29330.80it/s]
HessianLinearOperator.batch_prediction_loss_gradient (on cpu): 0%| | 0/2 [00:00<?, ?it/s]
HessianLinearOperator.batch_prediction_loss_gradient (on cpu): 0%| | 0/2 [00:00<?, ?it/s]
HessianLinearOperator.batch_prediction_loss_gradient (on cpu): 100%|██████████| 2/2 [00:00<00:00, 508.31it/s]
HessianLinearOperator.batch_prediction_loss_gradient (on cpu): 50%|█████ | 1/2 [00:00<00:00, 336.11it/s]
HessianLinearOperator._matmat (on cpu): 0%| | 0/2 [00:00<?, ?it/s]
HessianLinearOperator._matmat (on cpu): 100%|██████████| 2/2 [00:00<00:00, 221.12it/s]
HessianLinearOperator._matmat (on cpu): 0%| | 0/2 [00:00<?, ?it/s]
HessianLinearOperator._matmat (on cpu): 100%|██████████| 2/2 [00:00<00:00, 221.82it/s]
eigsh used 21 matrix-vector products.
Power iteration used 72 matrix-vector products.
Sadly, the power iteration also does not offer computational benefits, consuming
more matrix-vector products than eigsh. While it is elegant and simple,
it cannot compete with eigsh, at least in the comparison provided here.
Therefore, we recommend using eigsh for computing eigenvalues. This method
becomes accessible because curvlinops interfaces with SciPy’s linear
operators.
Total running time of the script: (0 minutes 1.806 seconds)