Eigenvalues

This example demonstrates how to compute a subset of eigenvalues of a linear operator, using scipy.sparse.linalg.eigsh(). Concretely, we will compute leading eigenvalues of the Hessian.

As always, imports go first.

from contextlib import redirect_stderr
from io import StringIO

import numpy
import scipy
import torch
from torch import nn

from curvlinops import HessianLinearOperator
from curvlinops.examples.functorch import functorch_hessian
from curvlinops.utils import allclose_report

# make deterministic
torch.manual_seed(0)
numpy.random.seed(0)

Setup

We will use synthetic data, consisting of two mini-batches, a small MLP, and mean-squared error as loss function.

N = 20
D_in = 7
D_hidden = 5
D_out = 3

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X1, y1 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)
X2, y2 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE)

model = nn.Sequential(
    nn.Linear(D_in, D_hidden),
    nn.ReLU(),
    nn.Linear(D_hidden, D_hidden),
    nn.Sigmoid(),
    nn.Linear(D_hidden, D_out),
).to(DEVICE)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}

loss_function = nn.MSELoss(reduction="mean").to(DEVICE)

Linear operator

We are ready to setup the linear operator. In this example, we will use the Hessian.

data = [(X1, y1), (X2, y2)]
H = HessianLinearOperator(model, loss_function, params, data).to_scipy()

Leading eigenvalues

Through scipy.sparse.linalg.eigsh(), we can obtain the leading \(k=3\) eigenvalues.

k = 3
which = "LM"  # largest magnitude
top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)

print(f"Leading {k} Hessian eigenvalues: {top_k_evals}")
Leading 3 Hessian eigenvalues: [1.41091264 1.42817086 1.49430477]

Verifying results

To double-check this result, let’s compute the Hessian with functorch, compute all its eigenvalues with scipy.linalg.eigh(), then extract the top \(k\).

H_functorch = functorch_hessian(model, loss_function, params, data).detach()
evals_functorch, _ = torch.linalg.eigh(H_functorch)
top_k_evals_functorch = evals_functorch[-k:]

print(f"Leading {k} Hessian eigenvalues (functorch): {top_k_evals_functorch}")
Leading 3 Hessian eigenvalues (functorch): tensor([1.4109, 1.4282, 1.4943])

Both results should match.

print(f"Comparing leading {k} Hessian eigenvalues (linear operator vs. functorch).")
assert allclose_report(top_k_evals, top_k_evals_functorch.double(), rtol=1e-4)
Comparing leading 3 Hessian eigenvalues (linear operator vs. functorch).

scipy.sparse.linalg.eigsh() can also compute other subsets of eigenvalues, and also their associated eigenvectors. Check out its documentation for more!

Power iteration versus eigsh

Here, we compare the query efficiency of scipy.sparse.linalg.eigsh() with the power iteration method, a simple method to compute the leading eigenvalues (in terms of magnitude). We re-use the im- plementation from the PyHessian library and adapt it to work with SciPy arrays rather than PyTorch tensors:

def power_method(
    A: scipy.sparse.linalg.LinearOperator,
    max_iterations: int = 100,
    tol: float = 1e-3,
    k: int = 1,
) -> tuple[numpy.ndarray, numpy.ndarray]:
    """Compute the top-k eigenpairs of a linear operator using power iteration.

    Code modified from PyHessian, see
    https://github.com/amirgholami/PyHessian/blob/72e5f0a0d06142387fccdab2226b4c6bae088202/pyhessian/hessian.py#L111-L156

    Args:
        A: Linear operator of dimension ``D`` whose top eigenpairs will be computed.
        max_iterations: Maximum number of iterations. Defaults to ``100``.
        tol: Relative tolerance between two consecutive iterations that has to be
            reached for convergence. Defaults to ``1e-3``.
        k: Number of eigenpairs to compute. Defaults to ``1``.

    Returns:
        The eigenvalues as array of shape ``[k]`` in descending order, and their
        corresponding eigenvectors as array of shape ``[D, k]``.
    """
    eigenvalues = []
    eigenvectors = []

    def normalize(v: numpy.ndarray) -> numpy.ndarray:
        return v / numpy.linalg.norm(v)

    def orthonormalize(v: numpy.ndarray, basis: list[numpy.ndarray]) -> numpy.ndarray:
        for basis_vector in basis:
            v -= numpy.dot(v, basis_vector) * basis_vector
        return normalize(v)

    computed_dim = 0
    while computed_dim < k:
        eigenvalue = None
        v = normalize(numpy.random.randn(A.shape[0]))

        for _ in range(max_iterations):
            v = orthonormalize(v, eigenvectors)
            Av = A @ v

            tmp_eigenvalue = v.dot(Av)
            v = normalize(Av)

            if eigenvalue is None:
                eigenvalue = tmp_eigenvalue
            elif abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
                break
            else:
                eigenvalue = tmp_eigenvalue

        eigenvalues.append(eigenvalue)
        eigenvectors.append(v)
        computed_dim += 1

    # sort in ascending order and convert into arrays
    eigenvalues = numpy.array(eigenvalues[::-1])
    eigenvectors = numpy.array(eigenvectors[::-1])

    return eigenvalues, eigenvectors

Let’s compute the top-3 eigenvalues via power iteration and verify they roughly match. Note that we are using a smaller tol value than the PyHessian default value here to get better convergence, and we have to use relatively large tolerances for the comparison (which we didn’t do when comparing eigsh with eigh).

top_k_evals_power, _ = power_method(H, tol=1e-4, k=k)
print(f"Comparing leading {k} Hessian eigenvalues (eigsh vs. power).")
assert allclose_report(
    top_k_evals_functorch.double(), top_k_evals_power, rtol=2e-2, atol=1e-6
)
Comparing leading 3 Hessian eigenvalues (eigsh vs. power).

This indicates that the power method achieves poorer accuracy than eigsh. But does it therefore require fewer matrix-vector products? To answer this, let’s turn on the linear operator’s progress bar, which allows us to count the number of matrix-vector products invoked by both eigen-solvers:

H = HessianLinearOperator(
    model, loss_function, params, data, progressbar=True
).to_scipy()

# determine number of matrix-vector products used by `eigsh`
with StringIO() as buf, redirect_stderr(buf):
    top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which)
    # The tqdm progressbar will print "matmat" for each batch in a matrix-vector
    # product. Therefore, we need to divide by the number of batches
    queries_eigsh = buf.getvalue().count("matmat") // len(data)
print(f"eigsh used {queries_eigsh} matrix-vector products.")

# determine number of matrix-vector products used by power iteration
with StringIO() as buf, redirect_stderr(buf):
    top_k_evals_power, _ = power_method(H, k=k, tol=1e-4)
    # The tqdm progressbar will print "matmat" for each batch in a matrix-vector
    # product. Therefore, we need to divide by the number of batches
    queries_power = buf.getvalue().count("matmat") // len(data)
print(f"Power iteration used {queries_power} matrix-vector products.")

assert queries_power > queries_eigsh
HessianLinearOperator.data_statistics (on cpu):   0%|          | 0/2 [00:00<?, ?it/s]
HessianLinearOperator.data_statistics (on cpu): 100%|██████████| 2/2 [00:00<00:00, 29330.80it/s]

HessianLinearOperator.batch_prediction_loss_gradient (on cpu):   0%|          | 0/2 [00:00<?, ?it/s]

HessianLinearOperator.batch_prediction_loss_gradient (on cpu):   0%|          | 0/2 [00:00<?, ?it/s]
HessianLinearOperator.batch_prediction_loss_gradient (on cpu): 100%|██████████| 2/2 [00:00<00:00, 508.31it/s]

HessianLinearOperator.batch_prediction_loss_gradient (on cpu):  50%|█████     | 1/2 [00:00<00:00, 336.11it/s]

HessianLinearOperator._matmat (on cpu):   0%|          | 0/2 [00:00<?, ?it/s]
HessianLinearOperator._matmat (on cpu): 100%|██████████| 2/2 [00:00<00:00, 221.12it/s]

HessianLinearOperator._matmat (on cpu):   0%|          | 0/2 [00:00<?, ?it/s]
HessianLinearOperator._matmat (on cpu): 100%|██████████| 2/2 [00:00<00:00, 221.82it/s]
eigsh used 21 matrix-vector products.
Power iteration used 72 matrix-vector products.

Sadly, the power iteration also does not offer computational benefits, consuming more matrix-vector products than eigsh. While it is elegant and simple, it cannot compete with eigsh, at least in the comparison provided here.

Therefore, we recommend using eigsh for computing eigenvalues. This method becomes accessible because curvlinops interfaces with SciPy’s linear operators.

Total running time of the script: (0 minutes 1.806 seconds)

Gallery generated by Sphinx-Gallery