Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling

Published: 16 Jun 2024, Last Modified: 19 Jul 2024HiLD at ICML 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Deep Learning Theory, Optimal Hyperparameter Transfer, Sharpness Aware Minimization, Infinite Width Limits, Signal Propagation Theory, Tensor Programs
Abstract: Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behavior is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a unique parameterization with layerwise perturbation scaling, which we call maximal update and perturbation parameterization ($\mu$P$^2$), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that $\mu$P$^2$ is the only parameterization to achieve hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. Moreover, we provide an intuitive condition to derive $\mu$P$^2$ for other perturbation rules like Adaptive SAM and SAM-ON, also ensuring balanced perturbation effects across all layers.
Student Paper: Yes
Submission Number: 16
Loading