TL;DR: We propose a theoretically principled optimization objective along with its algorithms that generalizes SAM and achieves better performance.
Abstract: Sharpness-Aware Minimization (SAM) has been demonstrated to improve the generalization performance of overparameterized models by seeking flat minima on the loss landscape through optimizing model parameters that incur the largest loss within a neighborhood. Nevertheless, such min-max formulations are computationally challenging especially when the problem is highly non-convex. Additionally, focusing only on the worst-case local solution while ignoring potentially many other local solutions may be suboptimal when searching for flat minima. In this work, we propose Tilted SAM (TSAM), a smoothed generalization of SAM inspired by exponential tilting that effectively assigns higher priority to local solutions that incur larger losses. TSAM is parameterized by a tilt hyperparameter $t$ and reduces to SAM as $t$ approaches infinity. We show that TSAM is smoother than SAM and thus easier to optimize, and it explicitly favors flatter minima. We develop algorithms motivated by the discretization of Hamiltonian dynamics to solve TSAM. Empirically, TSAM arrives at flatter local minima and results in superior test performance than the baselines of SAM and ERM across a range of image and text tasks.
Lay Summary: Sharpness-Aware Minimization (SAM) is a technique that improves the performance of deep learning models by finding "flat" areas on the loss landscape---regions where small changes to the model parameters don't dramatically increase the loss. However, SAM can be computationally challenging because it focuses only on the worst-case scenarios in a small neighborhood of parameters, making optimization difficult, especially when the model's loss landscape is complex.
We introduce a new approach called Tilted SAM (TSAM), inspired by a method called "exponential tilting." TSAM smooths out the optimization by assigning greater importance to areas with higher losses, rather than just focusing on the absolute worst-case. This makes it easier to find flatter minima, potentially improving model performance and making optimization smoother and less challenging. We develop new algorithms to efficiently solve TSAM and demonstrate that it achieves better results than standard SAM and its variants in various image and text tasks.
Primary Area: Optimization
Keywords: sharpness-aware optimization, exponential tilting, generalization
Submission Number: 7006
Loading