Learning an Invertible Output Mapping Can Mitigate Simplicity Bias in Neural NetworksDownload PDF

Published: 21 Oct 2022, Last Modified: 05 May 2023NeurIPS 2022 Workshop DistShift PosterReaders: Everyone
Keywords: Simplicity Bias, Out-of-distribution robustness, OOD Generalization, Deep Learning
TL;DR: We propose a regularizer that enforces the reconstruction of features from the output logits of neural networks, in order to overcome Simplicity Bias and boost their OOD generalization.
Abstract: Deep Neural Networks (DNNs) are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that \emph{Simplicity Bias} (SB) of DNNs -- bias towards learning only the simplest features -- is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term \emph{Feature Replication Hypothesis}, coupled with the \emph{Implicit Bias} of SGD to converge to maximum margin solutions in the feature space, leads the models to rely mostly on the simple features for classification. To mitigate this bias, we propose \emph{Feature Reconstruction Regularizer (FRR)} to ensure that the learned features can be reconstructed back from the logits. The use of \emph{FRR} in linear layer training (\emph{FRR-L}) encourages the use of more diverse features for classification. We further propose to finetune the full network by freezing the weights of the linear layer trained using \emph{FRR-L}, to refine the learned features, making them more suitable for classification. Using the proposed approach, we demonstrate noteworthy gains on synthetic/ semi-synthetic datasets, and outperform existing SOTA on the standard OOD benchmark DomainBed as well.
1 Reply