- Keywords: Federated Learning, Unseen Client Generalization, Structural Causal Model
- Abstract: Generalizing federated learning (FL) models to unseen clients with non-iid data is a crucial topic, yet unsolved so far. In this work, we propose to tackle this problem from a novel causal perspective. Specifically, we form a training structural causal model (SCM) to explain the challenges of model generalization in a distributed learning paradigm. Based on this, we present a simple yet effective method using test-specific and momentum tracked batch normalization (TsmoBN) to generalize FL models to testing clients. We give a causal analysis by formulating another testing SCM and demonstrate that the key factor in TsmoBN is the test-specific statistics (i.e., mean and variance) of features. Such statistics can be seen as a surrogate variable for causal intervention. In addition, by considering generalization bounds in FL, we show that our TsmoBN method can reduce divergence between training and testing feature distributions, which achieves a lower generalization gap than standard model testing. Our extensive experimental evaluations demonstrate significant improvements for unseen client generalization on three datasets with various types of distribution shifts and numbers of clients. It is worth noting that our proposed approach can be flexibly applied to different state-of-the-art federated learning algorithms and is orthogonal to existing domain generalization methods.
- One-sentence Summary: We tackle the unseen client generalization problem via a new causal view for federated learning, which implicates using test-specific and momentum tracked BN as causal intervention and achieves significant generalizability improvement for FL methods.