Keywords: Spurious Correlations, Disentangled Representation Learning, Attention
TL;DR: The paper proposes Deep Attention Reweighting to replace Global Average Pooling in CNNs, mitigating spurious correlations by disentangling core and spurious features through attention-based feature aggregation.
Abstract: Spurious correlations in datasets result in Convolutional Neural Networks (CNNs) learning features that are predictive but not causally relevant to the task, leading to poor generalization and fairness issues. The recently proposed Deep Feature Reweighting (DFR) technique aims to reduce the reliance of an ERM-trained model on spurious correlations by retraining its classification head on a target dataset, achieving state-of-the-art performance on various spurious correlation benchmarks. However, we find that DFR operates on entangled features, which limits its ability to simultaneously extract the core features while removing the influence of spurious features. Our analysis reveals that this entanglement in CNNs is primarily caused by the commonly used Global Average Pooling (GAP) aggregation mechanism, which indiscriminately collapses information in a feature map across spatial locations into a single feature. To address this, we propose Deep Attention Reweighting (DAR), which replaces the GAP layer with an attention-based aggregation mechanism that adaptively assigns importance to spatial locations of the feature maps, enabling selective suppression of spurious features before they become entangled with the core features. Across various metrics, datasets, and experimental settings, we empirically validate the effectiveness of DAR over DFR in its ability to resolve the feature entanglement between the core and spurious features to better mitigate spurious correlations.
Primary Area: unsupervised, self-supervised, semi-supervised, and supervised representation learning
Submission Number: 11085
Loading