# Communication-Efficient Heterogeneous Federated Learning with Generalized Heavy-Ball Momentum
Official PyTorch implementation of ICLR 2024 paper submission `5005` **_"Communication-Efficient Heterogeneous Federated Learning with Generalized Heavy-Ball Momentum_**

**Authors**: Anonymous

## Abstract
Federated Learning (FL) is the state-of-the-art approach for learning from decentralized data in privacy-constrained scenarios. 
As the current literature reports, the main problems associated with FL refer to system and statistical challenges: the former ones demand for efficient learning from edge devices, including lowering communication bandwidth and frequency, while the latter require algorithms robust to non-iidness.
State-of-art approaches either guarantee convergence at increased communication cost or are not sufficiently robust to handle extreme heterogeneous local distributions.
In this work we propose a novel generalization of the _heavy-ball_ momentum, and present FedHBM to effectively address statistical heterogeneity in FL without introducing any communication overhead.
We conduct extensive experimentation on common FL vision and NLP datasets, showing that our FedHBM algorithm empirically yields better model quality and higher convergence speed w.r.t. the state-of-art, especially in pathological non-iid scenarios.
While being designed for cross-silo settings, we show how FedHBM is applicable in moderate-to-high cross-device scenarios, and how good model initializations (e.g. pre-training) can be exploited for prompt acceleration.
Extended experimentation on large-scale real-world federated datasets further corroborates the effectiveness of our approach for real-world FL applications. 

| WARNING: **this is not production code!** |
| --- |

## Usage

### Requirements
To install the requirements, you can use the provided requirement file and use pip in your (virtual) environment:
```shell
# after having activated your environment
$ pip install -r requirements/requirements.txt
```

### Replicate our results
Simply run the ```train.py``` specifying the
command line arguments. Please note that default arguments are specified in ```./config``` folder.
For example, for running the experiments on CIFAR10, just issue:
```shell
# run cifar10 in our cross-silo scenario using LeNet, non-iid
$ python train.py algo=fedhbm model=lenet dataset=cifar10 n_round=10000 algo.params.common.K=100 algo.params.common.C=0.1 algo.params.common.alpha=0
```
```shell
# run cifar10 in our cross-device scenario using LeNet, non-iid
$ python train.py algo=fedhbm model=lenet dataset=cifar10 n_round=10000 algo.params.common.K=500 algo.params.common.C=0.01 algo.params.common.alpha=0
```


```shell
# run cifar10 in our cross-silo scenario using ResNet-20, non-iid
$ python train.py algo=fedhbm model=resnet dataset=cifar10 n_round=10000 algo.params.common.K=100 algo.params.common.C=0.1 algo.params.common.alpha=0
```

```shell
# run cifar10 in our cross-device scenario using ResNet-20, non-iid
$ python train.py algo=fedhbm model=resnet dataset=cifar10 n_round=10000 algo.params.common.K=500 algo.params.common.C=0.01 algo.params.common.alpha=0
```
For Shakespeare and StackOverflow you may request the dataset, we cannot include in the submission due to file size limit
```shell
# run shakespeare non-iid
$ python train.py algo=fedhbm model=shakespeareLSTM dataset=shakespeare n_round=250 algo.params.common.K=100 algo.params.common.C=0.1 dataset.getter_fn.args.version=niid algo.params.optim.args.lr=1 algo.params.center_server.args.optim.args.lr=1 algo.params.optim.args.weight_decay=0 algo.params.common.B=100 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```

```shell
# run shakespeare iid
$ python train.py algo=fedhbm model=shakespeareLSTM dataset=shakespeare n_round=250 algo.params.common.K=100 algo.params.common.C=0.1 dataset.getter_fn.args.version=iid algo.params.optim.args.lr=1 algo.params.optim.args.weight_decay=0 algo.params.common.B=100 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```

```shell
# run stackoverflow
$ python train.py algo=fedhbm model=stackoverflowLSTM dataset=stackoverflow n_round=1500 algo.params.common.K=40000 algo.params.common.C=0.00125 dataset.getter_fn.args.version=niid algo.params.optim.args.lr=0.3 algo.params.optim.args.weight_decay=0 algo.params.common.B=16 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```


This software uses Hydra to configure experiments, for more information on how to provide command
line arguments, please refer to the [official documentation](https://hydra.cc/docs/advanced/override_grammar/basic/).