"""Implements slices of linear operators."""
from __future__ import annotations
from torch import Tensor, device, dtype, zeros
from curvlinops._torch_base import PyTorchLinearOperator
[docs]
class SubmatrixLinearOperator(PyTorchLinearOperator):
"""Class for sub-matrices of linear operators.
.. note::
This operator is not compiler-friendly (:func:`torch.compile`). Its matrix-vector
product dispatches through the wrapped operator's ``__matmul__``, and Dynamo
cannot proxy a user-defined linear operator as an argument, which causes graph
breaks.
"""
[docs]
def __init__(
self, A: PyTorchLinearOperator, row_idxs: list[int], col_idxs: list[int]
):
"""Store the linear operator and indices of its sub-matrix.
Represents the sub-matrix ``A[row_idxs, :][col_idxs, :]``.
Args:
A: A linear operator.
row_idxs: The sub-matrix's row indices.
col_idxs: The sub-matrix's column indices.
"""
self._A = A
self.set_submatrix(row_idxs, col_idxs)
@property
def dtype(self) -> dtype:
"""Determine the linear operator's data type.
Returns:
The linear operator's dtype.
"""
return self._A.dtype
@property
def device(self) -> device:
"""Determine the device the linear operators is defined on.
Returns:
The linear operator's device.
"""
return self._A.device
[docs]
def set_submatrix(self, row_idxs: list[int], col_idxs: list[int]):
"""Define the sub-matrix.
Internally sets the linear operator's shape.
Args:
row_idxs: The sub-matrix's row indices.
col_idxs: The sub-matrix's column indices.
Raises:
ValueError: If the index lists contain duplicate values, non-integers,
or out-of-bounds indices.
"""
shape = []
for ax_idx, idxs in enumerate([row_idxs, col_idxs]):
if any(not isinstance(i, int) for i in idxs):
raise ValueError("Index lists must contain integers.")
if len(idxs) != len(set(idxs)):
raise ValueError("Index lists cannot contain duplicates.")
if any(i < 0 or i >= self._A.shape[ax_idx] for i in idxs):
raise ValueError("Index lists contain out-of-bounds indices.")
shape.append(len(idxs))
in_shape, out_shape = [(shape[1],)], [(shape[0],)]
super().__init__(in_shape, out_shape)
self._row_idxs = row_idxs
self._col_idxs = col_idxs
def _matmat(self, X: list[Tensor]) -> list[Tensor]:
"""Matrix-matrix multiplication.
Args:
X: A list that contains a single tensor, which is the input tensor.
Returns:
A list that contains a single tensor, which is the output tensor.
"""
(M,) = X
V = zeros(self._A.shape[1], M.shape[-1], dtype=self.dtype, device=self.device)
V[self._col_idxs] = M
AV = self._A @ V
return [AV[self._row_idxs]]
def _adjoint(self) -> SubmatrixLinearOperator:
"""Return the adjoint of the sub-matrix.
For that, we need to take the adjoint operator, and swap row and column indices.
Returns:
The linear operator for the adjoint sub-matrix.
"""
return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs)