# FedMAE: Federated Self-Supervised Learning with One-Block Masked Auto-Encoder

The PyTorch implementation for paper "FedMAE: Federated Self-Supervised Learning with One-Block Masked Auto-Encoder" submitted to ICLR conference.

## Environment Installation

1. Install Anaconda from the [Anaconda official website](https://www.anaconda.com/)
2. Run the following commands to install the virtual environment

```
cd federated_MAE
conda env create -f environment.yaml
```
3. Run the following command to switch to the installed virtual environment
```
conda activate fedMae
```
4. Find the appropriate command from the [PyTorch official website](https://pytorch.org/) to install PyTorch framework to the virtual environment

5. For any issue of missing library, try with the installation command:
```
pip install xxx
```

## Instruction 

### -- Dataset Preparation

Our codes support the following datasets:

0. CIFAR-10 (Supported by PyTorch)
1. CIFAR-100 (Supported by PyTorch)
2. food-101 (Supported by PyTorch)
3. ImageNet (Need to download from [ImageNet official website](https://image-net.org/index.php). How to extract: [link](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4))
4. Mini-ImageNet (After downloading ImageNet, see repo [Tools for mini-ImageNet Dataset](https://github.com/yaoyao-liu/mini-imagenet-tools#about-mini-ImageNet))
5. ImageNet32 (Need to download from [ImageNet official website](https://image-net.org/index.php))
6. RoadSign (Need to download from [Kaggle](https://www.kaggle.com/datasets/sergeykulakin/russian-road-signs-categories-dataset))
7. Mini-INat2021 (Need to download from repo [iNaturalist Competition Datasets](https://github.com/visipedia/inat_comp/tree/master/2021))

For datasets that need to be additionally downloaded, also remember to specify the correct path to data folder in **prepare_dataset.py** after the download.

### -- Federated Pretraining

Example 1:
* Pretrain with CIFAR-10 dataset
* Client data follows IID
* Default federated learning settings
* No data on server
```
python main.py -p pretrain -d 0 -ra 0 -sp iid -rd 200 -le 10 -tc 100 -sc 5 
```

Example 2:
* Pretrain with CIFAR-100 dataset
* Client data follows Non-IID with Dirichet Parameter 0.1
* No data on server
```
python main.py -p pretrain -d 1 -ra 0 -sp dir -alpha 0.1
```

### -- Finetuning (Requires pretrained checkpoint)

Example 1:
* Federated pretraining has completed
* Finetuning with CIFAR-10 dataset
* Use 100% labeled data
```
python main.py -p finetune -d 0 -ra 1 
```

Example 2:
* Federated pretraining has completed
* Finetuning with CIFAR-100 dataset
* Use 10% labeled data
```
python main.py -p finetune -d 1 -ra 0.1 
```


## File Structure

```
├── util/ <codes under this directory are taken from MAE repo>
├── debug.py <codes for debugging>
├── engine_finetune.py <the training engine for finetuning>
├── engine_pretrain.py <the training engine for pretraining>
├── environment.yaml <information about the conda environment>
├── federated_learning.py <includes codes for model aggregation and model cascading>
├── find_mean_std.py <codes that can be run to find mean and std of dataset>
├── imagenetLoad.py <dataloader codes for ImageNet32>
├── main.py <file for start running>
├── model_mae.py <implementation of MAE model>
├── model_ViT.py <implementation of ViT model>
├── prepare_dataset.py <codes for loading datsets, preparing dataloaders and data division>
├── readme.md <ReadMe file>
├── train_client.py <main program of client pretraining>
├── train_server.py <main program of finetuning>
```

# Reference

This code is implemented based on the repository [Masked Autoencoders: A PyTorch Implementation](https://github.com/facebookresearch/deit), which is the PyTorch implementation of paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377):
```
@Article{MaskedAutoencoders2021,
  author  = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll{\'a}r and Ross Girshick},
  journal = {arXiv:2111.06377},
  title   = {Masked Autoencoders Are Scalable Vision Learners},
  year    = {2021},
}
```

In the MAE repo, it also has references to the following projects:
 * [DeiT repo](https://github.com/facebookresearch/deit)
 * [timm](https://github.com/rwightman/pytorch-image-models)
 * [ELECTRA](https://github.com/google-research/electra)
 * [BEiT](https://github.com/microsoft/unilm/tree/master/beit)
 * [MoCo v3](https://github.com/facebookresearch/moco-v3)
 * [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py)


