
# FedGO : Federated Ensemble Distillation with GAN-based Optimality

This repository is the implementation of [FedGO : Federated Ensemble Distillation with GAN-based Optimality]. 

Abstract : For federated learning in practical settings, a significant challenge is the considerable diversity of data across clients. To tackle this data heterogeneity issue, it has been recognized that federated ensemble distillation is effective. Federated ensemble distillation requires an unlabeled dataset on the server, which could either be an extra dataset the server already possesses or a dataset generated by training a generator using a data-free approach. Then, it proceeds by generating pseudo-labels for the unlabeled data based on the predictions of client models and training the server model using this pseudo-labeled dataset. Consequently, the efficacy of ensemble distillation hinges on the quality  of these pseudo-labels, which, in turn, poses a challenge of appropriately assigning weights to client models for each data point, particularly in scenarios with data heterogeneity. In this work, we suggest a provably near-optimal weighting method for federated ensemble distillation, inspired by theoretical results in generative adversarial networks. Our weighting method utilizes client discriminators, trained at the clients based on a generator distributed from the server and their own datasets. 
Our comprehensive experiments on various image classification tasks illustrate that our method significantly improves the performance over baselines, under various scenarios with and without extra server dataset. Furthermore, we provide an extensive analysis of additional communication burden, privacy leakage, and computational burden caused by our weighting method.

## Requirements

To install requirements:

```setup
pip install -r requirements.txt
```

The CIFAR-10/100 datasets will be automatically downloaded to the data folder when the training code is executed. For experiments using the ImageNet100 dataset, you need to download it from the following Kaggle link: [https://www.kaggle.com/datasets/ambityga/imagenet100](https://www.kaggle.com/datasets/ambityga/imagenet100). After downloading, downsample the images to 32x32 (we used the box method for downsampling) and place the train and validation image folders into the data folder.

## Toy Example


The Jupyter Notebook file `toy.ipynb` contains the code for implementing the toy example. By executing each cell sequentially from top to bottom, you can obtain the results of the toy example.


## Training and Evaluation


Using this script, we provide training for CIFAR-10/100 and ImageNet100.

For federated learning training on CIFAR-10 with \(α=0.1\), you can run the following code:
```bash
CUDA_VISIBLE_DEVICES=0 python main.py --dset_c cifar10 --dset_s cifar10 --combine gan --diri_alpha 0.1 --anneal True --gen_load False --diff_disc_ep False --disc_ep 30 --eval False
```

For training CIFAR-100 with \(α=0.1\), use the following code:
```bash
CUDA_VISIBLE_DEVICES=0 python main.py --dset_c cifar100 --dset_s cifar100 --combine gan --diri_alpha 0.1 --anneal True --gen_load False --diff_disc_ep False --disc_ep 30 --eval False
```

For training ImageNet100 with \(α=0.1\), use the following code:
```bash
CUDA_VISIBLE_DEVICES=0 python main.py --dset_c imagenet100 --dset_s imagenet100 --combine gan --diri_alpha 0.1 --anneal True --gen_load False --diff_disc_ep False --disc_ep 10 --eval False
```

To run experiments with \(α=0.05\), simply change the `--diri_alpha` parameter from 0.1 to 0.05.

For CIFAR-10, it took 5 days and 14 hours using an RTX 3090 GPU to obtain results from 5 repeated experiments.

The FL algorithms provided and the respective `--combine` parameters are as follows:
1. FedAVG: `avg`
2. FedDF: `df`
3. FedGKD: `df_gkd`
4. FedEDG: `gan`

For comparing different weighting methods, use the following parameters:
1. Variance based: `logit_var`
2. Entropy based: `em_entropy_soft`
3. Domain based: `gan_dafkd`

To change the discriminator training epochs, set `--diff_disc_ep` to `True` and adjust the `--disc_ep` parameter to the desired number of epochs.

To change the client dataset, modify the `--dset_c` parameter. To change the server's unlabeled dataset, modify the `--dset_s` parameter. This allows experiments with various client-server dataset combinations.

Training is repeated five times to measure test accuracy, and the server model is saved in the `model` folder after 100 rounds.

If a saved server model is available, you can perform evaluation by setting the `--eval` parameter to `True` instead of `False`. This will provide test accuracy for server models trained with different seeds, ranging from 3 to 7.

## Pre-trained Models

The client and server datasets, when using CIFAR-10 with Dirichlet \(α=0.1\) and \(α=0.05\), have the WGAN-GP model trained from scratch uploaded to the model folder. If you want to skip the generator training by loading this pre-trained model, you can do so by adding the argument "--gen_load True" to the python train command. Due to storage limitations, the generator model is uploaded only for CIFAR-10.

## Results

|                     | CIFAR-10 α=0.1       | CIFAR-10 α=0.05       | CIFAR-100 α=0.1       | CIFAR-100 α=0.05       | ImageNet100 α=0.1       | ImageNet100 α=0.05       |
|---------------------|-----------------------|------------------------|------------------------|-------------------------|--------------------------|---------------------------|
| **Central Training**| 85.33 ± 0.25          | 85.33 ± 0.25           | 51.72 ± 0.65           | 51.72 ± 0.65            | 43.20 ± 1.00             | 43.20 ± 1.00              |
| **FedAVG**          | 58.65 ± 5.75          | 46.61 ± 8.54           | 38.93 ± 0.74           | 36.66 ± 0.97            | 29.44 ± 0.41             | 27.58 ± 0.88              |
| **FedProx**         | 64.69 ± 2.15          | 55.56 ± 9.86           | 38.21 ± 0.95           | 34.44 ± 1.26            | 29.96 ± 0.66             | 26.99 ± 0.97              |
| **FedDF**           | 71.56 ± 5.09          | 59.53 ± 9.88           | 42.74 ± 1.22           | 37.18 ± 1.03            | 33.48 ± 1.00             | 30.94 ± 1.60              |
| **FedGKD<sup>&plus;</sup>**      | 72.59 ± 4.10          | 59.96 ± 8.60           | 43.35 ± 1.14           | 40.47 ± 1.00            | 34.10 ± 0.67             | 31.42 ± 0.93              |
| **DaFKD**           | 71.52 ± 5.56          | 67.51 ± 10.77          | 44.12 ± 2.25           | 39.50 ± 0.85            | 33.34 ± 0.69             | 31.59 ± 1.46              |
| **FedGO (ours)**   | **79.62 ± 4.36**      | **72.35 ± 9.01**       | **44.66 ± 1.27**       | **41.04 ± 0.99**        | **34.20 ± 0.71**         | **31.70 ± 1.55**          |


