Source code for curvlinops.diagonal.hutchinson

"""Hutchinson-style matrix diagonal estimation."""

from torch import Tensor, column_stack, einsum

from curvlinops._torch_base import PyTorchLinearOperator
from curvlinops.sampling import random_vector
from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim


[docs] def hutchinson_diag( A: PyTorchLinearOperator | Tensor, num_matvecs: int, distribution: str = "rademacher", ) -> Tensor: r"""Estimate a linear operator's diagonal using Hutchinson's method. For details, see - Bekas, C., Kokiopoulou, E., & Saad, Y. (2007). An estimator for the diagonal of a matrix. Applied Numerical Mathematics. Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal :math:`\mathrm{diag}(\mathbf{A})` by drawing random vectors :math:`N` :math:`\mathbf{v}_n \sim \mathbf{v}` from a distribution :math:`\mathbf{v}` that satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}`, and compute the estimator .. math:: \mathbf{a} := \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n \odot \mathbf{A} \mathbf{v}_n \approx \mathrm{diag}(\mathbf{A})\,. This estimator is unbiased, .. math:: \mathbb{E}[a_i] = \sum_j \mathbb{E}[v_i A_{i,j} v_j] = \sum_j A_{i,j} \mathbb{E}[v_i v_j] = \sum_j A_{i,j} \delta_{i, j} = A_{i,i}\,. Args: A: A square linear operator whose diagonal is estimated. num_matvecs: Total number of matrix-vector products to use. Must be smaller than the dimension of the linear operator (because otherwise one can evaluate the true diagonal directly at the same cost). distribution: Distribution of the random vectors used for the diagonal estimation. Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. Returns: The estimated diagonal of the linear operator. Example: >>> from torch import manual_seed, rand >>> from torch.linalg import vector_norm >>> _ = manual_seed(0) # make deterministic >>> A = rand(40, 40) >>> diag_A = A.diag() # exact diagonal as reference >>> # one- and multi-sample approximations >>> diag_A_low_precision = hutchinson_diag(A, num_matvecs=1) >>> diag_A_high_precision = hutchinson_diag(A, num_matvecs=30) >>> # compute residual norms >>> error_low_precision = (vector_norm(diag_A - diag_A_low_precision) / vector_norm(diag_A)).item() >>> error_high_precision = (vector_norm(diag_A - diag_A_high_precision) / vector_norm(diag_A)).item() >>> assert error_low_precision > error_high_precision >>> round(error_low_precision, 4), round(error_high_precision, 4) (3.2648, 0.9253) """ dim = assert_is_square(A) assert_matvecs_subseed_dim(A, num_matvecs) G = column_stack([ random_vector(dim, distribution, A.device, A.dtype) for _ in range(num_matvecs) ]) return einsum("ij,ij->i", G, A @ G) / num_matvecs