Efficient Federated Tumor Segmentation via Parameter Distance Weighted Aggregation and Client Pruning
Abstract: Federated learning has become a popular paradigm to enable multiple distributed clients collaboratively train a model, providing a promising privacy-preserving solution without data sharing. To fully make use of federated training efforts, it is critical to promote the global model performance as well as the generalization capability based on diverse data samples provided in the federated cohort. The Federated Tumor Segmentation (FeTS) Challenge 2022 proposes two tasks for participants to improve the federated training and evaluation. Specifically, task 1 seeks effective weight aggregation methods to create the global model given a pre-defined segmentation algorithm. Task 2 aims to find robust segmentation algorithms which perform well on unseen testing data from various remote independent institutions. In federated learning, the data collected from different institutions present heterogeneity, largely affecting the training behavior. The heterogeneous data results in the variation of clients’ local optimization, therefore making the local client update not consistent with each other. The vanilla weighted average aggregation only takes the number of samples into account but ignores the differences in clients’ updates. As for task 1, we devise a parameter distance-based aggregation algorithm to mitigate the drifts of client updates. On top of this, we further propose a client pruning strategy to reduce the convergence time upon uneven training time among local clients. Our method finally achieves the convergence score of 0.7433 and an average dice score of 71.02% on the validation data, which is split out from the training data. For task 2, we propose to use the nnU-Net as the backbone and utilize the test-time batch normalization, which incorporates test data specific mean and variance to fit the unseen test data distribution during the testing phase.
Loading