Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness

TMLR Paper2577 Authors

23 Apr 2024 (modified: 18 Jun 2024)Under review for TMLREveryoneRevisionsBibTeXCC BY-SA 4.0
Abstract: Neural networks (NNs) are known to exhibit simplicity bias where they tend to prefer learning 'simple' features over more 'complex' ones, even when the latter may be more informative. Simplicity bias can lead to the model making biased predictions which have poor out-of-distribution (OOD) generalization and subgroup robustness. To address this, we propose a hypothesis about spurious features that directly connects to simplicity bias: we hypothesize that spurious features on many datasets are simple features that are still predictive of the label. We empirically validate this hypothesis, and subsequently develop a framework which leverages this hypothesis to learn more robust models. In our proposed framework, we first train a simple model, and then regularize the conditional mutual information with respect to it to obtain the final model. We theoretically study the effect of this regularization and show that it provably reduces reliance on spurious features in certain settings. We also empirically demonstrate the effectiveness of this framework in various problem settings and real-world applications, showing that it effectively addresses simplicity bias and leads to more features being used, enhances OOD generalization, and improves subgroup robustness and fairness.
Submission Length: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Pavel_Izmailov1
Submission Number: 2577
Loading