## Demystifying Latent Forgetting in Federated Learning

## Abstract
Federated Learning (FL) enables collaborative model training across decentralized, isolated clients in a privacy-preserving manner, but at the cost of limited control over the data and the training procedure. One of the key challenges in FL is the spatial data heterogeneity, which is due to the stratified nature of the underlying data distributions between clients. In addition, FL systems also undergo periods of time in which certain features disappear from the training data pool, resulting in the less studied but critical problem of temporal-spatial data heterogeneity. Such non-uniformity in training data across time introduces a new feature-level latent forgetting that is fundamentally different from the well-studied task-level catastrophic forgetting in continual learning. 
This latent forgetting, if not detected and mitigated timely, can result in poor model performance, especially for certain learning features.
The privacy requirements and temporal-spatial data heterogeneity of FL make the detection and mitigation of latent forgetting challenging.
In this paper, we analyze latent forgetting and propose FedMemo, a privacy-preserving FL framework to control its impact. FedMemo  employs an automated detection mechanism to detect latent forgetting in real time with preserved privacy. FedMemo further introduces a proxy-based 2-step aggregation approach to mitigate the impact of latent forgetting. We evaluate FedMemo in a diverse set of vision and language classification tasks in various FL settings, and show that it outperforms state-of-the-art methods by up to $20.06\%$
## Installation

### Prerequisite
vision
* python == 3.9
* torch == 1.13.1
* torchvision == 0.14.1

LLMs:
* transformers == 4.39.3
* peft == 0.10.0
* datasets == 2.18.0
* torch == 2.2.2
* tqdm == 4.66.2
* numpy == 1.24.4
* evaluate == 0.4.1
### Dataset
 * Download the datasets (CIFAR-10, SVHN) and set the directory in --path. 

# Run Code

LLMs:
Cd FedMemo/code
python main.py --lr=0.01


