Fisher-weighted Model Averaging

In this example we implement Fisher-weighted model averaging, a technique described in this NeurIPS 2022 paper. It requires Fisher-vector products, and multiplication with the inverse of a sum of Fisher matrices. The paper uses a diagonal approximation of the Fisher matrices. In addition, we will also use the exact Fisher matrices and rely on matrix-free methods for applying the inverse.

Note

In our setup, the Fisher equals the generalized Gauss-Newton matrix. Hence, we work with curvlinops.GGNLinearOperator.

Description: We are given a set of \(T\) tasks (represented by data sets \(\mathcal{D}_t\)), and train a model \(f_\mathbf{\theta}\) on each task independently using the same criterion function. This yields \(T\) parameters \(\mathbf{\theta}_1^\star, \dots, \mathbf{\theta}_T^\star\), and we would like to combine them into a single model \(f_\mathbf{\theta^\star}\). To do that, we use the Fisher information matrices \(\mathbf{F}_t\) of each task (given by the data set \(\mathcal{D}_t\) and the trained model parameters \(\mathbf{\theta}_t^\star\)). The merged parameters are given by

\[\mathbf{\theta}^\star = \left(\lambda \mathbf{I} + \sum_{t=1}^T \mathbf{F}_t \right)^{-1} \left( \sum_{t=1}^T \mathbf{F}_t \mathbf{\theta}_t^\star\right)\,.\]

This requires multiplying with the inverse of the sum of Fisher matrices (extended with a damping term). If we approximate each Fisher with its diagonal, this is easy, without this approximation, we will use curvlinops.CGInverseLinearOperator for inversion. Naive averaging corresponds to the special case where the Fisher is the identity.

Let’s start with the imports.

from backpack.utils.convert_parameters import vector_to_parameter_list
from torch import cuda, device, manual_seed, rand
from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid
from torch.nn.utils import parameters_to_vector
from torch.optim import SGD
from torch.utils.data import DataLoader, TensorDataset

from curvlinops import (
    CGInverseLinearOperator,
    GGNDiagonalLinearOperator,
    GGNLinearOperator,
)
from curvlinops.examples import IdentityLinearOperator

# make deterministic
manual_seed(0)

DEVICE = device("cuda" if cuda.is_available() else "cpu")

Setup

First, we will create a bunch of synthetic regression tasks (i.e. data sets) and an untrained model for each of them.

T = 3  # number of tasks
D_in = 7  # input dimension of each task
D_hidden = 5  # hidden dimension of the architecture we will use
D_out = 3  # output dimension of each task
N = 20  # number of data per task
batch_size = 7


def make_architecture() -> Sequential:
    """Create a neural network.

    Returns:
        A neural network.
    """
    return Sequential(
        Linear(D_in, D_hidden),
        ReLU(),
        Linear(D_hidden, D_hidden),
        Sigmoid(),
        Linear(D_hidden, D_out),
    )


def make_dataset() -> TensorDataset:
    """Create a synthetic regression data set.

    Returns:
        A synthetic regression data set.
    """
    X, y = rand(N, D_in), rand(N, D_out)
    return TensorDataset(X, y)


models = [make_architecture().to(DEVICE) for _ in range(T)]
data_loaders = [DataLoader(make_dataset(), batch_size=batch_size) for _ in range(T)]
loss_functions = [MSELoss(reduction="mean").to(DEVICE) for _ in range(T)]

Training

Here, we train each model for a small number of epochs.

num_epochs = 10
log_epochs = [0, num_epochs - 1]

for task_idx in range(T):
    model = models[task_idx]
    data_loader = data_loaders[task_idx]
    loss_function = loss_functions[task_idx]
    optimizer = SGD(model.parameters(), lr=1e-2)

    for epoch in range(num_epochs):
        for batch_idx, (X, y) in enumerate(data_loader):
            optimizer.zero_grad()
            X, y = X.to(DEVICE), y.to(DEVICE)
            loss = loss_function(model(X), y)
            loss.backward()
            optimizer.step()

            if epoch in log_epochs and batch_idx == 0:
                print(f"Task {task_idx} batch loss at epoch {epoch}: {loss.item():.3f}")
Task 0 batch loss at epoch 0: 0.454
Task 0 batch loss at epoch 9: 0.248
Task 1 batch loss at epoch 0: 0.632
Task 1 batch loss at epoch 9: 0.300
Task 2 batch loss at epoch 0: 0.363
Task 2 batch loss at epoch 9: 0.198

Linear operators

We are now ready to set up the linear operators for the per-task Fishers. We collect the per-task Fisher operators for all three strategies in a dictionary. Naive averaging corresponds to using the identity as Fisher:

per_task_fishers = {
    # Diagonal approximation as used in the seminal paper
    # (Precisely speaking, the seminal paper uses a randomized approximation of the
    # Fisher based on sampling that can be achieved with `mc_samples=1`.
    # For simplicity we compute the exact GGN/Fisher diagonal here.)
    "diag(F)": [
        GGNDiagonalLinearOperator(
            model,
            loss_function,
            {n: p for n, p in model.named_parameters() if p.requires_grad},
            data_loader,
        )
        for model, loss_function, data_loader in zip(
            models, loss_functions, data_loaders
        )
    ],
    "F": [
        GGNLinearOperator(
            model,
            loss_function,
            {n: p for n, p in model.named_parameters() if p.requires_grad},
            data_loader,
        )
        for model, loss_function, data_loader in zip(
            models, loss_functions, data_loaders
        )
    ],
}

Fisher-weighted Averaging

Next, we also need the trained parameters as vectors:

# flatten and concatenate
thetas = [
    parameters_to_vector((p for p in model.parameters() if p.requires_grad)).detach()
    for model in models
]

We are ready to compute the sum of Fisher-weighted parameters (the right-hand side in the above equation) for each strategy:

rhs = {
    key: sum(F @ theta for F, theta in zip(Fs, thetas))
    for key, Fs in per_task_fishers.items()
}

In the last step we need to normalize by multiplying with the inverse of the summed Fishers. Let’s first sum the per-task Fishers for each strategy:

fisher_sums = {}
for key, Fs in per_task_fishers.items():
    fisher_sums[key] = Fs[0]
    for F in Fs[1:]:
        fisher_sums[key] = F + fisher_sums[key]

Finally, we compute the merged parameters by applying the inverse of the damped Fisher sum. For diagonal operators (naive and diag(F)), the inverse is analytical. For the full Fisher, we use curvlinops.CGInverseLinearOperator:

Note

For the full Fisher, you may want to tweak the convergence criterion of CG using curvlinops.CGInverseLinearOperator.set_cg_hyperparameters() before applying the matrix-vector product.

damping = 1e-3

merged_params = {"Naive": sum(thetas) / len(thetas)}
for key in per_task_fishers:
    F_sum = fisher_sums[key]
    if hasattr(F_sum, "inverse"):
        fisher_sum_inv = F_sum.inverse(damping)
    else:
        identity = IdentityLinearOperator(
            [tuple(p.shape) for p in models[0].parameters() if p.requires_grad],
            DEVICE,
            next(models[0].parameters()).dtype,
        )
        fisher_sum_inv = CGInverseLinearOperator(F_sum + damping * identity)
    merged_params[key] = fisher_sum_inv @ rhs[key]

Comparison

Let’s compare the performance of the different strategies. We initialize a neural network for each:

merged_models = {}
for key, params_vec in merged_params.items():
    model = make_architecture()
    params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    for theta, param in zip(
        vector_to_parameter_list(params_vec, params.values()), params.values()
    ):
        param.data = theta.to(param.device, param.dtype).data
    merged_models[key] = model

and probe them on one batch of each task:

losses = {key: [] for key in merged_params}
header = "\t" + "\t".join(losses.keys())
print(header)

for task_idx in range(T):
    data_loader = data_loaders[task_idx]
    loss_function = loss_functions[task_idx]
    X, y = next(iter(data_loader))

    for key, model in merged_models.items():
        losses[key].append(loss_function(model(X), y).item())
    assert losses["F"][-1] < losses["Naive"][-1]

    print(
        "\t".join(
            [f"Task {task_idx}"] + [f"{losses[key][-1]:.3f}" for key in merged_params]
        )
    )

mean_losses = {key: sum(loss) / len(loss) for key, loss in losses.items()}
print("\t".join(["Avg"] + [f"{mean_losses[key]:.3f}" for key in merged_params]))
        Naive   diag(F) F
Task 0  0.237   0.223   0.184
Task 1  0.216   0.205   0.193
Task 2  0.261   0.275   0.203
Avg     0.238   0.234   0.193

The Fisher-averaged parameters perform better than the naively averaged parameters; at least on the training data.

That’s all for now.

Total running time of the script: (0 minutes 0.734 seconds)

Gallery generated by Sphinx-Gallery