Correct-N-Contrast: a Contrastive Approach for Improving Robustness to Spurious CorrelationsDownload PDF

Published: 28 Jan 2022, Last Modified: 13 Feb 2023ICLR 2022 SubmittedReaders: Everyone
Keywords: spurious correlations, contrastive learning, robustness, group shifts
Abstract: Spurious correlations pose a fundamental challenge for building robust machine learning models. For example, models trained with empirical risk minimization (ERM) may depend on correlations between class labels and spurious features to classify data, even if these relations only hold for certain data groups. This can result in poor performance on other groups that do not exhibit such relations. When group information is available during training, Sagawa et al. (2019) have shown how to improve worst-group performance by optimizing the worst-group loss (GDRO). However, when group information is unavailable, improving worst-group performance is more challenging. For this latter setting, we propose Correct-N-Contrast (CNC), a contrastive learning method to train models more robust to spurious correlations. Our motivating observation is that worst-group performance is related to a representation alignment loss, which measures the distance in feature space between different groups within each class. We prove that the gap between worst-group and average loss for each class is upper bounded by the alignment loss for that class. Thus, CNC aims to improve representation alignment via contrastive learning. First, CNC uses an ERM model to infer the group information. Second, with a careful sampling scheme, CNC trains a contrastive model to encourage similar representations for groups in the same class. We show that CNC significantly improves worst-group accuracy over existing state-of-the-art methods on popular benchmarks, e.g., achieving $7.7\%$ absolute lift in worst-group accuracy on the CelebA data set, and performs almost as well as GDRO trained with group labels. CNC also learns better-aligned representations between different groups in each class, reducing the alignment loss substantially compared to prior methods.
One-sentence Summary: We propose Correct-N-Contrast, a contrastive learning method to substantially improve neural network robustness to spurious correlations.
Supplementary Material: zip
30 Replies

Loading