# Communication-Efficient Heterogeneous Federated Learning with Generalized Heavy-Ball Momentum
Official PyTorch implementation of ICLR 2025 paper submission `6420` **_"Communication-Efficient Heterogeneous Federated Learning with Generalized Heavy-Ball Momentum_**

**Authors**: Anonymous

## Abstract
Federated Learning (FL) has emerged as the state-of-the-art approach for learning from decentralized data in privacy-constrained scenarios. However, system and statistical challenges hinder real-world applications, which demand efficient learning from edge devices and robustness to heterogeneity. Despite significant research efforts, existing approaches (i) are not sufficiently robust, (ii) do not perform well in large-scale scenarios, and (iii) are not communication efficient. In this work, we propose a novel _Generalized Heavy-Ball Momentum_ (GHBM), proving that it enjoys an improved theoretical convergence rate w.r.t. existing FL methods based on classical momentum in _partial participation_, without relying on bounded data heterogeneity. Then, we present FedHBM as an adaptive, communication-efficient by-design instance of GHBM. Extensive experimentation on vision and language tasks, in both controlled and realistic large-scale scenarios, confirms our theoretical findings, showing that GHBM substantially improves the state of the art, especially in large scale scenarios with high data heterogeneity and low client participation.

| WARNING: **this is not production code!** |
| --- |

## Usage

### Requirements
To install the requirements, you can use the provided requirement file and use pip in your (virtual) environment:
```shell
# after having activated your environment
$ pip install -r requirements/requirements.txt
```

### Replicate our results
Simply run the ```train.py``` specifying the
command line arguments. Please note that default arguments are specified in ```./config``` folder.
For example, for running the experiments on CIFAR10, just issue:
```shell
# run cifar10 in our cross-silo scenario using LeNet, non-iid
$ python train.py algo=fedhbm model=lenet dataset=cifar10 n_round=10000 algo.params.common.K=100 algo.params.common.C=0.1 algo.params.common.alpha=0
```
```shell
# run cifar10 in our cross-device scenario using LeNet, non-iid
$ python train.py algo=fedhbm model=lenet dataset=cifar10 n_round=10000 algo.params.common.K=500 algo.params.common.C=0.01 algo.params.common.alpha=0
```


```shell
# run cifar10 in our cross-silo scenario using ResNet-20, non-iid
$ python train.py algo=fedhbm model=resnet dataset=cifar10 n_round=10000 algo.params.common.K=100 algo.params.common.C=0.1 algo.params.common.alpha=0
```

```shell
# run cifar10 in our cross-device scenario using ResNet-20, non-iid
$ python train.py algo=fedhbm model=resnet dataset=cifar10 n_round=10000 algo.params.common.K=500 algo.params.common.C=0.01 algo.params.common.alpha=0
```
For Shakespeare and StackOverflow you may request the dataset, we cannot include in the submission due to file size limit
```shell
# run shakespeare non-iid
$ python train.py algo=fedhbm model=shakespeareLSTM dataset=shakespeare n_round=250 algo.params.common.K=100 algo.params.common.C=0.1 dataset.getter_fn.args.version=niid algo.params.optim.args.lr=1 algo.params.center_server.args.optim.args.lr=1 algo.params.optim.args.weight_decay=0 algo.params.common.B=100 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```

```shell
# run shakespeare iid
$ python train.py algo=fedhbm model=shakespeareLSTM dataset=shakespeare n_round=250 algo.params.common.K=100 algo.params.common.C=0.1 dataset.getter_fn.args.version=iid algo.params.optim.args.lr=1 algo.params.optim.args.weight_decay=0 algo.params.common.B=100 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```

```shell
# run stackoverflow
$ python train.py algo=fedhbm model=stackoverflowLSTM dataset=stackoverflow n_round=1500 algo.params.common.K=40000 algo.params.common.C=0.00125 dataset.getter_fn.args.version=niid algo.params.optim.args.lr=0.3 algo.params.optim.args.weight_decay=0 algo.params.common.B=16 exp_name='${algo.type}_${dataset.type}_${dataset.getter_fn.args.version}_K${algo.params.common.K}_C${algo.params.common.C}'
```


This software uses Hydra to configure experiments, for more information on how to provide command
line arguments, please refer to the [official documentation](https://hydra.cc/docs/advanced/override_grammar/basic/).