Abstract: Deep neural networks trained by minimizing the average risk can achieve strong average performance. Still, their performance for a subgroup may degrade if the subgroup is underrepresented in the overall data population. Group distributionally robust optimization (Sagawa et al., 2020a), or group DRO in short, is a widely used baseline for learning models with strong worst-group performance. We note that this method requires group labels for every example at training time and can overfit to small groups, requiring strong regularization. Given a limited amount of group labels at training time, Just Train Twice (Liu et al., 2021), or JTT in short, is a two-stage method that infers a pseudo group label for every unlabeled example first, then applies group DRO based on the inferred group labels. The inference process is also sensitive to overfitting, sometimes involving additional hyperparameters. This paper designs a simple method based on the idea of classifier retraining on independent splits of the training data. We find that using a novel sample-splitting procedure achieves robust worst-group performance in the fine-tuning step. When evaluated on benchmark image and text classification tasks, our approach consistently performs favorably to group DRO, JTT, and other strong baselines when either group labels are available during training or are only given in validation sets. Importantly, our method only relies on a single hyperparameter, which adjusts the fraction of labels used for training feature extractors vs. training classification layers. We justify the rationale of our splitting scheme with a generalization-bound analysis of the worst-group loss.
Submission Length: Long submission (more than 12 pages of main content)
Changes Since Last Submission: Following up on the reviewers' suggestions, we have made the following edits to the paper.
- Rephrased/removed some sentences to be more precise and fixed grammatical errors in some sentences.
- Added more references to support that Group DRO is a standard baseline for the group shift setting.
- Added the full name in front of every acronym the first time it is used across multiple sections to improve readability. We also use the full name instead the acronym in most places now and cite again the source of the name of the method or acronym to improve readability.
- Revised section 3.2 to better explain the motivations behind our method. In particular, we have added a paragraph discussing in more detail the different ways to utilize group labels and the role of classifier retraining.
- We have also added a broader impact statement that discusses the relationship between our work and other works in areas like fairness.
- We have revised section 5.3 and the introduction to reflect the limitations of our theoretical results better.
Code: https://github.com/timmytonga/crois
Supplementary Material: zip
Assigned Action Editor: ~Aditya_Menon1
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Number: 1170
Loading