Source code for curvlinops.inverse

"""Implements linear operator inverses."""

from __future__ import annotations

from collections.abc import Callable

from linear_operator.utils.linear_cg import linear_cg
from numpy import column_stack
from scipy.sparse.linalg import lsmr
from torch import Tensor, as_tensor, cat, device, dtype, isnan

from curvlinops._torch_base import PyTorchLinearOperator


class _InversePyTorchLinearOperator(PyTorchLinearOperator):
    """Base class for inverses of PyTorch linear operators."""

    def __init__(self, A: PyTorchLinearOperator):
        """Store the linear operator whose inverse should be represented.

        Args:
            A: PyTorch linear operator whose inverse is formed.

        Raises:
            ValueError: If the passed linear operator is not quadratic.
        """
        if A._in_shape != A._out_shape:
            raise ValueError(
                "Input linear operator must be square to form an inverse."
                + f"Got {A._in_shape} != {A._out_shape}."
            )
        super().__init__(A._in_shape, A._out_shape)
        self._A = A

    @property
    def dtype(self) -> dtype:
        """Determine the linear operator's data type.

        Returns:
            The linear operator's dtype.
        """
        return self._A.dtype

    @property
    def device(self) -> device:
        """Determine the device the linear operators is defined on.

        Returns:
            The linear operator's device.
        """
        return self._A.device


[docs] class CGInverseLinearOperator(_InversePyTorchLinearOperator): """Class for inverse linear operators via conjugate gradients. Note: Internally, this operator uses GPyTorch's implementation of CG. .. note:: This operator is not compiler-friendly (:func:`torch.compile`). The underlying ``linear_cg`` routine uses data-dependent control flow (convergence checks on tensor values via ``aten.equal`` and Python ``if`` on tensors), which causes graph breaks during tracing. """
[docs] def __init__(self, A: PyTorchLinearOperator, **cg_hyperparameters): """Store the linear operator whose inverse should be represented. Args: A: PyTorch linear operator whose inverse is formed. Must represent a symmetric and positive-definite matrix. cg_hyperparameters: Keyword arguments for GPyTorch's CG implementation. In particular, this includes optional arguments such as ``max_iter``, ``tolerance``, and ``preconditioner``. The ``preconditioner`` should be a callable that applies a left preconditioning operation to a supplied vector. This can be implemented via a ``PyTorchLinearOperator``'s ``__matmul__`` method. For details, see the documentation of the ``linear_cg`` function in https://github.com/cornellius-gp/linear_operator/blob/main/linear_operator/utils/linear_cg.py. Example: >>> from torch import allclose, tensor >>> from torch.linalg import inv >>> from curvlinops import CGInverseLinearOperator >>> from curvlinops.diag import DiagonalLinearOperator >>> from curvlinops.examples import TensorLinearOperator >>> A = tensor([[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]]) >>> b = tensor([1.0, 2.0, 3.0]) >>> A_linop = TensorLinearOperator(A) >>> A_inv_b = CGInverseLinearOperator( ... A_linop, max_iter=3, max_tridiag_iter=3, tolerance=1e-7 ... ) @ b >>> # Use CG with a simple diagonal preconditioner. >>> inverse_diagonal = DiagonalLinearOperator([A.diag().reciprocal()]) >>> A_inv_b_preconditioned = CGInverseLinearOperator( ... A_linop, ... max_iter=3, ... max_tridiag_iter=3, ... tolerance=1e-7, ... preconditioner=inverse_diagonal.__matmul__, ... ) @ b >>> A_inv_b_exact = inv(A) @ b >>> A_inv_b.round(decimals=4) tensor([0.2222, 0.1111, 1.4444]) >>> A_inv_b_preconditioned.round(decimals=4) tensor([0.2222, 0.1111, 1.4444]) >>> allclose(A_inv_b_exact, A_inv_b_preconditioned) True """ super().__init__(A) self._cg_hyperparameters = cg_hyperparameters
def _matmat(self, X: list[Tensor]) -> list[Tensor]: """Multiply X by the inverse of A. Args: X: Matrix for multiplication. Returns: Result of inverse matrix-vector multiplication, ``A⁻¹ @ X``. """ X_flat = cat([x.flatten(end_dim=-2) for x in X]) _, num_vecs = X_flat.shape # batched CG for all vectors in parallel Ainv_X = linear_cg(self._A.__matmul__, X_flat, **self._cg_hyperparameters) return [ r.reshape(*s, num_vecs) for r, s in zip(Ainv_X.split(self._out_shape_flat), self._out_shape) ] def _adjoint(self) -> CGInverseLinearOperator: """Return the linear operator's adjoint: (A^-1)* = (A*)^-1. Returns: A linear operator representing the adjoint. """ return CGInverseLinearOperator(self._A._adjoint(), **self._cg_hyperparameters)
[docs] class LSMRInverseLinearOperator(_InversePyTorchLinearOperator): """Class for inverse PyTorch linear operators via LSMR. See https://arxiv.org/abs/1006.0758 for details on the LSMR algorithm. Note: Internally, this operator uses SciPy's CPU implementation of LSMR as PyTorch currently does not offer an LSMR interface that purely relies on matrix-vector products. .. note:: This operator is not compiler-friendly (:func:`torch.compile`). The matrix-vector product converts tensors to NumPy and calls SciPy's ``lsmr``; these non-Torch operations cannot be traced and cause graph breaks. """
[docs] def __init__(self, A: PyTorchLinearOperator, **lsmr_hyperparameters): """Store the linear operator whose inverse should be represented. Args: A: Linear operator whose inverse is formed. lsmr_hyperparameters: The hyper-parameters that will be passed to the LSMR implementation in SciPy. For more detail, see https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsmr.html. """ super().__init__(A) self._A_scipy = A.to_scipy() self._lsmr_hyperparameters = lsmr_hyperparameters
def _matmat(self, X: list[Tensor]) -> list[Tensor]: """Multiply the inverse of A onto a matrix X in list format. Args: X: Matrix for multiplication in list format. Returns: Result of inverse matrix-matrix multiplication, ``A⁻¹ @ X`` in list format. """ # flatten and convert to numpy X_np = ( cat([x.flatten(end_dim=-2) for x in X]) .cpu() .numpy() .astype(self._A_scipy.dtype) ) _, num_vecs = X_np.shape # apply LSMR to each vector in SciPy (returns solution and info) Ainv_X = [lsmr(self._A_scipy, x, **self._lsmr_hyperparameters) for x in X_np.T] self._lsmr_info = [result[1:] for result in Ainv_X] Ainv_X = column_stack([result[0] for result in Ainv_X]) # convert to PyTorch and unflatten Ainv_X = as_tensor(Ainv_X, device=self.device, dtype=self.dtype) Ainv_X = [ r.reshape(*s, num_vecs) for r, s in zip(Ainv_X.split(self._out_shape_flat), self._out_shape) ] return Ainv_X def _adjoint(self) -> LSMRInverseLinearOperator: """Return the linear operator's adjoint: (A^-1)* = (A*)^-1. Returns: A linear operator representing the adjoint. """ return LSMRInverseLinearOperator( self._A._adjoint(), **self._lsmr_hyperparameters )
[docs] class NeumannInverseLinearOperator(_InversePyTorchLinearOperator): """Class for inverse linear operators via truncated Neumann series. See https://en.wikipedia.org/w/index.php?title=Neumann_series&oldid=1131424698#Approximate_matrix_inversion. Motivated by - Lorraine, J., Vicol, P., & Duvenaud, D. (2020). Optimizing millions of hyperparameters by implicit differentiation. In International Conference on Artificial Intelligence and Statistics (AISTATS). - Wang, A., Nguyen, E., Yang, R., Bae, J., McIlraith, S. A., & Grosse, R. B. (2025). Better Training Data Attribution via Better Inverse Hessian-Vector Products. In Advances in Neural Information Processing Systems (NeurIPS 2025). .. warning:: The Neumann series can be non-convergent. In this case, the iterations will become numerically unstable, leading to ``NaN`` values. .. warning:: The Neumann series can converge slowly. Use :py:class:`curvlinops.CGInverseLinearOperator` for better accuracy. .. note:: With the default ``check_nan=True``, this operator is not compiler-friendly (:func:`torch.compile`): the per-iteration ``isnan`` check introduces data-dependent branching that causes graph breaks. Passing ``check_nan=False`` makes the operator compiler-friendly (the truncated series itself uses only traceable PyTorch ops). """
[docs] def __init__( self, A: PyTorchLinearOperator, num_terms: int = 100, scale: float = 1.0, check_nan: bool = True, preconditioner: None | Callable[[Tensor], Tensor] = None, ): r"""Store the linear operator whose inverse should be represented. The Neumann series for an invertible linear operator :math:`\mathbf{A}` is .. math:: \mathbf{A}^{-1} = \sum_{k=0}^{\infty} \left(\mathbf{I} - \mathbf{A} \right)^k\,, which is convergent if all eigenvalues satisfy :math:`0 < \lambda(\mathbf{A}) < 2`. By re-scaling the matrix by ``scale`` (:math:`\alpha`), we have: .. math:: \mathbf{A}^{-1} = \alpha (\alpha \mathbf{A})^{-1} = \alpha \sum_{k=0}^{\infty} \left(\mathbf{I} - \alpha \mathbf{A} \right)^k\,, which is convergent if :math:`0 < \lambda(\mathbf{A}) < \frac{2}{\alpha}`. Additionally, we truncate the series at ``num_terms`` (:math:`K`): .. math:: \mathbf{A}^{-1} \approx \alpha \sum_{k=0}^{K} \left(\mathbf{I} - \alpha \mathbf{A} \right)^k\,. Args: A: Linear operator whose inverse is formed. num_terms: Number of terms in the truncated Neumann series. Default: ``100``. scale: Scale applied to the matrix in the Neumann iteration. Crucial for convergence of Neumann series (details above). Default: ``1.0``. check_nan: Whether to check for NaNs while applying the truncated Neumann series. Default: ``True``. preconditioner: Optional preconditioner :math:`\mathbf{P}` used in the preconditioned Neumann/Richardson iteration :math:`\mathbf{A}^{-1} \approx \alpha \sum_{k=0}^{K} (\mathbf{I} - \alpha \mathbf{P}\mathbf{A})^k \mathbf{P}`, where :math:`\alpha` is given by ``scale``. This preconditioned formulation is inspired by Wang et al. (NeurIPS 2025). ``preconditioner`` should be a callable that applies a left preconditioning operation to a supplied vector or matrix in tensor format, e.g. a ``PyTorchLinearOperator``'s ``__matmul__`` method. Default: ``None``. """ super().__init__(A) self._num_terms = num_terms self._scale = scale self._check_nan = check_nan self._preconditioner = preconditioner
def _matmat(self, X: list[Tensor]) -> list[Tensor]: """Multiply the inverse of A onto a matrix in list format. Args: X: Matrix for multiplication in list format. Returns: Result of inverse matrix-matrix multiplication, ``A⁻¹ @ X``. Raises: ValueError: If ``NaN`` check is turned on and ``NaN`` values are detected. """ preconditioned = self._preconditioner is not None if not preconditioned: rhs_list = X apply_iteration_operator = self._A._matmat else: # Apply the left preconditioner to a vector in list of tensor. def P(X: list[Tensor]) -> list[Tensor]: X_flat = cat([x.flatten(end_dim=-2) for x in X]) PX_flat = self._preconditioner(X_flat) _, num_vecs = PX_flat.shape return [ r.reshape(*s, num_vecs) for r, s in zip( PX_flat.split(self._out_shape_flat), self._out_shape ) ] rhs_list = P(X) def apply_iteration_operator(v_list: list[Tensor]) -> list[Tensor]: return P(self._A @ v_list) result_list = [x.clone() for x in rhs_list] v_list = [x.clone() for x in rhs_list] for idx in range(self._num_terms): A_v_list = apply_iteration_operator(v_list) v_list = [ v.sub_(A_v, alpha=self._scale) for v, A_v in zip(v_list, A_v_list) ] result_list = [result.add_(v) for result, v in zip(result_list, v_list)] if self._check_nan and any(isnan(result).any() for result in result_list): raise ValueError( f"Detected NaNs after application of {idx}-th term." + " This is probably because the Neumann series is non-convergent." + " Try decreasing `scale` and read the comment on convergence." ) return [result.mul_(self._scale) for result in result_list] def _adjoint(self) -> NeumannInverseLinearOperator: """Return the linear operator's adjoint: (A^-1)* = (A*)^-1. Returns: A linear operator representing the adjoint. Raises: NotImplementedError: If the preconditioner's adjoint cannot be inferred. """ preconditioner = None if self._preconditioner is not None: preconditioner_linop = getattr(self._preconditioner, "__self__", None) if isinstance(preconditioner_linop, PyTorchLinearOperator): preconditioner = preconditioner_linop.adjoint().__matmul__ else: raise NotImplementedError( "Adjoint with a preconditioner is only supported when the " "preconditioner is a bound PyTorchLinearOperator.__matmul__ " "method." ) return NeumannInverseLinearOperator( self._A._adjoint(), num_terms=self._num_terms, scale=self._scale, check_nan=self._check_nan, preconditioner=preconditioner, )