Keywords: Sharpness aware minimization, efficient learning, generalization, supervised learning, optimization
TL;DR: We parallelize the two gradient computations in SAM, resulting in both improved efficiency and generalization of the model.
Abstract: Sharpness-aware minimization (SAM) has been shown to improve the generalization of neural networks. However, each SAM update requires _sequentially_ computing two gradients, effectively doubling the per-iteration cost compared to base optimizers like SGD. We propose a simple modification of SAM, termed SAMPa, which allows us to fully parallelize the two gradient computations. SAMPa achieves a twofold speedup of SAM under the assumption that communication costs between devices are negligible. Empirical results show that SAMPa ranks among the most efficient variants of SAM in terms of computational time. Additionally, our method consistently outperforms SAM across both vision and language tasks. Notably, SAMPa theoretically maintains convergence guarantees even for _fixed_ perturbation sizes, which is established through a novel Lyapunov function. We in fact arrive at SAMPa by treating this convergence guarantee as a hard requirement---an approach we believe is promising for developing SAM-based methods in general. Our code is available at https://github.com/LIONS-EPFL/SAMPa.
Primary Area: Optimization for deep networks
Flagged For Ethics Review: true
Submission Number: 11892
Loading