.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_model_merging.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_model_merging.py: Fisher-weighted Model Averaging =================================== In this example we implement Fisher-weighted model averaging, a technique described in this `NeurIPS 2022 paper `_. It requires Fisher-vector products, and multiplication with the inverse of a sum of Fisher matrices. The paper uses a diagonal approximation of the Fisher matrices. In addition, we will also use the exact Fisher matrices and rely on matrix-free methods for applying the inverse. .. note:: In our setup, the Fisher equals the generalized Gauss-Newton matrix. Hence, we work with :py:class:`curvlinops.GGNLinearOperator`. **Description:** We are given a set of :math:`T` tasks (represented by data sets :math:`\mathcal{D}_t`), and train a model :math:`f_\mathbf{\theta}` on each task independently using the same criterion function. This yields :math:`T` parameters :math:`\mathbf{\theta}_1^\star, \dots, \mathbf{\theta}_T^\star`, and we would like to combine them into a single model :math:`f_\mathbf{\theta^\star}`. To do that, we use the Fisher information matrices :math:`\mathbf{F}_t` of each task (given by the data set :math:`\mathcal{D}_t` and the trained model parameters :math:`\mathbf{\theta}_t^\star`). The merged parameters are given by .. math:: \mathbf{\theta}^\star = \left(\lambda \mathbf{I} + \sum_{t=1}^T \mathbf{F}_t \right)^{-1} \left( \sum_{t=1}^T \mathbf{F}_t \mathbf{\theta}_t^\star\right)\,. This requires multiplying with the inverse of the sum of Fisher matrices (extended with a damping term). If we approximate each Fisher with its diagonal, this is easy, without this approximation, we will use :py:class:`curvlinops.CGInverseLinearOperator` for inversion. Naive averaging corresponds to the special case where the Fisher is the identity. Let's start with the imports. .. GENERATED FROM PYTHON SOURCE LINES 38-58 .. code-block:: Python from backpack.utils.convert_parameters import vector_to_parameter_list from torch import cuda, device, manual_seed, rand from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid from torch.nn.utils import parameters_to_vector from torch.optim import SGD from torch.utils.data import DataLoader, TensorDataset from curvlinops import ( CGInverseLinearOperator, GGNDiagonalLinearOperator, GGNLinearOperator, ) from curvlinops.examples import IdentityLinearOperator # make deterministic manual_seed(0) DEVICE = device("cuda" if cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 59-64 Setup ----- First, we will create a bunch of synthetic regression tasks (i.e. data sets) and an untrained model for each of them. .. GENERATED FROM PYTHON SOURCE LINES 65-103 .. code-block:: Python T = 3 # number of tasks D_in = 7 # input dimension of each task D_hidden = 5 # hidden dimension of the architecture we will use D_out = 3 # output dimension of each task N = 20 # number of data per task batch_size = 7 def make_architecture() -> Sequential: """Create a neural network. Returns: A neural network. """ return Sequential( Linear(D_in, D_hidden), ReLU(), Linear(D_hidden, D_hidden), Sigmoid(), Linear(D_hidden, D_out), ) def make_dataset() -> TensorDataset: """Create a synthetic regression data set. Returns: A synthetic regression data set. """ X, y = rand(N, D_in), rand(N, D_out) return TensorDataset(X, y) models = [make_architecture().to(DEVICE) for _ in range(T)] data_loaders = [DataLoader(make_dataset(), batch_size=batch_size) for _ in range(T)] loss_functions = [MSELoss(reduction="mean").to(DEVICE) for _ in range(T)] .. GENERATED FROM PYTHON SOURCE LINES 104-108 Training -------- Here, we train each model for a small number of epochs. .. GENERATED FROM PYTHON SOURCE LINES 109-130 .. code-block:: Python num_epochs = 10 log_epochs = [0, num_epochs - 1] for task_idx in range(T): model = models[task_idx] data_loader = data_loaders[task_idx] loss_function = loss_functions[task_idx] optimizer = SGD(model.parameters(), lr=1e-2) for epoch in range(num_epochs): for batch_idx, (X, y) in enumerate(data_loader): optimizer.zero_grad() X, y = X.to(DEVICE), y.to(DEVICE) loss = loss_function(model(X), y) loss.backward() optimizer.step() if epoch in log_epochs and batch_idx == 0: print(f"Task {task_idx} batch loss at epoch {epoch}: {loss.item():.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Task 0 batch loss at epoch 0: 0.454 Task 0 batch loss at epoch 9: 0.248 Task 1 batch loss at epoch 0: 0.632 Task 1 batch loss at epoch 9: 0.300 Task 2 batch loss at epoch 0: 0.363 Task 2 batch loss at epoch 9: 0.198 .. GENERATED FROM PYTHON SOURCE LINES 131-137 Linear operators ---------------- We are now ready to set up the linear operators for the per-task Fishers. We collect the per-task Fisher operators for all three strategies in a dictionary. Naive averaging corresponds to using the identity as Fisher: .. GENERATED FROM PYTHON SOURCE LINES 138-168 .. code-block:: Python per_task_fishers = { # Diagonal approximation as used in the seminal paper # (Precisely speaking, the seminal paper uses a randomized approximation of the # Fisher based on sampling that can be achieved with `mc_samples=1`. # For simplicity we compute the exact GGN/Fisher diagonal here.) "diag(F)": [ GGNDiagonalLinearOperator( model, loss_function, {n: p for n, p in model.named_parameters() if p.requires_grad}, data_loader, ) for model, loss_function, data_loader in zip( models, loss_functions, data_loaders ) ], "F": [ GGNLinearOperator( model, loss_function, {n: p for n, p in model.named_parameters() if p.requires_grad}, data_loader, ) for model, loss_function, data_loader in zip( models, loss_functions, data_loaders ) ], } .. GENERATED FROM PYTHON SOURCE LINES 169-173 Fisher-weighted Averaging ------------------------- Next, we also need the trained parameters as vectors: .. GENERATED FROM PYTHON SOURCE LINES 174-181 .. code-block:: Python # flatten and concatenate thetas = [ parameters_to_vector((p for p in model.parameters() if p.requires_grad)).detach() for model in models ] .. GENERATED FROM PYTHON SOURCE LINES 182-184 We are ready to compute the sum of Fisher-weighted parameters (the right-hand side in the above equation) for each strategy: .. GENERATED FROM PYTHON SOURCE LINES 185-191 .. code-block:: Python rhs = { key: sum(F @ theta for F, theta in zip(Fs, thetas)) for key, Fs in per_task_fishers.items() } .. GENERATED FROM PYTHON SOURCE LINES 192-194 In the last step we need to normalize by multiplying with the inverse of the summed Fishers. Let's first sum the per-task Fishers for each strategy: .. GENERATED FROM PYTHON SOURCE LINES 195-202 .. code-block:: Python fisher_sums = {} for key, Fs in per_task_fishers.items(): fisher_sums[key] = Fs[0] for F in Fs[1:]: fisher_sums[key] = F + fisher_sums[key] .. GENERATED FROM PYTHON SOURCE LINES 203-212 Finally, we compute the merged parameters by applying the inverse of the damped Fisher sum. For diagonal operators (naive and ``diag(F)``), the inverse is analytical. For the full Fisher, we use :py:class:`curvlinops.CGInverseLinearOperator`: .. note:: For the full Fisher, you may want to tweak the convergence criterion of CG using :py:func:`curvlinops.CGInverseLinearOperator.set_cg_hyperparameters` before applying the matrix-vector product. .. GENERATED FROM PYTHON SOURCE LINES 213-230 .. code-block:: Python damping = 1e-3 merged_params = {"Naive": sum(thetas) / len(thetas)} for key in per_task_fishers: F_sum = fisher_sums[key] if hasattr(F_sum, "inverse"): fisher_sum_inv = F_sum.inverse(damping) else: identity = IdentityLinearOperator( [tuple(p.shape) for p in models[0].parameters() if p.requires_grad], DEVICE, next(models[0].parameters()).dtype, ) fisher_sum_inv = CGInverseLinearOperator(F_sum + damping * identity) merged_params[key] = fisher_sum_inv @ rhs[key] .. GENERATED FROM PYTHON SOURCE LINES 231-236 Comparison ---------- Let's compare the performance of the different strategies. We initialize a neural network for each: .. GENERATED FROM PYTHON SOURCE LINES 237-248 .. code-block:: Python merged_models = {} for key, params_vec in merged_params.items(): model = make_architecture() params = {n: p for n, p in model.named_parameters() if p.requires_grad} for theta, param in zip( vector_to_parameter_list(params_vec, params.values()), params.values() ): param.data = theta.to(param.device, param.dtype).data merged_models[key] = model .. GENERATED FROM PYTHON SOURCE LINES 249-250 and probe them on one batch of each task: .. GENERATED FROM PYTHON SOURCE LINES 251-274 .. code-block:: Python losses = {key: [] for key in merged_params} header = "\t" + "\t".join(losses.keys()) print(header) for task_idx in range(T): data_loader = data_loaders[task_idx] loss_function = loss_functions[task_idx] X, y = next(iter(data_loader)) for key, model in merged_models.items(): losses[key].append(loss_function(model(X), y).item()) assert losses["F"][-1] < losses["Naive"][-1] print( "\t".join( [f"Task {task_idx}"] + [f"{losses[key][-1]:.3f}" for key in merged_params] ) ) mean_losses = {key: sum(loss) / len(loss) for key, loss in losses.items()} print("\t".join(["Avg"] + [f"{mean_losses[key]:.3f}" for key in merged_params])) .. rst-class:: sphx-glr-script-out .. code-block:: none Naive diag(F) F Task 0 0.237 0.223 0.184 Task 1 0.216 0.205 0.193 Task 2 0.261 0.275 0.203 Avg 0.238 0.234 0.193 .. GENERATED FROM PYTHON SOURCE LINES 275-279 The Fisher-averaged parameters perform better than the naively averaged parameters; at least on the training data. That's all for now. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.734 seconds) .. _sphx_glr_download_basic_usage_example_model_merging.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_model_merging.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_model_merging.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_model_merging.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_