Linear operators
Hessian
- class curvlinops.HessianLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor] | None,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- num_per_example_loss_terms: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for the Hessian of an empirical risk.
Consider the empirical risk
\[\mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\]with \(c = \frac{1}{N}\) for
reduction='mean'and \(c=1\) forreduction='sum'. The Hessian matrix is\[\nabla^2_{\mathbf{\theta}} \mathcal{L} = c \sum_{n=1}^{N} \nabla_{\mathbf{\theta}}^2 \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\,.\]Example
>>> from torch import rand, eye, allclose, kron, manual_seed >>> from torch.nn import Linear, MSELoss >>> from curvlinops import HessianLinearOperator >>> >>> # Create a simple linear model without bias >>> _ = manual_seed(0) # make deterministic >>> D_in, D_out = 4, 2 >>> num_data, num_batches = 10, 3 >>> model = Linear(D_in, D_out, bias=False) >>> params = dict(model.named_parameters()) >>> loss_func = MSELoss(reduction='sum') >>> >>> # Generate synthetic dataset and chunk into batches >>> X, y = rand(num_data, D_in), rand(num_data, D_out) >>> data = list(zip(X.split(num_batches), y.split(num_batches))) >>> >>> # Create Hessian linear operator >>> H_op = HessianLinearOperator(model, loss_func, params, data) >>> >>> # Compare with the known Hessian matrix 2 I ⊗ Xᵀ X >>> H_mat = 2 * kron(eye(D_out), X.T @ X) >>> P = sum(p.numel() for p in params.values()) >>> v = rand(P) # generate a random vector >>> (H_mat @ v).allclose(H_op @ v) True
- SELF_ADJOINT
Whether the linear operator is self-adjoint (
Truefor Hessians).- Type:
bool
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor] | None,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- num_per_example_loss_terms: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for curvature matrices of empirical risks.
Note
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.loss_func – Loss function criterion. Maps predictions and mini-batch labels to a scalar value. If
None, there is no loss function and the represented matrix is independent of the loss function.params – The parameter values at which the curvature matrix is evaluated. A dictionary mapping parameter names to tensors (use
dict(model.named_parameters())). The parameter ordering follows dict insertion order.data – Source from which mini-batches can be drawn, for instance a list of mini-batches
[(X, y), ...]or a torchDataLoader. Note thatXcould be adictorUserDict; this is useful for custom models. In this case, you must (i) specify thebatch_size_fnargument, and (ii) take care of preprocessing likeX.to(device)inside of yourmodel.forward()function.progressbar – Show a progressbar during matrix-multiplication. Default:
False.check_deterministic – Probe that model and data are deterministic, i.e. that the data does not use
drop_lastor data augmentation. Also, the model’s forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default:True. This is a safeguard, only turn it off if you know what you are doing.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.num_per_example_loss_terms – Number of per-example loss terms, e.g. the number of tokens in a sequence. Only used by subclasses with
NEEDS_NUM_PER_EXAMPLE_LOSS_TERMS = True. IfNone, it is inferred from the data when needed. Default:None.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.
Generalized Gauss-Newton
- class curvlinops.GGNLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping], int] | None = None,
- mc_samples: int = 0,
- seed: int = 2147483647,
Linear operator for the generalized Gauss-Newton matrix of an empirical risk.
Consider the empirical risk
\[\mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\]with \(c = \frac{1}{N}\) for
reduction='mean'and \(c=1\) forreduction='sum'. The GGN matrix is\[c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)^\top \left( \nabla_{f_\mathbf{\theta}(\mathbf{x}_n)}^2 \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}_n) \right)\,.\]Denoting \(\mathbf{f}_n = f_{\mathbf{\theta}}(\mathbf{x}_n)\) and using a matrix square root \(\mathbf{S}_n \mathbf{S}_n^\top = \nabla_{\mathbf{f}_n}^2 \ell(\mathbf{f}_n, \mathbf{y}_n)\), this can be rewritten as
\[c \sum_{n=1}^{N} \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)^\top \mathbf{S}_n \mathbf{S}_n^\top \left( \mathbf{J}_{\mathbf{\theta}} \mathbf{f}_n \right)\,.\]When
mc_samples > 0, the loss Hessian’s square root is approximated via Monte-Carlo sampling. For exponential family losses (MSELoss,CrossEntropyLoss,BCEWithLogitsLoss), the loss Hessian equals \(\mathbb{E}_{\tilde{\mathbf{y}}_n \sim q(\cdot \mid \mathbf{f}_n)} [\nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n) \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n)^\top]\), where \(q\) is the model’s predictive distribution. This expectation is approximated by drawing \(M\) samples \(\tilde{\mathbf{y}}_n^{(m)}\) and using the sampled gradients \(\mathbf{g}_{nm} = \nabla_{\mathbf{f}_n} \ell(\mathbf{f}_n, \tilde{\mathbf{y}}_n^{(m)})\) as columns of \(\mathbf{S}_n\):\[\nabla_{\mathbf{f}_n}^2 \ell \approx \frac{1}{M} \sum_{m=1}^{M} \mathbf{g}_{nm} \mathbf{g}_{nm}^\top\,.\]The MC estimate converges to the exact GGN as \(M \to \infty\).
- SELF_ADJOINT
Whether the linear operator is self-adjoint.
Truefor GGNs.- Type:
bool
- MC_SUPPORTED_LOSSES
Loss functions supported by the MC approximation.
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping], int] | None = None,
- mc_samples: int = 0,
- seed: int = 2147483647,
Linear operator for the GGN of an empirical risk.
Note
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.loss_func – Loss function criterion. Maps predictions and mini-batch labels to a scalar value.
params – The parameter values at which the GGN is evaluated. A dictionary mapping parameter names to tensors.
data – Source from which mini-batches can be drawn, for instance a list of mini-batches
[(X, y), ...]or a torchDataLoader. Note thatXcould be adictorUserDict; this is useful for custom models. In this case, you must (i) specify thebatch_size_fnargument, and (ii) take care of preprocessing likeX.to(device)inside of yourmodel.forward()function. When using MC sampling, batches must be presented in the same deterministic order (no shuffling!).progressbar – Show a progressbar during matrix-multiplication. Default:
False.check_deterministic – Probe that model and data are deterministic, i.e. that the data does not use drop_last or data augmentation. Also, the model’s forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default:
True. This is a safeguard, only turn it off if you know what you are doing.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.mc_samples – Number of Monte-Carlo samples to approximate the loss Hessian.
0(default) uses the exact GGN. Positive values activate the MC approximation, which is only supported forMSELoss,CrossEntropyLoss, andBCEWithLogitsLoss.seed – Seed for the internal random number generator used for MC sampling. Only used when
mc_samples > 0. Default:2147483647.
- Raises:
NotImplementedError – If
mc_samples > 0and the loss function is not inMC_SUPPORTED_LOSSES.
- class curvlinops.GGNDiagonalLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- mc_samples: int = 0,
- seed: int = 2147483647,
Diagonal linear operator representing the GGN diagonal.
Computes \(\mathrm{diag}(\mathbf{G})\) where \(\mathbf{G}\) is the generalized Gauss-Newton matrix (see
GGNLinearOperatorfor the full definition). Whenmc_samples > 0, the loss Hessian is approximated via Monte-Carlo sampling from the model’s predictive distribution (seeGGNLinearOperatorfor details).Internally uses a
GGNDiagonalComputerto compute the diagonal, then initializes the parentDiagonalLinearOperatorwith the result.- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- mc_samples: int = 0,
- seed: int = 2147483647,
Initialize the GGN diagonal linear operator.
Constructs a
GGNDiagonalComputerwith the given arguments, computes the diagonal, and passes it to the parent class.- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.loss_func – Loss function criterion. Maps predictions and mini-batch labels to a scalar value.
params – The parameter values at which the GGN diagonal is evaluated. A dictionary mapping parameter names to tensors.
data – Source from which mini-batches can be drawn, for instance a list of mini-batches
[(X, y), ...]or a torchDataLoader. Note thatXcould be adictorUserDict; this is useful for custom models. In this case, you must (i) specify thebatch_size_fnargument, and (ii) take care of preprocessing likeX.to(device)inside of yourmodel.forward()function. When using MC sampling, batches must be presented in the same deterministic order (no shuffling!).progressbar – Show a progressbar during computation. Default:
False.check_deterministic – Probe that model and data are deterministic, i.e. that the data does not use
drop_lastor data augmentation. Also, the model’s forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default:True. This is a safeguard, only turn it off if you know what you are doing.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – Function that computes the batch size from input data. For
torch.Tensorinputs, this should typically returnX.shape[0]. Fordict/UserDictinputs, this should return the batch size of the contained tensors.mc_samples – Number of Monte-Carlo samples to approximate the loss Hessian.
0(default) uses the exact GGN diagonal. Positive values activate the MC approximation.seed – Seed for the internal random number generator used for MC sampling. Only used when
mc_samples > 0. Default:2147483647.
Fisher (approximate)
- class curvlinops.KFACLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- seed: int = 2147483647,
- fisher_type: str = FisherType.MC,
- mc_samples: int = 1,
- kfac_approx: str = KFACType.EXPAND,
- num_per_example_loss_terms: int | None = None,
- separate_weight_and_bias: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- backend: str = 'hooks',
Linear operator to multiply with the Fisher/GGN’s KFAC approximation.
KFAC approximates the per-layer Fisher/GGN with a Kronecker product: Consider a weight matrix \(\mathbf{W}\) and a bias vector \(\mathbf{b}\) in a single layer. The layer’s Fisher \(\mathbf{F}(\mathbf{\theta})\) for
\[\begin{split}\mathbf{\theta} = \begin{pmatrix} \mathrm{vec}(\mathbf{W}) \\ \mathbf{b} \end{pmatrix}\end{split}\]where \(\mathrm{vec}\) denotes column-stacking is approximated as
\[\mathbf{F}(\mathbf{\theta}) \approx \mathbf{A}_{(\text{KFAC})} \otimes \mathbf{B}_{(\text{KFAC})}\](see
curvlinops.GGNLinearOperatorwithmc_samples > 0). Loosely speaking, the first Kronecker factor is the un-centered covariance of the inputs to a layer. The second Kronecker factor is the un-centered covariance of ‘would-be’ gradients w.r.t. the layer’s output. Those ‘would-be’ gradients result from sampling labels from the model’s distribution and computing their gradients.Kronecker-Factored Approximate Curvature (KFAC) was originally introduced for MLPs in
Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored approximate curvature. International Conference on Machine Learning (ICML),
extended to CNNs in
Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for convolution layers. International Conference on Machine Learning (ICML),
and generalized to all linear layers with weight sharing in
Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., Hennig, P. (2023). Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (NeurIPS).
- SELF_ADJOINT
Whether the operator is self-adjoint.
Truefor KFAC.- Type:
bool
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- seed: int = 2147483647,
- fisher_type: str = FisherType.MC,
- mc_samples: int = 1,
- kfac_approx: str = KFACType.EXPAND,
- num_per_example_loss_terms: int | None = None,
- separate_weight_and_bias: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- backend: str = 'hooks',
Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Warning
- This is an early proto-type with limitations:
Only Linear and Conv2d modules are supported.
The
hooksbackend assumes each module is called exactly once per forward pass. Weight tying (same module called multiple times) will silently produce incorrect results. Usebackend="make_fx"for weight-tied architectures.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction. Callables requirebackend="make_fx".loss_func – The loss function.
params – The parameter values at which the Fisher/GGN is approximated. A dictionary mapping parameter names to tensors.
data – A data loader containing the data of the Fisher/GGN.
progressbar – Whether to show a progress bar when computing the Kronecker factors. Defaults to
False.check_deterministic – Whether to check that the linear operator is deterministic. Defaults to
True.seed – The seed for the random number generator used to draw labels from the model’s predictive distribution. Defaults to
2147483647.fisher_type – The type of Fisher/GGN to approximate. If
FisherType.TYPE2, the exact Hessian of the loss w.r.t. the model outputs is used. This requires as many backward passes as the output dimension, i.e. the number of classes for classification. This is sometimes also called type-2 Fisher. IfFisherType.MC, the expectation is approximated by samplingmc_sampleslabels from the model’s predictive distribution. IfFisherType.EMPIRICAL, the empirical gradients are used which corresponds to the uncentered gradient covariance, or the empirical Fisher. IfFisherType.FORWARD_ONLY, the gradient covariances will be identity matrices, see the FOOF method in Benzing, 2022 or ISAAC in Petersen et al., 2023. Defaults toFisherType.MC.mc_samples – The number of Monte-Carlo samples to use per data point. Has to be set to
1whenfisher_type != FisherType.MC. Defaults to1.kfac_approx – A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g.
Conv2dmodules orLinearmodules that process matrix- or higher-dimensional features. Possible values areKFACType.EXPANDandKFACType.REDUCE. See Eschenhagen et al., 2023 for an explanation of the two approximations. Defaults toKFACType.EXPAND.num_per_example_loss_terms – Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have
num_data * num_per_example_loss_terms * Centries, whereCis the dimension of the random variable we define the likelihood over – for theCrossEntropyLossit will be the number of classes, for theMSELossandBCEWithLogitsLossit will be the size of the last dimension of the the model outputs/targets (our convention here). IfNone,num_per_example_loss_termsis inferred from the data at the cost of one traversal through the data loader. It is expected to be the same for all examples. Defaults toNone.separate_weight_and_bias – Whether to treat weights and biases separately. Defaults to
True. Setting this toFalseis more efficient because gradient covariances are computed once per layer rather than separately for weight and bias.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.backend –
The backend to use for computing Kronecker factors.
"hooks"uses forward/backward hooks (default)."make_fx"uses FX graph tracing via the IO collector. Defaults to"hooks".Note: The
"make_fx"backend incurs a significant one-time tracing overhead (seconds for large models) on the first batch. The traced function is cached by batch size, so subsequent batches of the same size reuse it. However, each distinct batch size triggers a re-trace. Use uniform batch sizes in the data loader to avoid repeated tracing.
- Raises:
ValueError – If
backendis not supported.
- det() Tensor[source]
Compute the determinant of the KFAC approximation.
- Returns:
Determinant of the KFAC approximation.
- frobenius_norm() Tensor[source]
Frobenius norm of the KFAC approximation.
- Returns:
Frobenius norm of the KFAC approximation.
- inverse(
- damping: float = 0.0,
- use_heuristic_damping: bool = False,
- min_damping: float = 1e-08,
- use_exact_damping: bool = False,
- retry_double_precision: bool = True,
Return the inverse of the KFAC approximation.
Inverts each Kronecker-factored block of the canonical operator and returns the result in parameter space.
- Parameters:
damping – Damping value applied to all Kronecker factors. Default:
0.0.use_heuristic_damping – Whether to use a heuristic damping strategy by Martens and Grosse, 2015 (Section 6.3). Only supported for one or two factors.
min_damping – Minimum damping value. Only used if
use_heuristic_dampingisTrue.use_exact_damping – Whether to use exact damping, i.e. to invert \((A \\otimes B) + \\text{damping}\\; \\mathbf{I}\).
retry_double_precision – Whether to retry Cholesky decomposition used for inversion in double precision.
- Returns:
Inverse of the KFAC approximation as a linear operator
P @ K^-1 @ PT.
- class curvlinops.EKFACLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- seed: int = 2147483647,
- fisher_type: str = FisherType.MC,
- mc_samples: int = 1,
- kfac_approx: str = KFACType.EXPAND,
- num_per_example_loss_terms: int | None = None,
- separate_weight_and_bias: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- backend: str = 'hooks',
Linear operator to multiply with the Fisher/GGN’s EKFAC approximation.
Eigenvalue-corrected Kronecker-Factored Approximate Curvature (EKFAC) was originally introduced in
George, T., Laurent, C., Bouthillier, X., Ballas, N., Vincent, P. (2018). Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis (NeurIPS)
and concurrently in the context of continual learning in
Liu, X., Masana, M., Herranz, L., Van de Weijer, J., Lopez, A., Bagdanov, A. (2018). Rotate your networks: Better weight consolidation and less catastrophic forgetting (ICPR).
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: MSELoss | CrossEntropyLoss | BCEWithLogitsLoss,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- seed: int = 2147483647,
- fisher_type: str = FisherType.MC,
- mc_samples: int = 1,
- kfac_approx: str = KFACType.EXPAND,
- num_per_example_loss_terms: int | None = None,
- separate_weight_and_bias: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
- backend: str = 'hooks',
Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Warning
- This is an early proto-type with limitations:
Only Linear and Conv2d modules are supported.
The
hooksbackend assumes each module is called exactly once per forward pass. Weight tying (same module called multiple times) will silently produce incorrect results. Usebackend="make_fx"for weight-tied architectures.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction. Callables requirebackend="make_fx".loss_func – The loss function.
params – The parameter values at which the Fisher/GGN is approximated. A dictionary mapping parameter names to tensors.
data – A data loader containing the data of the Fisher/GGN.
progressbar – Whether to show a progress bar when computing the Kronecker factors. Defaults to
False.check_deterministic – Whether to check that the linear operator is deterministic. Defaults to
True.seed – The seed for the random number generator used to draw labels from the model’s predictive distribution. Defaults to
2147483647.fisher_type –
The type of Fisher/GGN to approximate. If
FisherType.TYPE2, the exact Hessian of the loss w.r.t. the model outputs is used. This requires as many backward passes as the output dimension, i.e. the number of classes for classification. This is sometimes also called type-2 Fisher. IfFisherType.MC, the expectation is approximated by samplingmc_sampleslabels from the model’s predictive distribution. IfFisherType.EMPIRICAL, the empirical gradients are used which corresponds to the uncentered gradient covariance, or the empirical Fisher. IfFisherType.FORWARD_ONLY, the gradient covariances will be identity matrices, see the FOOF method in Benzing, 2022 or ISAAC in Petersen et al., 2023. Defaults toFisherType.MC.mc_samples – The number of Monte-Carlo samples to use per data point. Has to be set to
1whenfisher_type != FisherType.MC. Defaults to1.kfac_approx –
A string specifying the KFAC approximation that should be used for linear weight-sharing layers, e.g.
Conv2dmodules orLinearmodules that process matrix- or higher-dimensional features. Possible values areKFACType.EXPANDandKFACType.REDUCE. See Eschenhagen et al., 2023 for an explanation of the two approximations. Defaults toKFACType.EXPAND.num_per_example_loss_terms – Number of per-example loss terms, e.g., the number of tokens in a sequence. The model outputs will have
num_data * num_per_example_loss_terms * Centries, whereCis the dimension of the random variable we define the likelihood over – for theCrossEntropyLossit will be the number of classes, for theMSELossandBCEWithLogitsLossit will be the size of the last dimension of the the model outputs/targets (our convention here). IfNone,num_per_example_loss_termsis inferred from the data at the cost of one traversal through the data loader. It is expected to be the same for all examples. Defaults toNone.separate_weight_and_bias – Whether to treat weights and biases separately. Defaults to
True. Setting this toFalseis more efficient because gradient covariances are computed once per layer rather than separately for weight and bias.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.backend –
The backend to use for computing Kronecker factors.
"hooks"uses forward/backward hooks (default)."make_fx"uses FX graph tracing via the IO collector. Defaults to"hooks".Note: The
"make_fx"backend incurs a significant one-time tracing overhead (seconds for large models) on the first batch. The traced function is cached by batch size, so subsequent batches of the same size reuse it. However, each distinct batch size triggers a re-trace. Use uniform batch sizes in the data loader to avoid repeated tracing.
- Raises:
ValueError – If
backendis not supported.
- det() Tensor
Compute the determinant of the KFAC approximation.
- Returns:
Determinant of the KFAC approximation.
- frobenius_norm() Tensor
Frobenius norm of the KFAC approximation.
- Returns:
Frobenius norm of the KFAC approximation.
- inverse(damping: float = 0.0) _ChainPyTorchLinearOperator[source]
Return the inverse of the EKFAC approximation.
Inverts each eigendecomposed block of the canonical operator and returns the result in parameter space.
- Parameters:
damping – Damping term added to eigenvalues before inversion. Default:
0.0.- Returns:
Inverse of the EKFAC approximation as a linear operator.
- class curvlinops.FisherType(value)[source]
Enum for the Fisher type.
- TYPE2
'type-2'- Type-2 Fisher, i.e. the exact Hessian of the loss w.r.t. the model outputs is used. This requires as many backward passes as the output dimension, i.e. the number of classes for classification.- Type:
str
- MC
'mc'- Monte-Carlo approximation of the expectation by samplingmc_sampleslabels from the model’s predictive distribution.- Type:
str
- EMPIRICAL
'empirical'- Empirical gradients are used which corresponds to the uncentered gradient covariance, or the empirical Fisher.- Type:
str
- FORWARD_ONLY
'forward-only'- The gradient covariances will be identity matrices, see the FOOF method in Benzing, 2022 or ISAAC in Petersen et al., 2023.- Type:
str
- class curvlinops.KFACType(value)[source]
Enum for the KFAC approximation type.
KFAC-expand and KFAC-reduce are defined in Eschenhagen et al., 2023.
- EXPAND
'expand'- KFAC-expand approximation.- Type:
str
- REDUCE
'reduce'- KFAC-reduce approximation.- Type:
str
Uncentered gradient covariance (empirical Fisher)
- class curvlinops.EFLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor] | None,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- num_per_example_loss_terms: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Uncentered gradient covariance as PyTorch linear operator.
The uncentered gradient covariance is often called ‘empirical Fisher’ (EF).
Consider the empirical risk
\[\mathcal{L}(\mathbf{\theta}) = c \sum_{n=1}^{N} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\]with \(c = \frac{1}{N}\) for
reduction='mean'and \(c=1\) forreduction='sum'. The uncentered gradient covariance matrix is\[c \sum_{n=1}^{N} \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right) \left( \nabla_{\mathbf{\theta}} \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) \right)^\top\,.\]- SELF_ADJOINT
Whether the linear operator is self-adjoint.
Truefor empirical Fisher.- Type:
bool
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- loss_func: Callable[[Tensor, Tensor], Tensor] | None,
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- num_per_example_loss_terms: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for curvature matrices of empirical risks.
Note
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch input X to predictions p. ℓ(p, y) maps the prediction to a loss, using the mini-batch labels y.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.loss_func – Loss function criterion. Maps predictions and mini-batch labels to a scalar value. If
None, there is no loss function and the represented matrix is independent of the loss function.params – The parameter values at which the curvature matrix is evaluated. A dictionary mapping parameter names to tensors (use
dict(model.named_parameters())). The parameter ordering follows dict insertion order.data – Source from which mini-batches can be drawn, for instance a list of mini-batches
[(X, y), ...]or a torchDataLoader. Note thatXcould be adictorUserDict; this is useful for custom models. In this case, you must (i) specify thebatch_size_fnargument, and (ii) take care of preprocessing likeX.to(device)inside of yourmodel.forward()function.progressbar – Show a progressbar during matrix-multiplication. Default:
False.check_deterministic – Probe that model and data are deterministic, i.e. that the data does not use
drop_lastor data augmentation. Also, the model’s forward pass could depend on the order in which mini-batches are presented (BatchNorm, Dropout). Default:True. This is a safeguard, only turn it off if you know what you are doing.num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.num_per_example_loss_terms – Number of per-example loss terms, e.g. the number of tokens in a sequence. Only used by subclasses with
NEEDS_NUM_PER_EXAMPLE_LOSS_TERMS = True. IfNone, it is inferred from the data when needed. Default:None.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.
Jacobians
- class curvlinops.JacobianLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator of the Jacobian.
- FIXED_DATA_ORDER
Whether the data order must be fix.
Truefor Jacobians.- Type:
bool
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for the Jacobian as PyTorch linear operator.
Consider a model \(f(\mathbf{x}, \mathbf{\theta}): \mathbb{R}^M \times \mathbb{R}^D \to \mathbb{R}^C\) with parameters \(\mathbf{\theta}\) and input \(\mathbf{x}\). Assume we are given a data set \(\mathcal{D} = \{ (\mathbf{x}_n, \mathbf{y}_n) \}_{n=1}^N\) of input-target pairs via batches. The model’s Jacobian \(\mathbf{J}_\mathbf{\theta}\mathbf{f}\) is an \(NC \times D\) matrix with elements
\[\left[ \mathbf{J}_\mathbf{\theta}\mathbf{f} \right]_{(n,c), d} = \frac{\partial [f(\mathbf{x}_n, \mathbf{\theta})]_c}{\partial \theta_d}\,.\]Note that the data must be supplied in deterministic order.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.params – The parameter values at which the Jacobian is evaluated. A dictionary mapping parameter names to tensors.
data – Iterable of batched input-target pairs.
progressbar – Show progress bar.
check_deterministic – Check if model and data are deterministic.
num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.
- class curvlinops.TransposedJacobianLinearOperator(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for the transpose Jacobian.
- FIXED_DATA_ORDER
Whether the data order must be fix.
Truefor Jacobians.- Type:
bool
- __init__(
- model_func: Module | Callable[[dict[str, Tensor], Tensor | MutableMapping], Tensor],
- params: dict[str, Tensor],
- data: Iterable[tuple[Tensor | MutableMapping, Tensor]],
- progressbar: bool = False,
- check_deterministic: bool = True,
- num_data: int | None = None,
- batch_size_fn: Callable[[MutableMapping | Tensor], int] | None = None,
Linear operator for the transpose Jacobian as PyTorch linear operator.
Consider a model \(f(\mathbf{x}, \mathbf{\theta}): \mathbb{R}^M \times \mathbb{R}^D \to \mathbb{R}^C\) with parameters \(\mathbf{\theta}\) and input \(\mathbf{x}\). Assume we are given a data set \(\mathcal{D} = \{ (\mathbf{x}_n, \mathbf{y}_n) \}_{n=1}^N\) of input-target pairs via batches. The model’s transpose Jacobian \((\mathbf{J}_\mathbf{\theta}\mathbf{f})^\top\) is an \(D \times NC\) matrix with elements
\[\left[ (\mathbf{J}_\mathbf{\theta}\mathbf{f})^\top \right]_{d, (n,c)} = \frac{\partial [f(\mathbf{x}_n, \mathbf{\theta})]_c}{\partial \theta_d}\,.\]Note that the data must be supplied in deterministic order.
- Parameters:
model_func – The neural network’s forward pass, defining the functional relationship
(params, X) -> prediction. Either annn.Module(architecture) or a callable(params_dict, X) -> prediction.params – The parameter values at which the Jacobian is evaluated. A dictionary mapping parameter names to tensors.
data – Iterable of batched input-target pairs.
progressbar – Show progress bar.
check_deterministic – Check if model and data are deterministic.
num_data – Number of data points. If
None, it is inferred from the data at the cost of one traversal through the data loader.batch_size_fn – If the
X’s indataare nottorch.Tensor, this needs to be specified. The intended behavior is to consume the first entry of the iterates fromdataand return their batch size.
Inverses
- class curvlinops.CGInverseLinearOperator(A: PyTorchLinearOperator, **cg_hyperparameters)[source]
Class for inverse linear operators via conjugate gradients.
Note
Internally, this operator uses GPyTorch’s implementation of CG.
Note
This operator is not compiler-friendly (
torch.compile()). The underlyinglinear_cgroutine uses data-dependent control flow (convergence checks on tensor values viaaten.equaland Pythonifon tensors), which causes graph breaks during tracing.- __init__(A: PyTorchLinearOperator, **cg_hyperparameters)[source]
Store the linear operator whose inverse should be represented.
- Parameters:
A – PyTorch linear operator whose inverse is formed. Must represent a symmetric and positive-definite matrix.
cg_hyperparameters – Keyword arguments for GPyTorch’s CG implementation. In particular, this includes optional arguments such as
max_iter,tolerance, andpreconditioner. Thepreconditionershould be a callable that applies a left preconditioning operation to a supplied vector. This can be implemented via aPyTorchLinearOperator’s__matmul__method. For details, see the documentation of thelinear_cgfunction in https://github.com/cornellius-gp/linear_operator/blob/main/linear_operator/utils/linear_cg.py.
Example
>>> from torch import allclose, tensor >>> from torch.linalg import inv >>> from curvlinops import CGInverseLinearOperator >>> from curvlinops.diag import DiagonalLinearOperator >>> from curvlinops.examples import TensorLinearOperator >>> A = tensor([[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]]) >>> b = tensor([1.0, 2.0, 3.0]) >>> A_linop = TensorLinearOperator(A) >>> A_inv_b = CGInverseLinearOperator( ... A_linop, max_iter=3, max_tridiag_iter=3, tolerance=1e-7 ... ) @ b >>> # Use CG with a simple diagonal preconditioner. >>> inverse_diagonal = DiagonalLinearOperator([A.diag().reciprocal()]) >>> A_inv_b_preconditioned = CGInverseLinearOperator( ... A_linop, ... max_iter=3, ... max_tridiag_iter=3, ... tolerance=1e-7, ... preconditioner=inverse_diagonal.__matmul__, ... ) @ b >>> A_inv_b_exact = inv(A) @ b >>> A_inv_b.round(decimals=4) tensor([0.2222, 0.1111, 1.4444]) >>> A_inv_b_preconditioned.round(decimals=4) tensor([0.2222, 0.1111, 1.4444]) >>> allclose(A_inv_b_exact, A_inv_b_preconditioned) True
- class curvlinops.LSMRInverseLinearOperator(A: PyTorchLinearOperator, **lsmr_hyperparameters)[source]
Class for inverse PyTorch linear operators via LSMR.
See https://arxiv.org/abs/1006.0758 for details on the LSMR algorithm.
Note
Internally, this operator uses SciPy’s CPU implementation of LSMR as PyTorch currently does not offer an LSMR interface that purely relies on matrix-vector products.
Note
This operator is not compiler-friendly (
torch.compile()). The matrix-vector product converts tensors to NumPy and calls SciPy’slsmr; these non-Torch operations cannot be traced and cause graph breaks.- __init__(A: PyTorchLinearOperator, **lsmr_hyperparameters)[source]
Store the linear operator whose inverse should be represented.
- Parameters:
A – Linear operator whose inverse is formed.
lsmr_hyperparameters – The hyper-parameters that will be passed to the LSMR implementation in SciPy. For more detail, see https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsmr.html.
- class curvlinops.NeumannInverseLinearOperator(
- A: PyTorchLinearOperator,
- num_terms: int = 100,
- scale: float = 1.0,
- check_nan: bool = True,
- preconditioner: None | Callable[[Tensor], Tensor] = None,
Class for inverse linear operators via truncated Neumann series.
Motivated by
Lorraine, J., Vicol, P., & Duvenaud, D. (2020). Optimizing millions of hyperparameters by implicit differentiation. In International Conference on Artificial Intelligence and Statistics (AISTATS).
Wang, A., Nguyen, E., Yang, R., Bae, J., McIlraith, S. A., & Grosse, R. B. (2025). Better Training Data Attribution via Better Inverse Hessian-Vector Products. In Advances in Neural Information Processing Systems (NeurIPS 2025).
Warning
The Neumann series can be non-convergent. In this case, the iterations will become numerically unstable, leading to
NaNvalues.Warning
The Neumann series can converge slowly. Use
curvlinops.CGInverseLinearOperatorfor better accuracy.Note
With the default
check_nan=True, this operator is not compiler-friendly (torch.compile()): the per-iterationisnancheck introduces data-dependent branching that causes graph breaks. Passingcheck_nan=Falsemakes the operator compiler-friendly (the truncated series itself uses only traceable PyTorch ops).- __init__(
- A: PyTorchLinearOperator,
- num_terms: int = 100,
- scale: float = 1.0,
- check_nan: bool = True,
- preconditioner: None | Callable[[Tensor], Tensor] = None,
Store the linear operator whose inverse should be represented.
The Neumann series for an invertible linear operator \(\mathbf{A}\) is
\[\mathbf{A}^{-1} = \sum_{k=0}^{\infty} \left(\mathbf{I} - \mathbf{A} \right)^k\,,\]which is convergent if all eigenvalues satisfy \(0 < \lambda(\mathbf{A}) < 2\).
By re-scaling the matrix by
scale(\(\alpha\)), we have:\[\mathbf{A}^{-1} = \alpha (\alpha \mathbf{A})^{-1} = \alpha \sum_{k=0}^{\infty} \left(\mathbf{I} - \alpha \mathbf{A} \right)^k\,,\]which is convergent if \(0 < \lambda(\mathbf{A}) < \frac{2}{\alpha}\).
Additionally, we truncate the series at
num_terms(\(K\)):\[\mathbf{A}^{-1} \approx \alpha \sum_{k=0}^{K} \left(\mathbf{I} - \alpha \mathbf{A} \right)^k\,.\]- Parameters:
A – Linear operator whose inverse is formed.
num_terms – Number of terms in the truncated Neumann series. Default:
100.scale – Scale applied to the matrix in the Neumann iteration. Crucial for convergence of Neumann series (details above). Default:
1.0.check_nan – Whether to check for NaNs while applying the truncated Neumann series. Default:
True.preconditioner – Optional preconditioner \(\mathbf{P}\) used in the preconditioned Neumann/Richardson iteration \(\mathbf{A}^{-1} \approx \alpha \sum_{k=0}^{K} (\mathbf{I} - \alpha \mathbf{P}\mathbf{A})^k \mathbf{P}\), where \(\alpha\) is given by
scale. This preconditioned formulation is inspired by Wang et al. (NeurIPS 2025).preconditionershould be a callable that applies a left preconditioning operation to a supplied vector or matrix in tensor format, e.g. aPyTorchLinearOperator’s__matmul__method. Default:None.
Sub-matrices
- class curvlinops.SubmatrixLinearOperator(A: PyTorchLinearOperator, row_idxs: list[int], col_idxs: list[int])[source]
Class for sub-matrices of linear operators.
Note
This operator is not compiler-friendly (
torch.compile()). Its matrix-vector product dispatches through the wrapped operator’s__matmul__, and Dynamo cannot proxy a user-defined linear operator as an argument, which causes graph breaks.- __init__(A: PyTorchLinearOperator, row_idxs: list[int], col_idxs: list[int])[source]
Store the linear operator and indices of its sub-matrix.
Represents the sub-matrix
A[row_idxs, :][col_idxs, :].- Parameters:
A – A linear operator.
row_idxs – The sub-matrix’s row indices.
col_idxs – The sub-matrix’s column indices.
- set_submatrix(row_idxs: list[int], col_idxs: list[int])[source]
Define the sub-matrix.
Internally sets the linear operator’s shape.
- Parameters:
row_idxs – The sub-matrix’s row indices.
col_idxs – The sub-matrix’s column indices.
- Raises:
ValueError – If the index lists contain duplicate values, non-integers, or out-of-bounds indices.
Spectral density approximation
Note
This functionality currently expects SciPy LinearOperator instances.
- curvlinops.lanczos_approximate_spectrum(
- A: PyTorchLinearOperator,
- ncv: int,
- num_points: int = 1024,
- num_repeats: int = 1,
- kappa: float = 3.0,
- boundaries: tuple[float, float] | tuple[float, None] | tuple[None, float] | None = None,
- margin: float = 0.05,
- boundaries_tol: float = 0.01,
Approximate the spectral density p(λ) = 1/d ∑ᵢ δ(λ - λᵢ) of A ∈ Rᵈˣᵈ.
Implements algorithm 2 (
LanczosApproxSpec) of Papyan, 2020 (https://jmlr.org/papers/v21/20-933.html).Internally rescales the operator spectrum to the interval [-1; 1] such that the width
kappaof the Gaussian bumps used to approximate the delta peaks need not be tweaked.- Parameters:
A – Symmetric linear operator.
ncv – Number of Lanczos vectors (number of nodes/weights for the quadrature).
num_points – Resolution.
num_repeats – Number of Lanczos quadratures to average the density over. Default:
1. Taken from papyan2020traces, Section D.2.kappa – Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must be greater than 1. Default:
3. Taken from papyan2020traces, Section D.2.boundaries – Estimates of the minimum and maximum eigenvalues of
A. If left unspecified, they will be estimated internally.margin – Relative margin added around the spectral boundary. Default:
0.05. Taken from papyan2020traces, Section D.2.boundaries_tol – (Only relevant if
boundariesare not specified). Relative accuracy used to estimate the spectral boundary.0implies machine precision. Default:1e-2, from https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
- Returns:
Grid points λ and approximated spectral density p(λ) of A.
- curvlinops.lanczos_approximate_log_spectrum(
- A: PyTorchLinearOperator,
- ncv: int,
- num_points: int = 1024,
- num_repeats: int = 1,
- kappa: float = 1.04,
- boundaries: tuple[float, float] | tuple[float, None] | tuple[None, float] | None = None,
- margin: float = 0.05,
- boundaries_tol: float = 0.01,
- epsilon: float = 1e-05,
Approximate the spectral density
p(λ) = 1/d ∑ᵢ δ(λ - λᵢ)oflog(|A| + εI) ∈ Rᵈˣᵈ.Follows the idea of Section C.7 in Papyan, 2020 (https://jmlr.org/papers/v21/20-933.html).
Here, log denotes the natural logarithm (i.e. base e).
Internally rescales the operator spectrum to the interval [-1; 1] such that the width
kappaof the Gaussian bumps used to approximate the delta peaks need not be tweaked.- Parameters:
A – Symmetric linear operator.
ncv – Number of Lanczos vectors (number of nodes/weights for the quadrature).
num_points – Resolution.
num_repeats – Number of Lanczos quadratures to average the density over. Default:
1. Taken from papyan2020traces, Section D.2.kappa – Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must be greater than 1. Default:
1.04. Obtained by tweaking while reproducing Fig. 15b from papyan2020traces (not specified by the paper).boundaries – Estimates of the minimum and maximum eigenvalues of \(|A|\). If left unspecified, they will be estimated internally.
margin – Relative margin added around the spectral boundary. Default:
0.05. Taken from papyan2020traces, Section D.2.boundaries_tol – (Only relevant if
boundariesare not specified). Relative accuracy used to estimate the spectral boundary.0implies machine precision. Default:1e-2, from https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.epsilon – Shift to increase numerical stability. Default:
1e-5. Taken from papyan2020traces, Section D.2.
- Returns:
Grid points λ and approximated spectral density p(λ) of
log(|A| + εI).
- class curvlinops.LanczosApproximateSpectrumCached(
- A: PyTorchLinearOperator,
- ncv: int,
- boundaries: tuple[float, float] | tuple[float, None] | tuple[None, float] | None = None,
- boundaries_tol: float = 0.01,
Class to approximate the spectral density of p(λ) = 1/d ∑ᵢ δ(λ - λᵢ) of A ∈ Rᵈˣᵈ.
Caches Lanczos iterations to efficiently produce spectral density approximations with different hyperparameters.
- __init__(
- A: PyTorchLinearOperator,
- ncv: int,
- boundaries: tuple[float, float] | tuple[float, None] | tuple[None, float] | None = None,
- boundaries_tol: float = 0.01,
Initialize.
- Parameters:
A – Symmetric linear operator.
ncv – Number of Lanczos vectors (number of nodes/weights for the quadrature).
boundaries – Estimates of the minimum and maximum eigenvalues of
A. If left unspecified, they will be estimated internally.boundaries_tol – (Only relevant if
boundariesare not specified). Relative accuracy used to estimate the spectral boundary.0implies machine precision. Default:1e-2, from https://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html#examples.
- approximate_spectrum(num_repeats: int = 1, num_points: int = 1024, kappa: float = 3.0, margin: float = 0.05) tuple[Tensor, Tensor][source]
Approximate the spectal density of A.
- Parameters:
num_repeats – Number of Lanczos quadratures to average the density over. Default:
1. Taken from papyan2020traces, Section D.2.num_points – Resolution. Default:
1024.kappa – Width of the Gaussian used to approximate delta peaks in [-1; 1]. Must be greater than 1. Default:
3. From papyan2020traces, Section D.2.margin – Relative margin added around the spectral boundary. Default:
0.05. Taken from papyan2020traces, Section D.2.
- Returns:
Grid points λ and approximated spectral density p(λ) of A.
Trace approximation
- curvlinops.hutchinson_trace(A: Tensor | PyTorchLinearOperator, num_matvecs: int, distribution: str = 'rademacher') Tensor[source]
Estimate a linear operator’s trace using the Girard-Hutchinson method.
For details, see
Girard, D. A. (1989). A fast ‘monte-carlo cross-validation’ procedure for large least squares problems with noisy data. Numerische Mathematik.
Hutchinson, M. (1989). A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communication in Statistics—Simulation and Computation.
Let \(\mathbf{A}\) be a square linear operator. We can approximate its trace \(\mathrm{Tr}(\mathbf{A})\) by drawing \(N\) random vectors \(\mathbf{v}_n \sim \mathbf{v}\) from a distribution that satisfies \(\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}\) and compute
\[a := \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n^\top \mathbf{A} \mathbf{v}_n \approx \mathrm{Tr}(\mathbf{A})\,.\]This estimator is unbiased,
\[\mathbb{E}[a] = \mathrm{Tr}(\mathbb{E}[\mathbf{v}^\top\mathbf{A} \mathbf{v}]) = \mathrm{Tr}(\mathbf{A} \mathbb{E}[\mathbf{v} \mathbf{v}^\top]) = \mathrm{Tr}(\mathbf{A} \mathbf{I}) = \mathrm{Tr}(\mathbf{A})\,.\]- Parameters:
A – A square linear operator whose trace is estimated.
num_matvecs – Total number of matrix-vector products to use. Must be smaller than the dimension of the linear operator (because otherwise one can evaluate the true trace directly at the same cost).
distribution – Distribution of the random vectors used for the trace estimation. Can be either
'rademacher'or'normal'. Default:'rademacher'.
- Returns:
The estimated trace of the linear operator.
Example
>>> from torch import manual_seed, rand >>> _ = manual_seed(0) # make deterministic >>> A = rand(50, 50) >>> tr_A = A.trace().item() # exact trace as reference >>> # one- and multi-sample approximations >>> tr_A_low_precision = hutchinson_trace(A, num_matvecs=1).item() >>> tr_A_high_precision = hutchinson_trace(A, num_matvecs=40).item() >>> # compute the relative errors >>> rel_error_low_precision = abs(tr_A - tr_A_low_precision) / abs(tr_A) >>> rel_error_high_precision = abs(tr_A - tr_A_high_precision) / abs(tr_A) >>> assert rel_error_low_precision > rel_error_high_precision >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) (23.7836, -10.0279, 20.8427)
- curvlinops.hutchpp_trace(A: PyTorchLinearOperator | Tensor, num_matvecs: int, distribution: str = 'rademacher') Tensor[source]
Estimate a linear operator’s trace using the Hutch++ method.
In contrast to vanilla Hutchinson, Hutch++ has lower variance, but requires more memory. The method is presented in
Meyer, R. A., Musco, C., Musco, C., & Woodruff, D. P. (2020). Hutch++: optimal stochastic trace estimation.
Let \(\mathbf{A}\) be a square linear operator whose trace we want to approximate. First, using one third of the available matrix-vector products, we compute an orthonormal basis \(\mathbf{Q}\) of a sub-space spanned by \(\mathbf{A} \mathbf{S}\) where \(\mathbf{S}\) is a tall random matrix with i.i.d. elements. Then, using one third of the available matrix-vector products, we compute the trace in the sub-space. Finally, we apply Hutchinson’s estimator in the remaining space spanned by \(\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top\). Let \(3N\) denote the total number of matrix-vector products. We can draw \(2N\) random vectors \(\mathbf{v}_n \sim \mathbf{v}\) from a distribution which satisfies \(\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}\), compute \(\mathbf{Q}\) from the first \(N\) vectors, and use the remaining to compute the estimator
\[a := \mathrm{Tr}(\mathbf{Q}^\top \mathbf{A} \mathbf{Q}) + \frac{1}{N} \sum_{n = N+1}^{2N} \mathbf{v}_n^\top (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top)^\top \mathbf{A} (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top) \mathbf{v}_n \approx \mathrm{Tr}(\mathbf{A})\,.\]This estimator is unbiased, \(\mathbb{E}[a] = \mathrm{Tr}(\mathbf{A})\), as the first term is the exact trace in the space spanned by \(\mathbf{Q}\), and the second part is Hutchinson’s unbiased estimator in the complementary space.
- Parameters:
A – A square linear operator whose trace is estimated.
num_matvecs – Total number of matrix-vector products to use. Must be smaller than the dimension of the linear operator (because otherwise one can evaluate the true trace directly at the same cost), and divisible by 3.
distribution – Distribution of the random vectors used for the trace estimation. Can be either
'rademacher'or'normal'. Default:'rademacher'.
- Returns:
The estimated trace of the linear operator.
Example
>>> from torch import manual_seed, rand >>> _ = manual_seed(0) # make deterministic >>> A = rand(50, 50) >>> tr_A = A.trace().item() # exact trace as reference >>> # one- and multi-sample approximations >>> tr_A_low_precision = hutchpp_trace(A, num_matvecs=3).item() >>> tr_A_high_precision = hutchpp_trace(A, num_matvecs=30).item() >>> # compute the relative errors >>> rel_error_low_precision = abs(tr_A - tr_A_low_precision) / abs(tr_A) >>> rel_error_high_precision = abs(tr_A - tr_A_high_precision) / abs(tr_A) >>> assert rel_error_low_precision > rel_error_high_precision >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) (23.7836, 15.7879, 19.6381)
- curvlinops.xtrace(A: PyTorchLinearOperator | Tensor, num_matvecs: int, distribution: str = 'rademacher') Tensor[source]
Estimate a linear operator’s trace using the XTrace algorithm.
The method is presented in this paper:
Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis and Applications (SIMAX).
It combines the variance reduction from Hutch++ with the exchangeability principle.
- Parameters:
A – A square linear operator.
num_matvecs – Total number of matrix-vector products to use. Must be even and less than the dimension of the linear operator (because otherwise one can evaluate the true trace directly at the same cost).
distribution – Distribution of the random vectors used for the trace estimation. Can be either
'rademacher'or'normal'. Default:'rademacher'.
- Returns:
The estimated trace of the linear operator.
Diagonal approximation
- curvlinops.hutchinson_diag(A: PyTorchLinearOperator | Tensor, num_matvecs: int, distribution: str = 'rademacher') Tensor[source]
Estimate a linear operator’s diagonal using Hutchinson’s method.
For details, see
Bekas, C., Kokiopoulou, E., & Saad, Y. (2007). An estimator for the diagonal of a matrix. Applied Numerical Mathematics.
Let \(\mathbf{A}\) be a square linear operator. We can approximate its diagonal \(\mathrm{diag}(\mathbf{A})\) by drawing random vectors \(N\) \(\mathbf{v}_n \sim \mathbf{v}\) from a distribution \(\mathbf{v}\) that satisfies \(\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}\), and compute the estimator
\[\mathbf{a} := \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n \odot \mathbf{A} \mathbf{v}_n \approx \mathrm{diag}(\mathbf{A})\,.\]This estimator is unbiased,
\[\mathbb{E}[a_i] = \sum_j \mathbb{E}[v_i A_{i,j} v_j] = \sum_j A_{i,j} \mathbb{E}[v_i v_j] = \sum_j A_{i,j} \delta_{i, j} = A_{i,i}\,.\]- Parameters:
A – A square linear operator whose diagonal is estimated.
num_matvecs – Total number of matrix-vector products to use. Must be smaller than the dimension of the linear operator (because otherwise one can evaluate the true diagonal directly at the same cost).
distribution – Distribution of the random vectors used for the diagonal estimation. Can be either
'rademacher'or'normal'. Default:'rademacher'.
- Returns:
The estimated diagonal of the linear operator.
Example
>>> from torch import manual_seed, rand >>> from torch.linalg import vector_norm >>> _ = manual_seed(0) # make deterministic >>> A = rand(40, 40) >>> diag_A = A.diag() # exact diagonal as reference >>> # one- and multi-sample approximations >>> diag_A_low_precision = hutchinson_diag(A, num_matvecs=1) >>> diag_A_high_precision = hutchinson_diag(A, num_matvecs=30) >>> # compute residual norms >>> error_low_precision = (vector_norm(diag_A - diag_A_low_precision) / vector_norm(diag_A)).item() >>> error_high_precision = (vector_norm(diag_A - diag_A_high_precision) / vector_norm(diag_A)).item() >>> assert error_low_precision > error_high_precision >>> round(error_low_precision, 4), round(error_high_precision, 4) (3.2648, 0.9253)
- curvlinops.xdiag(A: PyTorchLinearOperator | Tensor, num_matvecs: int) Tensor[source]
Estimate a linear operator’s diagonal using the XDiag algorithm.
The method is presented in this paper:
Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis and Applications (SIMAX).
It combines the variance reduction from Diag++ with the exchangeability principle.
- Parameters:
A – A square linear operator.
num_matvecs – Total number of matrix-vector products to use. Must be even and less than the dimension of the linear operator (because otherwise one can evaluate the true diagonal directly at the same cost).
- Returns:
The estimated diagonal of the linear operator.
Frobenius norm approximation
- class curvlinops.hutchinson_squared_fro(A: Tensor | PyTorchLinearOperator, num_matvecs: int, distribution: str = 'rademacher')[source]
Estimate the squared Frobenius norm of a matrix using Hutchinson’s method.
Let \(\mathbf{A} \in \mathbb{R}^{M \times N}\) be some matrix. It’s Frobenius norm \(\lVert\mathbf{A}\rVert_\text{F}\) is defined via:
\[\lVert\mathbf{A}\rVert_\text{F}^2 = \sum_{m=1}^M \sum_{n=1}^N \mathbf{A}_{n,m}^2 = \text{Tr}(\mathbf{A}^\top \mathbf{A}).\]Due to the last equality, we can use Hutchinson-style trace estimation to estimate the squared Frobenius norm.
- Parameters:
A – A matrix whose squared Frobenius norm is estimated.
num_matvecs – Total number of matrix-vector products to use. Must be smaller than the minimum dimension of the matrix.
distribution – Distribution of the random vectors used for the trace estimation. Can be either
'rademacher'or'normal'. Default:'rademacher'.
- Returns:
The estimated squared Frobenius norm of the matrix.
- Raises:
ValueError – If the matrix is not two-dimensional or if the number of matrix- vector products is greater than the minimum dimension of the matrix (because then you can evaluate the true squared Frobenius norm directly atthe same cost).
Example
>>> from torch.linalg import matrix_norm >>> from torch import rand, manual_seed >>> _ = manual_seed(0) # make deterministic >>> A = rand(40, 40) >>> fro2_A = matrix_norm(A).item()**2 # reference: exact squared Frobenius norm >>> # one- and multi-sample approximations >>> fro2_A_low_prec = hutchinson_squared_fro(A, num_matvecs=1).item() >>> fro2_A_high_prec = hutchinson_squared_fro(A, num_matvecs=30).item() >>> assert abs(fro2_A - fro2_A_low_prec) > abs(fro2_A - fro2_A_high_prec) >>> round(fro2_A, 1), round(fro2_A_low_prec, 1), round(fro2_A_high_prec, 1) (530.9, 156.7, 628.9)