Source code for curvlinops.norm.hutchinson

"""Hutchinson-style matrix norm estimation."""

from torch import Tensor, column_stack

from curvlinops._torch_base import PyTorchLinearOperator
from curvlinops.sampling import random_vector


[docs] def hutchinson_squared_fro( A: Tensor | PyTorchLinearOperator, num_matvecs: int, distribution: str = "rademacher", ) -> Tensor: r"""Estimate the squared Frobenius norm of a matrix using Hutchinson's method. Let :math:`\mathbf{A} \in \mathbb{R}^{M \times N}` be some matrix. It's Frobenius norm :math:`\lVert\mathbf{A}\rVert_\text{F}` is defined via: .. math:: \lVert\mathbf{A}\rVert_\text{F}^2 = \sum_{m=1}^M \sum_{n=1}^N \mathbf{A}_{n,m}^2 = \text{Tr}(\mathbf{A}^\top \mathbf{A}). Due to the last equality, we can use Hutchinson-style trace estimation to estimate the squared Frobenius norm. Args: A: A matrix whose squared Frobenius norm is estimated. num_matvecs: Total number of matrix-vector products to use. Must be smaller than the minimum dimension of the matrix. distribution: Distribution of the random vectors used for the trace estimation. Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. Returns: The estimated squared Frobenius norm of the matrix. Raises: ValueError: If the matrix is not two-dimensional or if the number of matrix- vector products is greater than the minimum dimension of the matrix (because then you can evaluate the true squared Frobenius norm directly atthe same cost). Example: >>> from torch.linalg import matrix_norm >>> from torch import rand, manual_seed >>> _ = manual_seed(0) # make deterministic >>> A = rand(40, 40) >>> fro2_A = matrix_norm(A).item()**2 # reference: exact squared Frobenius norm >>> # one- and multi-sample approximations >>> fro2_A_low_prec = hutchinson_squared_fro(A, num_matvecs=1).item() >>> fro2_A_high_prec = hutchinson_squared_fro(A, num_matvecs=30).item() >>> assert abs(fro2_A - fro2_A_low_prec) > abs(fro2_A - fro2_A_high_prec) >>> round(fro2_A, 1), round(fro2_A_low_prec, 1), round(fro2_A_high_prec, 1) (530.9, 156.7, 628.9) """ if len(A.shape) != 2: raise ValueError(f"A must be a matrix. Got shape {A.shape}.") dim = min(A.shape) if num_matvecs >= dim: raise ValueError( f"num_matvecs ({num_matvecs}) must be less than the minimum dimension of A." ) # Instead of AT @ A, use A @ AT if the matrix is wider than tall if A.shape[1] > A.shape[0]: A = A.T G = column_stack([ random_vector(dim, distribution, A.device, A.dtype) for _ in range(num_matvecs) ]) AG = A @ G return (AG**2 / num_matvecs).sum()