Fed3R: Recursive Ridge Regression for Federated Learning with strong pre-trained models

Published: 28 Oct 2023, Last Modified: 14 Dec 2023FL@FM-NeurIPS’23 PosterEveryoneRevisionsBibTeX
Student Author Indication: Yes
Keywords: Federated Learning, Ridge Regression, Random Features, Statistical Heterogeneity, Client Drift, Destructive Interference, Pre-trained models
TL;DR: Recursive Ridge Regression enables fast convergence in Federated Learning with pre-trained models
Abstract: Current Federated Learning (FL) methods often struggle with high statistical heterogeneity across clients' data, resulting in client drift due to biased local solutions. This issue is particularly pronounced in the final classification layer, negatively impacting convergence speed and accuracy throughout model aggregation. To overcome these challenges, we introduce Federated Recursive Ridge Regression (Fed3R). Our method replaces the softmax classifier with a ridge regression-based one computed in a closed form, ensuring robustness to statistical heterogeneity and drastically reducing convergence time and communication costs. When the feature extractor is fixed, the incremental formulation of Fed3R is equivalent to the exact centralized solution. Thus, Fed3R enables higher-capacity pre-trained feature extractors with better predictive performance, incompatible with previous FL techniques, because no backpropagation is required through the feature extractor, and only a few rounds are needed to converge. We propose Fed3R in three variants, with Fed3R-RF significantly enhancing performance to levels akin to centralized training while remaining competitive regarding the total communication costs.
Submission Number: 55