Causality Inspired Federated Learning for OOD Generalization

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
Abstract: The out-of-distribution (OOD) generalization problem in federated learning (FL) has recently attracted significant research interest. A common approach, derived from centralized learning, is to extract causal features which exhibit causal relationships with the label. However, in FL, the global feature extractor typically captures only invariant causal features shared across clients and thus discards many other causal features that are potentially useful for OOD generalization. To address this problem, we propose FedUni, a simple yet effective architecture trained to extract all possible causal features from any input. FedUni consists of a comprehensive feature extractor, designed to identify a union of all causal feature types in the input, followed by a feature compressor, which discards potential \textit{inactive} causal features. With this architecture, FedUni can benefit from collaborative training in FL while avoiding the cost of model aggregation (i.e., extracting only invariant features). In addition, to further enhance the feature extractor's ability to capture causal features, FedUni add a causal intervention module on the client side, which employs a counterfactual generator to generate counterfactual examples that simulate distributions shifts. Extensive experiments and theoretical analysis demonstrate that our method significantly improves OOD generalization performance.
Lay Summary: Deep learning models often struggle when facing unfamiliar data, especially in real-world settings where information is collected from many different sources. This is a major challenge for federated learning, a method where many devices train a shared model without sharing their raw data. One popular approach is to teach the model to focus on "causal features" , which are parts of the input that truly affect the outcome and are more stable across different environments. However, existing methods usually extract only the features shared by all devices, which ignores many useful ones. To solve this, we propose a new method called FedUni. It teaches the model to collect all possible causal features across devices, and then selectively filter out the ones that are irrelevant in new situations. We also introduce a way for each device to simulate "what-if" versions of its data, helping the model better understand cause and effect. Our experiments show that this leads to much more reliable performance when the model sees unfamiliar data. By making deep learning models more flexible and robust, this research can help ensure safer and more trustworthy AI systems in real-world applications.
Primary Area: Optimization->Large Scale, Parallel and Distributed
Keywords: federated learning, causality, out-of-distribution generalization
Flagged For Ethics Review: true
Submission Number: 9705
Loading