Understanding and Improving Feature Learning for Out-of-Distribution Generalization

Published: 10 Mar 2023, Last Modified: 28 Apr 2023ICLR 2023 Workshop DG PosterEveryoneRevisions
Keywords: Out-of-Distribution Generalization, Feature Learning, Invariant Learning, Causality
TL;DR: We theoretically reveal that ERM learns both invariant and spurious features and propose a new algorithm to learn richer features than ERM for facilitating OOD generalization.
Abstract: A common explanation for the failure of out-of-distribution (OOD) generalization is that the model trained with empirical risk minimization (ERM) learns spurious features instead of the desired invariant features. However, several recent studies challenged this explanation and found that deep networks may have already learned sufficiently good features for OOD generalization. To understand these seemingly contradicting phenomena, we conduct a theoretical investigation and find that ERM learns both spurious features and invariant features. On the other hand, the quality of learned features during ERM pre-training significantly affects the final OOD performance, as OOD objectives rarely learn new features. Failing to capture all the useful features during pre-training will further limit the final OOD performance. To remedy the issue, we propose Feature Augmented Training (FAT), to enforce the model to learn all useful features by retaining the already learned features and augmenting new ones multiple rounds. In each round, the retention and augmentation are performed on different subsets of the training data that capture distinct features. Extensive experiments show that FAT learns richer features and consistently improves the OOD performance of various objectives
Submission Number: 5
Loading