Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: This paper proposes a prototype-based ensemble method to improve robustness against subpopulation shifts without requiring subgroup labels, outperforming state-of-the-art methods across several benchmark datasets.
Abstract: Subpopulation shift, characterized by a disparity in subpopulation distribution between the training and target datasets, can significantly degrade the performance of machine learning models. Current solutions to subpopulation shift involve modifying empirical risk minimization with re-weighting strategies to improve generalization. This strategy relies on assumptions about the number and nature of subpopulations and annotations on group membership, which are unavailable for many real-world datasets. Instead, we propose using an ensemble of diverse classifiers to adaptively capture risk associated with subpopulations. Given a feature extractor network, we replace its standard linear classification layer with a mixture of prototypical classifiers, where each member is trained to classify the data while focusing on different features and samples from other members. In empirical evaluation on nine real-world datasets, covering diverse domains and kinds of subpopulation shift, our method of Diverse Prototypical Ensembles (DPEs) often outperforms the prior state-of-the-art in worst-group accuracy. The code is available at https://github.com/minhto2802/dpe4subpop.
Lay Summary: Machine learning models often struggle when they encounter situations that differ slightly from what they were trained on. This is a major issue when data includes hidden subgroups, such as different types of people, environments, or medical conditions, that are not equally represented. For example, a model trained mostly on healthy patients might not work well on those with rare diseases. Our research introduces a new technique called the Diversified Prototypical Ensemble (DPE) to tackle this problem. Instead of using just one model, we create a group of simple classifiers called prototypes. Each one learns to focus on different patterns or features in the data. We encourage these classifiers to be as different as possible, so together they can cover a broader variety of hidden subgroups. The key benefit of DPE is that it does not require prior knowledge of the subgroups. It can automatically discover and adapt to them using only the data itself. This makes it especially useful in real-world situations where such subgroup labels are missing or hard to define. Across nine challenging datasets, our method consistently outperforms existing solutions and helps make machine learning models more fair and reliable when used in diverse populations.
Application-Driven Machine Learning: This submission is on Application-Driven Machine Learning.
Link To Code: https://github.com/minhto2802/dpe4subpop
Primary Area: Deep Learning->Robustness
Keywords: distribution shift, subpopulation shift, spurious correlation, class imbalance, attribute imbalance
Submission Number: 15471
Loading