Optimising Clinical Federated Learning through Mode Connectivity-based Model Aggregation
TL;DR: The paper proposes a mode connectivity-based Federated Learning framework that improves convergence and ensures equitable performance across clients in non-IID healthcare settings by aligning the global model with low-loss regions for all clients.
Abstract: Federated Learning (FL) involves a server aggregating local models from clients to compute a global model. However, this process can struggle to position the global model in low-loss regions of the parameter space for all clients, resulting in subpar convergence and inequitable performance across clients. This issue is particularly pronounced in non-IID settings, common in clinical contexts, where variations in data distribution, class imbalance, and training sample sizes result in client heterogeneity. To address this issue, we propose a mode connectivity-based FL framework that ensures the global model resides within the overlapping low-loss regions of all clients in the parameter space. This framework models the low-loss regions as non-linear mode connections between the current global and local models, and optimises to identify an intersection among these mode connections to define the new global model. This approach enhances training stability and convergence, yielding better and more equitable performance compared to standard FL frameworks like federated averaging. Empirical evaluations across multiple healthcare datasets demonstrate the benefits of the proposed framework.
Submission Number: 63
Loading