"""Spectral analysis methods for PyTorch linear operators.
From Papyan, 2020:
- Traces of class/cross-class structure pervade deep learning spectra. Journal
of Machine Learning Research (JMLR), https://jmlr.org/papers/v21/20-933.html
"""
from math import log, sqrt
from scipy.linalg import eigh_tridiagonal
from scipy.sparse.linalg import eigsh
from torch import (
Tensor,
as_tensor,
diag_embed,
linspace,
randn,
zeros,
zeros_like,
)
from torch.distributions import Normal
from torch.linalg import eigh, vector_norm
from curvlinops._torch_base import PyTorchLinearOperator
[docs]
def lanczos_approximate_spectrum(
A: PyTorchLinearOperator,
ncv: int,
num_points: int = 1024,
num_repeats: int = 1,
kappa: float = 3.0,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
margin: float = 0.05,
boundaries_tol: float = 1e-2,
) -> tuple[Tensor, Tensor]:
"""Approximate the spectral density p(λ) = 1/d ∑ᵢ δ(λ - λᵢ) of A ∈ Rᵈˣᵈ.
Implements algorithm 2 (:code:`LanczosApproxSpec`) of Papyan, 2020
(https://jmlr.org/papers/v21/20-933.html).
Internally rescales the operator spectrum to the interval [-1; 1] such that
the width ``kappa`` of the Gaussian bumps used to approximate the delta peaks
need not be tweaked.
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors (number of nodes/weights for the quadrature).
num_points: Resolution.
num_repeats: Number of Lanczos quadratures to average the density over.
Default: ``1``. Taken from papyan2020traces, Section D.2.
kappa: Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must
be greater than 1. Default: ``3``. Taken from papyan2020traces, Section D.2.
boundaries: Estimates of the minimum and maximum eigenvalues of ``A``. If left
unspecified, they will be estimated internally.
margin: Relative margin added around the spectral boundary. Default: ``0.05``.
Taken from papyan2020traces, Section D.2.
boundaries_tol: (Only relevant if ``boundaries`` are not specified). Relative
accuracy used to estimate the spectral boundary. ``0`` implies machine
precision. Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
Returns:
Grid points λ and approximated spectral density p(λ) of A.
"""
boundaries = approximate_boundaries(A, tol=boundaries_tol, boundaries=boundaries)
average_density = zeros(num_points, device=A.device, dtype=A.dtype)
for n in range(num_repeats):
lanczos_iter = fast_lanczos(A, ncv)
grid, density = lanczos_approximate_spectrum_from_iter(
lanczos_iter, boundaries, num_points, kappa, margin
)
average_density = (1 - 1 / (n + 1)) * average_density + density / (n + 1)
return grid, average_density
def lanczos_approximate_spectrum_from_iter(
lanczos_iter: tuple[Tensor, Tensor],
boundaries: tuple[float, float],
num_points: int,
kappa: float,
margin: float,
) -> tuple[Tensor, Tensor]:
"""Compute a spectrum approximation from a Lanczos iteration.
Args:
lanczos_iter: Pair ``(evals, evecs)`` from a Lanczos run.
boundaries: Approximate minimum and maximum eigenvalues of the operator.
num_points: Number of grid points.
kappa: Width parameter for the Gaussian bumps.
margin: Relative margin added around the spectral boundary.
Returns:
Grid points and estimated spectral density.
"""
eval_min, eval_max = boundaries
_width = eval_max - eval_min
_padding = margin * _width
eval_min, eval_max = eval_min - _padding, eval_max + _padding
# use normalized operator ``(A - c I) / d`` whose spectrum lies in [-1; 1]
c = (eval_max + eval_min) / 2
d = (eval_max - eval_min) / 2
evals, evecs = lanczos_iter
device, dtype = evals.device, evals.dtype
# estimate on grid [-1; 1]
grid_norm = linspace(-1, 1, num_points, device=device, dtype=dtype)
density = zeros_like(grid_norm)
ncv = evals.shape[0]
nodes = (evals - c) / d
# Repeat as ``(ncv, num_points)`` arrays to avoid broadcasting
grid = grid_norm.reshape((1, num_points)).repeat(ncv, 1)
nodes = nodes.reshape((ncv, 1)).repeat(1, num_points)
weights = (evecs[0, :] ** 2 / d).reshape((ncv, 1)).repeat(1, num_points)
# width of Gaussian bump in [-1; 1]
sigma = 2 / (ncv - 1) / sqrt(8 * log(kappa))
normal_dist = Normal(nodes, sigma)
density = (weights * normal_dist.log_prob(grid).exp()).sum(0)
return linspace(eval_min, eval_max, num_points, device=device, dtype=dtype), density
class _LanczosSpectrumCached:
"""Base class for approximating spectra with Lanczos iterations.
Caches the Lanczos iterations to efficiently produce approximations with different
hyperparameters.
"""
def __init__(self, A: PyTorchLinearOperator, ncv: int):
"""Initialize.
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors (number of nodes/weights for the quadrature).
"""
self._A = A
self._ncv = ncv
self._lanczos_iters: list[tuple[Tensor, Tensor]] = []
def _get_lanczos_iters(self, num_iters: int) -> list[tuple[Tensor, Tensor]]:
while len(self._lanczos_iters) < num_iters:
self._lanczos_iters.append(fast_lanczos(self._A, self._ncv))
return self._lanczos_iters[:num_iters]
[docs]
class LanczosApproximateSpectrumCached(_LanczosSpectrumCached):
"""Class to approximate the spectral density of p(λ) = 1/d ∑ᵢ δ(λ - λᵢ) of A ∈ Rᵈˣᵈ.
Caches Lanczos iterations to efficiently produce spectral density approximations with
different hyperparameters.
"""
[docs]
def __init__(
self,
A: PyTorchLinearOperator,
ncv: int,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
boundaries_tol: float = 1e-2,
):
"""Initialize.
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors (number of nodes/weights for the quadrature).
boundaries: Estimates of the minimum and maximum eigenvalues of ``A``. If
left unspecified, they will be estimated internally.
boundaries_tol: (Only relevant if ``boundaries`` are not specified).
Relative accuracy used to estimate the spectral boundary. ``0`` implies
machine precision. Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
"""
super().__init__(A, ncv)
self._boundaries = approximate_boundaries(
A, tol=boundaries_tol, boundaries=boundaries
)
[docs]
def approximate_spectrum(
self,
num_repeats: int = 1,
num_points: int = 1024,
kappa: float = 3.0,
margin: float = 0.05,
) -> tuple[Tensor, Tensor]:
"""Approximate the spectal density of A.
Args:
num_repeats: Number of Lanczos quadratures to average the density over.
Default: ``1``. Taken from papyan2020traces, Section D.2.
num_points: Resolution. Default: ``1024``.
kappa: Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must
be greater than 1. Default: ``3``. From papyan2020traces, Section D.2.
margin: Relative margin added around the spectral boundary.
Default: ``0.05``. Taken from papyan2020traces, Section D.2.
Returns:
Grid points λ and approximated spectral density p(λ) of A.
"""
spectra = [
lanczos_approximate_spectrum_from_iter(
lanczos_iter, self._boundaries, num_points, kappa, margin
)
for lanczos_iter in self._get_lanczos_iters(num_repeats)
]
grid = spectra[0][0]
spectrum = sum(spectrum[1] for spectrum in spectra) / num_repeats
return grid, spectrum
[docs]
def lanczos_approximate_log_spectrum(
A: PyTorchLinearOperator,
ncv: int,
num_points: int = 1024,
num_repeats: int = 1,
kappa: float = 1.04,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
margin: float = 0.05,
boundaries_tol: float = 1e-2,
epsilon: float = 1e-5,
) -> tuple[Tensor, Tensor]:
"""Approximate the spectral density ``p(λ) = 1/d ∑ᵢ δ(λ - λᵢ)`` of ``log(|A| + εI) ∈ Rᵈˣᵈ``.
Follows the idea of Section C.7 in Papyan, 2020
(https://jmlr.org/papers/v21/20-933.html).
Here, log denotes the natural logarithm (i.e. base e).
Internally rescales the operator spectrum to the interval [-1; 1] such that
the width ``kappa`` of the Gaussian bumps used to approximate the delta peaks
need not be tweaked.
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors (number of nodes/weights for the quadrature).
num_points: Resolution.
num_repeats: Number of Lanczos quadratures to average the density over.
Default: ``1``. Taken from papyan2020traces, Section D.2.
kappa: Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must
be greater than 1. Default: ``1.04``. Obtained by tweaking while reproducing
Fig. 15b from papyan2020traces (not specified by the paper).
boundaries: Estimates of the minimum and maximum eigenvalues of :math:`|A|`. If left
unspecified, they will be estimated internally.
margin: Relative margin added around the spectral boundary. Default: ``0.05``.
Taken from papyan2020traces, Section D.2.
boundaries_tol: (Only relevant if ``boundaries`` are not specified). Relative
accuracy used to estimate the spectral boundary. ``0`` implies machine
precision. Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
epsilon: Shift to increase numerical stability. Default: ``1e-5``. Taken from
papyan2020traces, Section D.2.
Returns:
Grid points λ and approximated spectral density p(λ) of ``log(|A| + εI)``.
"""
boundaries = approximate_boundaries_abs(
A, tol=boundaries_tol, boundaries=boundaries
)
average_density = zeros(num_points, device=A.device, dtype=A.dtype)
for n in range(num_repeats):
lanczos_iter = fast_lanczos(A, ncv)
grid, density = lanczos_approximate_log_spectrum_from_iter(
lanczos_iter, boundaries, num_points, kappa, margin, epsilon
)
average_density = (1 - 1 / (n + 1)) * average_density + density / (n + 1)
return grid, average_density
def lanczos_approximate_log_spectrum_from_iter(
lanczos_iter: tuple[Tensor, Tensor],
boundaries: tuple[float, float],
num_points: int,
kappa: float,
margin: float,
epsilon: float,
) -> tuple[Tensor, Tensor]:
"""Compute a log-spectrum approximation from a Lanczos iteration.
Args:
lanczos_iter: Pair ``(evals, evecs)`` from a Lanczos run.
boundaries: Approximate spectral boundary of ``|A|``.
num_points: Number of grid points.
kappa: Width parameter for the Gaussian bumps.
margin: Relative margin added around the boundary.
epsilon: Positive shift for numerical stability.
Returns:
Grid points and estimated spectral density of ``log(|A| + εI)``.
"""
log_eval_min, log_eval_max = (log(boundary + epsilon) for boundary in boundaries)
_width = log_eval_max - log_eval_min
_padding = margin * _width
log_eval_min, log_eval_max = log_eval_min - _padding, log_eval_max + _padding
# use normalized operator ``(log(|A| + εI) - c I) / d`` with spectrum in [-1; 1]
c = (log_eval_max + log_eval_min) / 2
d = (log_eval_max - log_eval_min) / 2
evals, evecs = lanczos_iter
device, dtype = evals.device, evals.dtype
# estimate on grid [-1; 1]
grid_norm = linspace(-1, 1, num_points, device=device, dtype=dtype)
grid_out = (grid_norm * d + c).exp()
abs_evals = evals.abs() + epsilon
log_evals = abs_evals.log()
nodes = (log_evals - c) / d
# Repeat as ``(ncv, num_points)`` arrays to avoid broadcasting
ncv = evals.shape[0]
grid = grid_norm.reshape((1, num_points)).repeat(ncv, 1)
nodes = nodes.reshape((ncv, 1)).repeat(1, num_points)
weights = (evecs[0, :] ** 2).reshape((ncv, 1)).repeat(1, num_points)
# width of Gaussian bump in [-1; 1]
sigma = 2 / (ncv - 1) / sqrt(8 * log(kappa))
normal_dist = Normal(nodes, sigma)
density = (weights * normal_dist.log_prob(grid).exp()).sum(0) / (d * grid_out)
return grid_out, density
class LanczosApproximateLogSpectrumCached(_LanczosSpectrumCached):
"""Class to approximate p(λ) = 1/d ∑ᵢ δ(λ - λᵢ) of log(|A| + εI) ∈ Rᵈˣᵈ.
Caches Lanczos iterations to efficiently produce spectral density approximations with
different hyperparameters.
"""
def __init__(
self,
A: PyTorchLinearOperator,
ncv: int,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
boundaries_tol: float = 1e-2,
):
"""Initialize.
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors (number of nodes/weights for the quadrature).
boundaries: Estimates of the minimum and maximum eigenvalues of ``|A|``. If
left unspecified, they will be estimated internally.
boundaries_tol: (Only relevant if ``boundaries`` are not specified).
Relative accuracy used to estimate the spectral boundary. ``0`` implies
machine precision. Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
"""
super().__init__(A, ncv)
self._boundaries = approximate_boundaries_abs(
A, tol=boundaries_tol, boundaries=boundaries
)
def approximate_log_spectrum(
self,
num_repeats: int = 1,
num_points: int = 1024,
kappa: float = 3.0,
margin: float = 0.05,
epsilon: float = 1e-5,
) -> tuple[Tensor, Tensor]:
"""Approximate the spectal density of A.
Args:
num_repeats: Number of Lanczos quadratures to average the density over.
Default: ``1``. Taken from papyan2020traces, Section D.2.
num_points: Resolution. Default: ``1024``.
kappa: Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must
be greater than 1. Default: ``3``. From papyan2020traces, Section D.2.
margin: Relative margin added around the spectral boundary.
Default: ``0.05``. Taken from papyan2020traces, Section D.2.
epsilon: Shift to increase numerical stability. Default: ``1e-5``. Taken from
papyan2020traces, Section D.2.
Returns:
Grid points λ and approximated spectral density p(λ) of log(|A| + εI).
"""
spectra = [
lanczos_approximate_log_spectrum_from_iter(
lanczos_iter, self._boundaries, num_points, kappa, margin, epsilon
)
for lanczos_iter in self._get_lanczos_iters(num_repeats)
]
grid = spectra[0][0]
spectrum = sum(spectrum[1] for spectrum in spectra) / num_repeats
return grid, spectrum
def fast_lanczos(
A: PyTorchLinearOperator, ncv: int, use_eigh_tridiagonal: bool = False
) -> tuple[Tensor, Tensor]:
"""Lanczos iterations for large-scale problems (no reorthogonalization step).
Implements algorithm 2 of Papyan, 2020 (https://jmlr.org/papers/v21/20-933.html).
Args:
A: Symmetric linear operator.
ncv: Number of Lanczos vectors.
use_eigh_tridiagonal: Whether to use eigh_tridiagonal to eigen-decompose the
tri-diagonal matrix. Default: ``False``. Setting this value to ``True``
results in faster eigen-decomposition, but is less stable.
Returns:
Eigenvalues and eigenvectors of the tri-diagonal matrix built up during
Lanczos iterations. ``evecs[:, i]`` is normalized eigenvector of ``evals[i]``.
"""
device, dtype = A.device, A.dtype
alphas = zeros(ncv, device=device, dtype=dtype)
betas = zeros(ncv - 1, device=device, dtype=dtype)
dim = A.shape[1]
v, v_prev = None, None
for m in range(ncv):
if m == 0:
v = randn(dim, device=device, dtype=dtype)
v /= vector_norm(v)
v_next = A @ v
else:
v_next = A @ v - betas[m - 1] * v_prev
alphas[m] = (v_next * v).sum()
v_next -= alphas[m] * v
last = m == ncv - 1
if not last:
betas[m] = vector_norm(v_next)
v_next /= betas[m]
v_prev = v
v = v_next
if use_eigh_tridiagonal:
# Convert to NumPy for SciPy operations
evals_np, evecs_np = eigh_tridiagonal(
alphas.detach().cpu().numpy(), betas.detach().cpu().numpy()
)
# Convert back to PyTorch tensors
evals = as_tensor(evals_np, device=device, dtype=dtype)
evecs = as_tensor(evecs_np, device=device, dtype=dtype)
else:
# Build tridiagonal matrix using PyTorch
T = (
diag_embed(alphas)
+ diag_embed(betas, offset=1)
+ diag_embed(betas, offset=-1)
)
evals, evecs = eigh(T)
return evals, evecs
def approximate_boundaries(
A: PyTorchLinearOperator,
tol: float = 1e-2,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
) -> tuple[float, float]:
"""Approximate λₘᵢₙ(A) and λₘₐₓ(A) using SciPy's ``eigsh``.
Args:
A: Symmetric linear operator.
tol: Relative accuracy used by ``eigsh``. ``0`` implies machine precision.
Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
boundaries: A tuple of floats that specifies known parts of the boundaries
which consequently won't be recomputed. Default: ``None``.
Returns:
Estimates of λₘᵢₙ and λₘₐₓ.
"""
eigsh_kwargs = {"tol": tol, "return_eigenvectors": False}
A_scipy = A.to_scipy()
if boundaries is None:
eval_min, eval_max = eigsh(A_scipy, k=2, which="BE", **eigsh_kwargs)
else:
eval_min, eval_max = boundaries
if eval_min is None:
(eval_min,) = eigsh(A_scipy, k=1, which="SA", **eigsh_kwargs)
if eval_max is None:
(eval_max,) = eigsh(A_scipy, k=1, which="LA", **eigsh_kwargs)
return eval_min, eval_max
def approximate_boundaries_abs(
A: PyTorchLinearOperator,
tol: float = 1e-2,
boundaries: (
tuple[float, float] | tuple[float, None] | tuple[None, float] | None
) = None,
) -> tuple[float, float]:
"""Approximate λₘᵢₙ(|A|) and λₘₐₓ(|A|) using SciPy's ``eigsh``.
Args:
A: Symmetric linear operator.
tol: Relative accuracy used by ``eigsh``. ``0`` implies machine precision.
Default: ``1e-2``, from
https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
boundaries: A tuple of floats that specifies known parts of the boundaries
which consequently won't be recomputed. Default: ``None``.
Returns:
Estimates of λₘᵢₙ and λₘₐₓ of :math:`|A|`.
"""
eval_min, eval_max = (None, None) if boundaries is None else boundaries
eigsh_kwargs = {"tol": tol, "return_eigenvectors": False}
A_scipy = A.to_scipy()
if eval_max is None:
(eval_max,) = eigsh(A_scipy, k=1, which="LM", **eigsh_kwargs)
if eval_min is None:
(eval_min,) = eigsh(A_scipy, k=1, which="SM", **eigsh_kwargs)
return abs(eval_min), abs(eval_max)