# ON THE GENERALIZATION OF WASSERSTEIN ROBUST FEDERATED LEARNING
This repository implements all experiments in the paper the *ON THE GENERALIZATION OF WASSERSTEIN ROBUST FEDERATED LEARNING*.
  
Authors: 

# Software requirements:
- numpy, scipy, torch, Pillow, matplotlib, tqdm, pandas, h5py, comet_ml, scikit-learn==0.20.3
- To download the dependencies: **pip3 install -r requirements.txt**
- The code can be run on any pc, doesn't require GPU.
  
# Datasets:

- MNIST (100 clients): generated by runing <pre><code>python3 data/Mnist/generate_niid_100users.py</code></pre> 
- CIFAR (20 clients): This dataset will be downloaded and generated automatically when runing algorithms.
- Three-digit Datasets (MNIST, MNISTM, USPS): A subset of Digit-Five Dataset

Download Link: https://drive.google.com/file/d/14YcDLqSYn6FbLMpZqilM5VLNnX06U2Dx/view?usp=sharing

All dataset after downloading must be stored at folder \data


# Adversarial Training 
<pre><code>
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0 --gamma 0.05 --alpha 10 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.1 --numusers 100 --times 10

    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.2 --gamma 0.05 --alpha 10 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.1 --numusers 100 --times 10

    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.4 --gamma 0.05 --alpha 10  --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.1 --numusers 100 --times 10

    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.6 --gamma 0.05 --alpha 10 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.1 --numusers 100 --times 10
    
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.8 --gamma 0.05 --alpha 10  --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.1 --numusers 100 --times 10
    python3 main.py --dataset Mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.1 --numusers 100 --times 10
</code></pre>

## Cifar10
<pre><code>
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0 --gamma 0.5 --alpha 1 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.5 --numusers 20 --times 10

    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.2 --gamma 0.5 --alpha 1 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.2 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.5 --numusers 20 --times 10

    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.4 --gamma 0.5 --alpha 1 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.4 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.5 --numusers 20 --times 10
    
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.6 --gamma 0.5 --alpha 1 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.6 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.5 --numusers 20 --times 10

    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.8 --gamma 0.5 --alpha 1 --num_global_iters 200 --local_epochs 2 --algorithm WAFL --subusers 0.5 --numusers 200 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedAvg --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedPGD --subusers 0.5 --numusers 20 --times 10
    python3 main.py --dataset Cifar10 --model cnn --batch_size 64 --learning_rate 0.05 --robust 0.8 --num_global_iters 200 --local_epochs 2 --algorithm FedFGSM --subusers 0.5 --numusers 20 --times 10
cnnode></pre>

## Domain Adaptation

### Generate source clients' domains data
Before running Domain Adadaption experiments, we need to generate source clients' data by using *data/fiveDigit/generate_niid_DAUsers.py*
    usage: python generate_niid_DAUsers.py [-h] [--dataset {mnist,mnistm,usps}]
                                    [--numuser NUMUSER] [--label LABEL]
                                    [--batch_size BATCH_SIZE] [--verbose VERBOSE]

    optional arguments:
    -h, --help            show this help message and exit
    --dataset {mnist,mnistm,usps}
                             Select source dataset  (Default = mnist)
    --numuser    NUMUSER     Total source domain users  (Default = 100)
    --label      LABEL       Number of classes per user (Default = 2)
    --batch_size BATCH_SIZE  Size of batches (Optional) 
    --verbose    VERBOSE     Print logs while generating data (Optional)
    
#### MNISTM to MNIST
<pre><code>
    python3 main.py --dataset mnistm2mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust -1 --gamma 0.1 --num_global_iters 10 --local_epochs 2 --algorithm FedAvg --subusers 0.1 --numusers 101 --times 10
    python3 main.py --dataset mnistm2mnist --model mclr --batch_size 64 --learning_rate 0.001 --robust -1 --gamma 0.1 --num_global_iters 10 --local_epochs 2 --algorithm WAFL --subusers 0.1 --numusers 101 --times 10
</code></pre>


