# Wasserstein Barycenter-based Model Fusion and Connections to Loss Landscapes of Neural Networks

We provide code for all the experiments presented in our paper.

The organization of code is as follows:
* Source code is present in `src` directory.
* Bash files required to run the experiments along with all the hyperparameters used are in the `bin` directory.

### Requirements

The main dependencies for running the code are
* pytorch
* torchvision
* torchtext
* torchdata
* PTL
* numpy
* matplotlib
* seaborn
* Python Optimal Transport (POT)


## Running Experiments

Next, we provide detailed instructions on running each experiment.

In general, each experiment has a bash file in `bin` directory along with the hyperparameters and random seeds used in the experiment. Corresponding command in the relevant bash file needs to be uncommented before running the experiment. For most of the code, commands and argument names are self-explanatory.

### Download Datasets

The datasets AGNEWS and DBpedia can be downloaded by running `bash bin/download_datasets.sh`. The pre-trained embedding GloVe used in RNN and LSTM training can be downloaded by running `bash bin/download_glove.sh`.
Other datasets will be automatically downloaded during the training.



### Training Models

All the base models for fusion experiments needs to be trained!
Relevant code for training MLPNet, VGG11 and ResNet18 is in `src/tlp_model_fusion/train_models.py`.
Relevant code for training RNN and LSTM is in `src/tlp_rnn_fusion/train_rnn.py`.
The model training can be done by uncommenting the specific command.
The code for model classes MLPNet, VGG11 and ResNet18 are in `model.py`, `vgg_models.py` and `resnet_models.py` under directory `src/tlp_model_fusion`.
The code for model classes RNN and LSTM are in `src/tlp_rnn_fusion/rnn_models.py`.


Running training:
1. Check `bin/run_train_models.sh`.
2. Identify the model to be trained and uncomment the corresponding command.
3. For model classes MLPNet, VGG11 and ResNet18, the results of trained models would be presented in `result/<experimenet_name>/<model_name>_<dataset_name>/<run_id>/snapshots/`, where `<run_id>` is a string consisting of relevant parameters used for this training like random seed etc. For model class RNN and LSTM, the results of trained models would be presented in path provided in the commands.
4. Run `bash bin/run_train_models.sh`.

The model with best validation accuracy is saved as `best_val_acc_model.pth`, while the final model at the end of training epoch is saved as `final_model.pth`.

Note: We use the model with best validation accuracy for our fusion experiments. 
All the required model training can be done using this script.


### Fusing FC models with same architecture

The relevant code for FC NN and deep CNN fusion is in `fuse_models.py`, `tlp_fusion.py`, `ad_hoc_ot_fusion.py` and `avg_fusion.py` under dir `src/tlp_model_fusion/`.

Some arguments worth noting for using `fuse_models.py` are
* `fusion_type`: The types of fusion to perform. It takes values in `tlp` for WB fusion, `ot` for OT fusion, `avg` for Vanilla averaging.
* `tlp_*` are the parameters for WB fusion, `ad_hoc_ot_*` are the parameters for OT fusion.
* `tlp_init_type`: The type of initialization for target model. It takes values in `identity` for initializing using one of base models, `None` for randomly initialization.
* `tlp_cost_choice`, `ad_hoc_cost_choice` denote the type of cost functions to be used for the fusion.
* `tlp_ot_solver`, `ad_hoc_ot_solver` specify the type of solver to use for solving OT optimization problem.

Running fusion of FC (fully connected) models with same architecture
1. Check `bin/run_fuse_fc_models.sh`.
2. WB fusion, OT fusion, Vanilla averaging have been documented in comments.
3. The fused model is saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/fused_model.pth`. The permuted models are saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/permuted_model_{i}.pth` for i=1,2. 
4. Note that the `<run_id>` contains all the relevant arguments as a string to identify fusion using 
a specific set of parameters.
5. After uncommenting and choosing the correct parameters run `bash bin/run_fuse_fc_models.sh`.
6. Note down the validation and test accuracy of the fused model and permuted model 2 at the end of each fusion.


### Fusion into different architecture with multiple model counts

The fusing of models into different architecture with multiple model counts can be run using `bin/run_distill_models.sh`:
1. The first part of `bin/run_distill_models.sh` under `################## FUSION MULTIPLE MODELS INTO MLPLarge ###################` 
heading contains the commands for performing the relevant fusion.
2. The commands for WB fusion and OT fusion can be umcommented separately to run the experiments.
3. Run `bash bin/run_distill_models.sh` for the appropriate command.
4. The fused model is saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/fused_model.pth`
5. Note down the validation and test accuracy of the fused model at the end of each fusion.


### Fusing deep CNNs

The deep CNN models - VGG11, ResNet18 can be trained using `bin/run_train_models.sh`.

Run the followings to perform the fusion for deep CNNs:
1. Fusion for VGG11 models can be done using `bin/run_fuse_vgg_models.sh` 
2. For ResNet18 models, the fusion can be done using `bin/run_fuse_resnet_models.sh`
3. Uncomment the command for specific fusion type and run the scripts.
4. As usual the fused model is saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/fused_model.pth`, and the permuted models are saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/permuted_model_{i}.pth` for i=1,2. 
5. Note down the validation and test accuracy of the fused model and permuted model 2 at the end of each fusion.


### Fusing RNN and LSTM models

The relevant code for RNN and LSTM fusion is in `fuse_rnn_models.py`, `tlp_fusion_rnn.py` under dir `src/tlp_rnn_fusion/`, and `ad_hoc_ot_fusion.py` and `avg_fusion.py` under dir `src/tlp_model_fusion/`.

Some arguments worth noting for using `fuse_rnn_models.py` are
* `fusion_type`: The types of fusion to perform. It takes values in `tlp` for GWB fusion, `tlp_no_hidden` for WB fusion, `ot` for OT fusion and `avg` for Vanilla averaging. (Note here `tlp` will stand for GWB fusion, not WB fusion as in the FC and CNN cases)
* `alpha_h` is the hyperparameter that balances the importance of input-to-hidden weights and hidden-to-hidden weights in the GWB fusion. 
  * Since we didn't normalize the couplings for the first layer in the implementation, so instead of choosing alpha_h between [1,20], we chose alpha_h between [50, 1000].
  * The alpha_h is chosed between [50, 1000] only for hidden layers. For other layer, we just set alpha_h=1.

Run the followings to perform the fusion for RNN models:
1. Check `bin/run_fuse_rnn_models.sh` for RNN models, `bin/run_fuse_lstm_models.sh` for LSTM models.
2. GWB fusion, WB fusion, OT fusion and Vanilla averaging for each dataset have been documented in comments. After uncommenting the specific command, run `bash bin/run_fuse_rnn_models.sh` or `bash bin/run_fuse_lstm_models.sh`.
3. The fused model and permuted models are saved as `model_{seed1}_{seed2}_<regularization_para>.pth` and `permuted_model_{i}.pth` for i=1,2 under the path provided in the commands for each case respectively.
4. Note down the validation and test accuracy of the fused model and permuted model 2 at the end of each fusion.


### Visualizations of the loss landscapes

The relevant code for visualizations is in `plane.py`, `plane_plot.py` under dir `src/tlp_model_fusion/`.

Run the followings to generate the visualizations of the loss landscapes:
1. First check `bin/run_plane.sh`. This bash file is used for generating the grid planes using for visualizations. 
2. Uncomment specific command for each combination of model type and dataset and run `bash bin/run_plane.sh`.
3. The grid plane is saved as `result/visualization_<model_name>_<dataset_name>/plane.npz`.
4. Second, check `bin/run_plane_plot.sh`. This bash file is used for generating the figures of visualizations.
5. Uncomment specific command for each combination of model type and dataset and run `bash bin/run_plane_plot.sh`.
6. The visualization results is also saved under dir `result/visualization_<model_name>_<dataset_name>/`.


### Single shot distillation 

The single shot distillation experiments can be run using `bin/run_distill_models.sh`:
1. The second part of `bin/run_distill_models.sh` under `################ DISTILLATION INTO DIFFERENT ARCHITECTURE #################` 
heading contains the commands for performing distillation.
2. Note that distillation experiments are carried out for different seeds which are 
written as a for-loop in the script.
3. The entire script runs distillation for each seed and each choice of ot solver.
4. Uncomment appropriate sections to run distillation using WB method or OT method.
5. Run using `bash bin/run_distill_models.sh`
6. The fused model is saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/fused_model.pth`
7. Note down the validation and test accuracy of the fused model at the end of each fusion.



### WB fusion for heterogeneous data distributions

The models can be trained on heterogeneous data distributions using the `bin/run_train_models.sh`.

After the models are trained the fusion can be performed using following:
1. Check `bin/run_fuse_hetero_models.sh` 
2. Note that fusion is performed for models trained from the same seed but different data distributions.
3. The weights for different models can be adjusted using `WA` and `WB` variables.
4. Model A is the model trained to recognize special digit, while Model B is the model trained for other digits.
5. Uncomment appropriate sections and run `bash bin/run_fuse_hetero_models.sh`
6. The fused model is saved as `result/<experiment_name>/<model_name>_<dataset_name>/<run_id>/snapshots/fused_model.pth`
7. Note down the validation and test accuracy of the fused model at the end of each fusion.





