# Step-Ahead Partial Error Feedback (SA-PEF)

<p align="center">
  Official PyTorch implementation of <br>
  <b>Step-Ahead Partial Error Feedback (SA-PEF)</b><br>
  A communication-efficient federated learning algorithm.
</p>

## 📖 Overview

Federated learning (FL) struggles with two coupled issues:

1. **Non-IID client data** → local drift and unstable training.
2. **Aggressive compression / partial participation** → information loss and bias.

**SA-PEF (Step-Ahead Partial Error Feedback)** tackles both by acting on **the path** and **the message**:

- **Compressed updates**: clients send a sparse / quantized update (e.g., Top-k).
- **Error-feedback memory**: the unsent residual is carried forward across rounds.
- **Step-ahead preview (partial)**: before local SGD, each client shifts its model by a _fraction_ $\alpha_r$ of its residual, while the remainder flows through the EF channel.
  This reduces early gradient mismatch **without** extra communication.

### Why it works

- **Lower mismatch, same budget.** Previewing $\alpha_r e_r$ aligns local steps with the global descent direction; partial EF corrects compressor bias.
- **Robust in hard regimes.** Handles extreme sparsity, many local steps, non-IID data, and partial participation; without error averaging or extra syncs.

### What you get

- **Faster convergence** in the early phase (less staleness).
- **Stable final accuracy** under heavy compression.
- **Lower communication**, plug-and-play with common compressors.

## ⚙️ Installation

We recommend using a conda environment:

```bash
unzip SA-PEF.zip -d SA-PEF
cd SA-PEF
conda env create -f sapef_environment.yaml -n sapef
conda activate sapef
```

## Dependencies:

- Python 3.9+
- PyTorch 2.0+
- Torchvision 0.15+
- [Hydra](https://hydra.cc/) 1.3
- [Flower](https://flower.dev/) for FL orchestration
- wandb (optional, for experiment logging)

## 📂 Datasets

Supported benchmarks:

- **CIFAR-10/100** ([torchvision link](https://pytorch.org/vision/stable/datasets.html))
- **Tiny-ImageNet** ([kaggle link](https://www.kaggle.com/c/tiny-imagenet))

Extract datasets into `./data/`.

## 🔬 Baselines

We provide implementations of:

- **FedAvg** ([McMahan et al., 2017](https://arxiv.org/abs/1602.05629))
- **EF** ([Li & Li 2023](https://proceedings.mlr.press/v202/li23o/li23o.pdf))
- **SAEF** ([Step-Ahead Error Feedback](https://arxiv.org/pdf/2008.05823))
- **CSER** ([Compressed SGD with Error Reset](https://arxiv.org/pdf/2007.13221))

## 🚀 Training

### Example: Algorithm: FedAvg, CIFAR-100, 100 clients, Dirichlet(0.1), 10% participation

```bash
CUDA_VISIBLE_DEVICES=0 python -m main --config-name fedavg_base approach=fedavg dataset_name=cifar100 client_resources.num_gpus=0.0375 num_clients=100 alpha=0.3 clients_per_round=10 num_rounds=200 alpha=0.1  wandb=False
```

### Example: Algorithm: SAPEF/SAEF/EF, CIFAR-100, 100 clients, Dirichlet(0.1), 10% participation

```bash
CUDA_VISIBLE_DEVICES=0,1 python -m main --config-name fedavg_topk approach=sapef dataset_name=cifar100 sparsify_by=0.01 client_resources.num_gpus=0.0375 num_clients=100 clients_per_round=10 num_rounds=5 alpha=0.1 alpha_r=0.84 comp_type=topk wandb=False
```

### Example: SAPEF/SAEF/EF, CIFAR-100, 100 clients, Dirichlet(0.1), 10% participation

```bash
CUDA_VISIBLE_DEVICES=0 python -m main --config-name fedavg_topk_cser approach=cser dataset_name=cifar100 sparsify_by=0.01 H=5 reset_frac=0.1 client_resources.num_gpus=0.0375 num_rounds=200 alpha=0.1 num_clients=100 clients_per_round=10 comp_type=topk wandb=False
```

Metrics are automatically saved to `./results/` and optionally tracked via wandb.

### You may choose any combination from below (respecting the constraints): 

dataset_name in {cifar10, cifar100, tinyimagenet}

approach in {fedavg, ef, saef, sapef, cser}

comp_type in {dense, topk}
  - Use comp_type=topk only with approach in {ef, saef, sapef, cser}

sparsify_by in {0.01, 0.05, 0.10}        # only if comp_type=topk

partitioning in {iid, dirichlet}
alpha (Dirichlet) in {0.1, 0.5}          

clients_per_round in {10, 50, 100}            # q ∈ {0.1, 0.5, 1.0} when num_clients=100
num_clients = 100

alpha_r in {0.84, 0.90}                  # only for approach in {sapef}

wandb in {True, False}

```

```
