G-TRACER: Expected Sharpness Optimization

TMLR Paper2530 Authors

16 Apr 2024 (modified: 17 Sept 2024)Rejected by TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: We propose a new regularization scheme for the optimization of deep learning architectures, G-TRACER ("Geometric TRACE Ratio"), which promotes generalization by seeking minima with low mean curvature, and which has a sound theoretical basis as an approximation to a natural gradient-descent based optimization of a generalized variational objective. By augmenting the loss function with a G-TRACER penalty, which can be interpreted as the metric trace of the Hessian (the Laplace-Beltrami operator) with respect to the Fisher information metric, curvature-regularized optimizers (e.g. SGD-TRACER and Adam-TRACER) are simple to implement as modifications to existing optimizers and do not require extensive tuning. We show that the method can be interpreted as penalizing, in the neighborhood of a minimum, the difference between the mean value of the loss and the value at the minimum, in a way that adjusts for the natural geometry of the parameter space induced by the KL divergence. We show that the method converges to a neighborhood (depending on the regularization strength) of a local minimum of the unregularized objective, and demonstrate promising performance on a number of benchmark computer vision and NLP datasets, with a particular focus on challenging problems characterized by a low signal-to-noise ratio, or an absence of natural data augmentations and other regularization schemes.
Submission Length: Long submission (more than 12 pages of main content)
Previous TMLR Submission Url: https://openreview.net/forum?id=OBijPYcL9u
Changes Since Last Submission: The reviewers of the original submission helpfully highlighted a number key areas of improvement, with which we fully agree and which we have addressed is this significantly revised manuscript, as follows: 1. Clarity and structure of the paper 2. Better explanation of the theory and of the connection between flatness and the proposed regularization scheme. 3. Full details of experiment hyperparameters, comparison with existing benchmarks, demonstration that the models are well-trained, and further experimental evidence 4. Concerns about scalability of the inverse Hessian approximation 5. Improved literature review 6. Overall claims Taking each of these points in turn: 1. We have completely reworked the presentation to clarify the material and aid understanding, also adding derivation sketches and moving the detailed derivations to appendices in order to simplify the presentation and better communicate the key ideas. 2. We have set out in detail and elaborated upon the theory linking geometric flatness and the G-Tracer regularization scheme, in particular, highlighting its connection to the Laplace-Beltrami Operator and, in particular how, at a critical point, $\mathrm{Tr}({G}^{-1}H)$ is the Laplace-Beltrami operator $\Delta$ (also known as the manifold Laplacian) which generalizes the Laplacian to Riemannian manifolds, and defines an invariant, geometric quantity which, by analogy with $\text{Tr}(H)$ in Euclidean space, measures the average deviation from flatness, adjusting for the curvature of the manifold. Crucially, this is not an assumption, but rather emerges naturally from the variational optimization of a generalized Bayes objective using the KL-metric. In particular, for a multivariate Gaussian variational approximation, the trace penalty corresponds to a smoothing of the loss surface using a kernel estimated online. 3. While our paper is primarily theory-driven, we have significantly reworked the experiments and results section, added additional experiments and results on vision transformers (ViTs), added a benchmarking cross-check with a Resnet-18 architecture and publicly available results, and established consistency with results in the SAM literature. The exact experimental setting for all tasks is fully set-out and follows standard best practice. In particular, we show that results for ResNet-20 are in line with (in fact, competitive with) the results in the key papers [1] and [2] . As a further consistency check with practice, follow the training protocol (stepwise learning rate decay over 200 episodes, with learning rates $[.1, .02, .004,.0008]$ at $[0, 60,120,160]$) in https://github.com/weiaicunzai/pytorch-cifar100/tree/master?tab=readme-ov-file with larger architectures, eg ResNet-18 (11M parameters), and match the expected results for SGD, and see similar improvements vs SGD (75.8\% accuracy vs 75.1\% accuracy) and SAM (75.3\% accuracy, $\rho=.05$). Although we had grouped them together, we highlight the fact that the 3 tasks in NLP are distinct, intricate, and challenging (they are incorporated into the challenging SuperGLUE benchmark). Similarly, the baseline and noisy CIFAR-100 problems are distinct and influence the relative advantages of each method. We have updated this section to reflect additional experiments and the diversity of tasks performed, together with full details on the experimental settings and hyperparameter choices. 4. We have reworked the presentation to further emphasize that the G-TRACER scheme is based an a diagonal Empirical Fisher approximation which is very cheap to compute and that the algorithm presents no scalability issues. 5. We have added an extensive literature review of the literature on flatness as well as SAM variants, and benchmarked our baseline results against those in the key papers [1] and [2] enabling comparisons to be made against other SAM variants. 6. We have updated the claims to emphasize that our particular focus is not to show that our method is uniformly better than SAM, but that it is a practically relevant, competitive, principled regularization with a sound theoretical motivation, which boosts performance on challenging low signal-noise ratio problems (where it exhibits strong performance vs baseline methods and SAM), in the absence of data augmentations and other regularizations, where our "model" for this setting is challenging variants of known benchmarks, eg training vision transformers (ViT) from scratch (no pretraining), heavily noise corrupted CIFAR-100 with no augmentation, as well as 3 challenging NLP tasks in Bert fine-tuning (BoolQ, WIC, RTE). Establishing that the method is competitive wtih SAM across a very broad range of tasks (including large-scale tasks) will be dealt with in a follow-up purely experimental paper. [1] Möllenhoff et. al, Sam as an optimal relaxation of Bayes, 2023. [2] Kwon, J et. al, Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks, 2021
Assigned Action Editor: ~Lechao_Xiao2
Submission Number: 2530
Loading