Note
Go to the end to download the full example code.
Fisher-weighted Model Averaging
In this example we implement Fisher-weighted model averaging, a technique described in this NeurIPS 2022 paper. It requires Fisher-vector products, and multiplication with the inverse of a sum of Fisher matrices. The paper uses a diagonal approximation of the Fisher matrices. Instead, we will use the exact Fisher matrices and rely on matrix-free methods for applying the inverse.
Note
In our setup, the Fisher equals the generalized Gauss-Newton matrix.
Hence, we work with curvlinops.GGNLinearOperator.
Description: We are given a set of \(T\) tasks (represented by data sets \(\mathcal{D}_t\)), and train a model \(f_\mathbf{\theta}\) on each task independently using the same criterion function. This yields \(T\) parameters \(\mathbf{\theta}_1^\star, \dots, \mathbf{\theta}_T^\star\), and we would like to combine them into a single model \(f_\mathbf{\theta^\star}\). To do that, we use the Fisher information matrices \(\mathbf{F}_t\) of each task (given by the data set \(\mathcal{D}_t\) and the trained model parameters \(\mathbf{\theta}_t^\star\)). The merged parameters are given by
This requires multiplying with the inverse of the sum of Fisher matrices
(extended with a damping term). We will use
curvlinops.CGInverseLinearOperator for that.
Let’s start with the imports.
import numpy
import torch
from backpack.utils.convert_parameters import vector_to_parameter_list
from scipy import sparse
from scipy.sparse.linalg import aslinearoperator
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from curvlinops import CGInverseLinearOperator, GGNLinearOperator
# make deterministic
torch.manual_seed(0)
numpy.random.seed(0)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Setup
First, we will create a bunch of synthetic regression tasks (i.e. data sets) and an untrained model for each of them.
T = 3 # number of tasks
D_in = 7 # input dimension of each task
D_hidden = 5 # hidden dimension of the architecture we will use
D_out = 3 # output dimension of each task
N = 20 # number of data per task
batch_size = 7
def make_architecture() -> nn.Sequential:
"""Create a neural network.
Returns:
A neural network.
"""
return nn.Sequential(
nn.Linear(D_in, D_hidden),
nn.ReLU(),
nn.Linear(D_hidden, D_hidden),
nn.Sigmoid(),
nn.Linear(D_hidden, D_out),
)
def make_dataset() -> TensorDataset:
"""Create a synthetic regression data set.
Returns:
A synthetic regression data set.
"""
X, y = torch.rand(N, D_in), torch.rand(N, D_out)
return TensorDataset(X, y)
models = [make_architecture().to(DEVICE) for _ in range(T)]
data_loaders = [DataLoader(make_dataset(), batch_size=batch_size) for _ in range(T)]
loss_functions = [nn.MSELoss(reduction="mean").to(DEVICE) for _ in range(T)]
Training
Here, we train each model for a small number of epochs.
num_epochs = 10
log_epochs = [0, num_epochs - 1]
for task_idx in range(T):
model = models[task_idx]
data_loader = data_loaders[task_idx]
loss_function = loss_functions[task_idx]
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
for epoch in range(num_epochs):
for batch_idx, (X, y) in enumerate(data_loader):
optimizer.zero_grad()
X, y = X.to(DEVICE), y.to(DEVICE)
loss = loss_function(model(X), y)
loss.backward()
optimizer.step()
if epoch in log_epochs and batch_idx == 0:
print(f"Task {task_idx} batch loss at epoch {epoch}: {loss.item():.3f}")
Task 0 batch loss at epoch 0: 0.454
Task 0 batch loss at epoch 9: 0.248
Task 1 batch loss at epoch 0: 0.632
Task 1 batch loss at epoch 9: 0.300
Task 2 batch loss at epoch 0: 0.363
Task 2 batch loss at epoch 9: 0.198
Linear operators
We are now ready to set up the linear operators for the per-task Fishers:
fishers = [
GGNLinearOperator(
model,
loss_function,
[p for p in model.parameters() if p.requires_grad],
data_loader,
)
for model, loss_function, data_loader in zip(models, loss_functions, data_loaders)
]
Fisher-weighted Averaging
Next, we also need the trained parameters as scipy vectors:
# flatten and convert to numpy
thetas = [
nn.utils.parameters_to_vector((p for p in model.parameters() if p.requires_grad))
for model in models
]
thetas = [theta.cpu().detach().numpy() for theta in thetas]
We are ready to compute the sum of Fisher-weighted parameters (the right-hand side in the above equation):
In the last step we need to normalize by multiplying with the inverse of the summed Fishers. Let’s first create the linear operator and add a damping term:
dim = fishers[0].shape[0]
identity = sparse.eye(dim)
damping = 1e-3
fisher_sum = aslinearoperator(damping * identity)
for fisher in fishers:
fisher_sum += fisher
Finally, we define a linear operator for the inverse of the damped Fisher sum:
fisher_sum_inv = CGInverseLinearOperator(fisher_sum)
Note
You may want to tweak the convergence criterion of CG using
curvlinops.CGInverseLinearOperator.set_cg_hyperparameters(). before
applying the matrix-vector product.
fisher_weighted_params = fisher_sum_inv @ rhs
/home/docs/checkouts/readthedocs.org/user_builds/curvlinops/envs/1.2.0/lib/python3.8/site-packages/curvlinops/_base.py:259: UserWarning: Input vector is float64, while linear operator is float32. Converting to float32.
warn(
Comparison
Let’s compare the performance of the Fisher-averaged parameters with a naive average.
average_params = sum(thetas) / len(thetas)
We initialize two neural networks with those parameters
fisher_model = make_architecture()
params = [p for p in fisher_model.parameters() if p.requires_grad]
theta_fisher = vector_to_parameter_list(
torch.from_numpy(fisher_weighted_params), params
)
for theta, param in zip(theta_fisher, params):
param.data = theta.to(param.device).to(param.dtype).data
# same for the average-weighted parameters
average_model = make_architecture()
params = [p for p in average_model.parameters() if p.requires_grad]
theta_average = vector_to_parameter_list(torch.from_numpy(average_params), params)
for theta, param in zip(theta_average, params):
param.data = theta.to(param.device).to(param.dtype).data
and probe them on one batch of each task:
for task_idx in range(T):
data_loader = data_loaders[task_idx]
loss_function = loss_functions[task_idx]
X, y = next(iter(data_loader))
X, y = X.to(DEVICE), y.to(DEVICE)
fisher_loss = loss_function(fisher_model(X), y)
average_loss = loss_function(average_model(X), y)
assert fisher_loss < average_loss
print(f"Task {task_idx} batch loss with Fisher averaging: {fisher_loss.item():.3f}")
print(f"Task {task_idx} batch loss with naive averaging: {average_loss.item():.3f}")
Task 0 batch loss with Fisher averaging: 0.186
Task 0 batch loss with naive averaging: 0.237
Task 1 batch loss with Fisher averaging: 0.189
Task 1 batch loss with naive averaging: 0.216
Task 2 batch loss with Fisher averaging: 0.215
Task 2 batch loss with naive averaging: 0.261
The Fisher-averaged parameters perform better than the naively averaged parameters; at least on the training data.
That’s all for now.
Total running time of the script: (0 minutes 0.250 seconds)