Hide & Seek: Transformer Symmetries Obscure Sharpness & Riemannian Geometry Finds It

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 spotlightposterEveryoneRevisionsBibTeXCC BY 4.0
Abstract: The concept of sharpness has been successfully applied to traditional architectures like MLPs and CNNs to predict their generalization. For transformers, however, recent work reported weak correlation between flatness and generalization. We argue that existing sharpness measures fail for transformers, because they have much richer symmetries in their attention mechanism that induce directions in parameter space along which the network or its loss remain identical. We posit that sharpness must account fully for these symmetries, and thus we redefine it on a quotient manifold that results from quotienting out the transformer symmetries, thereby removing their ambiguities. Leveraging tools from Riemannian geometry, we propose a fully general notion of sharpness, in terms of a geodesic ball on the symmetry-corrected quotient manifold. In practice, we need to resort to approximating the geodesics. Doing so up to first order yields existing adaptive sharpness measures, and we demonstrate that including higher-order terms is crucial to recover correlation with generalization. We present results on diagonal networks with synthetic data, and show that our geodesic sharpness reveals strong correlation for real-world transformers on both text and image classification tasks.
Lay Summary: In deep learning, understanding why some neural networks make better predictions than others is an important problem. One popular idea to explain this is called sharpness. Sharpness looks at the shape of the network’s loss landscape, a kind of landscape showing how good or bad the network is doing depending on small changes in its internal parameters. Generally, if this landscape is “flat,” it means small changes don’t hurt performance much, and the model is more likely to generalize well to data it has not seen before. This idea works well for older types of neural networks like MLPs (multilayer perceptrons) and CNNs (convolutional neural networks). But for transformers this relationship breaks down. Researchers have found that sharpness, as it's usually measured, doesn't reliably predict whether a transformer will generalize well. We argue that the problem isn't with the idea of sharpness itself, but with how it's measured in transformers. Transformers have a lot of ways you can change their internal parameters without actually changing how the model behaves (symmetries). These symmetries confuse traditional sharpness measurements. Using tools from differential geometry, we introduce a more accurate definition of sharpness that takes these symmetries into account, finding that once we correct for these symmetries, sharpness is still a useful concept.
Primary Area: Deep Learning->Theory
Keywords: generalization, symmetry, sharpness, flatness, riemannian geometry, loss landscape
Submission Number: 15939
Loading