Sharpness-Aware Minimization Leads to Low-Rank Features

Published: 21 Sept 2023, Last Modified: 02 Nov 2023NeurIPS 2023 posterEveryoneRevisionsBibTeX
Keywords: sharpness-aware minimization, low-rank features, understanding feature learning
TL;DR: SAM leads to features of lower rank in a variety of settings: from small MLPs to vision transformers, from regression to multimodal contrastive learning
Abstract: Sharpness-aware minimization (SAM) is a recently proposed method that minimizes the sharpness of the training loss of a neural network. While its generalization improvement is well-known and is the primary motivation, we uncover an additional intriguing effect of SAM: reduction of the feature rank which happens at different layers of a neural network. We show that this low-rank effect occurs very broadly: for different architectures such as fully-connected networks, convolutional networks, vision transformers and for different objectives such as regression, classification, language-image contrastive training. To better understand this phenomenon, we provide a mechanistic understanding of how low-rank features arise in a simple two-layer network. We observe that a significant number of activations gets entirely pruned by SAM which directly contributes to the rank reduction. We confirm this effect theoretically and check that it can also occur in deep networks, although the overall rank reduction mechanism can be more complex, especially for deep networks with pre-activation skip connections and self-attention layers.
Submission Number: 456