Cross-Training with Prototypical Distillation for improving the generalization of Federated Learning
Abstract: Cross-training has become a promising strategy to handle data heterogeneity problem in federated learning, which re-train a local model across different clients to improve its generalization capability in a privacy-preserving manner. Its main idea is to make the local models to fit the data of all clients. However, the heterogeneity between data sources may lead the local models to quickly forget the knowledge learned in several rounds of cross-training. To address the problem, this paper presents a novel prototype guided cross training mechanism, termed PGCT, to regularize the change of class-level data representations across clients. It includes two main modules, where the prototype guided representation learning module employs client-aware prototypes of data patterns learned by clustering to guide the learning of consistency representation across feature spaces. This maintains the similar decision boundary across different clients. The prototype-based feature augmentation module uses prototypes as soft attention regularizers to further aggregate rich information to enhance the discrimination of historical features. Experiments were conducted on four datasets in terms of performance comparison, ablation study and case study, and the results verified that PGCT can learn discriminative features with different classes under the guidance of prototypes, which leads to better performance than the state-of-the-art methods.
Loading