Abstract: Machine learning models are known to learn spurious correlations, i.e., features that have strong correlations with class labels but no causal relationship. Relying on these correlations leads to poor performance in data groups that do not contain these correlations, and poor generalization. Approaches to mitigate spurious correlations either rely on the availability of group annotations or require access to different model checkpoints to approximate these group annotations. We propose PruSC, a method for extracting a spurious-free subnetwork from a dense network. PruSC does not require prior knowledge of the spurious correlations and is able to mitigate the effect of multiple spurious attributes. Specifically, we observe that ERM training leads to clusters in representation space that are induced by spurious correlations. We then define a supervised contrastive loss to extract a subnetwork that distorts such clusters, forcing the model to learn only class-specific clusters, rather than attribute-class specific clusters. Our method outperforms all annotation-free methods, achieves worst-group accuracy competitive with methods that require annotations and can mitigate the effect of multiple spurious correlations. Our results show that in a fully trained dense network, there exists a subnetwork that uses only invariant features in classification tasks, thereby eliminating the influence of spurious features.
Submission Length: Long submission (more than 12 pages of main content)
Assigned Action Editor: ~Gintare_Karolina_Dziugaite1
Submission Number: 3271
Loading