Keywords: OOD Generalization
Abstract: Out-of-Distribution (OOD) generalization is a central challenge in machine learning. Models often fail on unseen data, not because of an inability to learn robust signals, but because they $\textit{preferentially learn spurious, dataset-specific correlations that are highly predictive for in-distribution examples}$. Existing solutions typically focus on searching for invariant features, yet often overlook a more fundamental question: $\textbf{what properties of the training data cause models to learn these non-invariant ``shortcut" features in the first place?}$ In this work, we present a different perspective on OOD generalization. We argue that failures to generalize are a direct consequence of models learning the strongest features in the training data, which are often spurious. Guided by this, we reframe OOD generalization not as a search for invariance, but as the $\textit{problem of identifying and mitigating the influence of these overly dominant features}$. Under this new perspective, we develop a novel primitive for quantifying feature strength across a training set. This primitive gives rise to a targeted regularization algorithm that weakens a model's reliance on the identified strongest features, thereby compelling it to learn more robust and causally stable signals. Our method demonstrates substantial improvements in generalization across a wide range of OOD benchmarks, improving OOD accuracy by up to $2\times$ over standard training and significantly outperforming existing baselines without compromising in-distribution performance.
Primary Area: interpretability and explainable AI
Submission Number: 8132
Loading