Improving Generalization in Federated Learning with Highly Heterogeneous Data via Momentum-Based Stochastic Controlled Weight Averaging
Abstract: For federated learning (FL) algorithms such as FedSAM, their generalization capability is crucial for real-word applications. In this paper, we revisit the generalization problem in FL and investigate the impact of data heterogeneity on FL generalization. We find that FedSAM usually performs worse than FedAvg in the case of highly heterogeneous data, and thus propose a novel and effective federated learning algorithm with Stochastic Weight Averaging (called \texttt{FedSWA}), which aims to find flatter minima in the setting of highly heterogeneous data. Moreover, we introduce a new momentum-based stochastic controlled weight averaging FL algorithm (\texttt{FedMoSWA}), which is designed to better align local and global models.
Theoretically, we provide both convergence analysis and generalization bounds for \texttt{FedSWA} and \texttt{FedMoSWA}. We also prove that the optimization and generalization errors of \texttt{FedMoSWA} are smaller than those of their counterparts, including FedSAM and its variants. Empirically, experimental results on CIFAR10/100 and Tiny ImageNet demonstrate the superiority of the proposed algorithms compared to their counterparts.
Lay Summary: (1) Problem: Federated learning (FL) is a powerful method for training machine learning models across multiple devices without sharing data, but it struggles with data heterogeneity, leading to poor generalization.
(2) Solution: We propose two new algorithms, FedSWA and FedMoSWA, which use Stochastic Weight Averaging and momentum-based techniques to find flatter minima in the loss landscape, improving generalization in highly heterogeneous data settings.
(3) Impact: Our methods enhance the practical applicability of FL by improving model performance and reducing the impact of data heterogeneity, making federated learning more effective in real-world applications like healthcare and finance.
Link To Code: https://github.com/junkangLiu0/FedSWA
Primary Area: Optimization->Large Scale, Parallel and Distributed
Keywords: Federated Learning,Generalization,Stochastic Weight Averaging
Submission Number: 11169
Loading