TRAM: Bridging Trust Regions and Sharpness Aware Minimization

Published: 16 Jan 2024, Last Modified: 12 Mar 2024ICLR 2024 spotlightEveryoneRevisionsBibTeX
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: sharpness-aware minimization, sam, trust region, optimization, cross-lingual transfer, language modeling
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: We propose a trust region motivated variant of SAM which scales learning with a trust region constraint to jointly optimize for low sharpness parameters and low curvature representations.
Abstract: Sharpness-aware minimization (SAM) reports improving domain generalization by reducing the loss surface curvature in the parameter space. However, generalization during _fine-tuning_ is often more dependent on the transferability of _representations_ in the function space. Trust-region methods (TR) target this goal by regularizing representation curvature to reduce catastrophic forgetting of pre-trained task-agnostic information while adopting task-specific skills. We consider unifying these strategies for low curvature in both parameter space and function space to improve out-of-domain (OOD) generalization. We propose **Trust Region Aware Minimization** (TRAM), a SAM algorithm fine-tuning for low parameter sharpness and smooth, informative representations preserving pre-trained structure. TRAM uses a trust region bound to inform the SAM adversarial neighborhood, introducing an awareness of function curvature within optimization for flatter minima. We empirically validate TRAM in vision (cross-dataset adaptation) and text (OOD language modeling, zero-shot cross-lingual transfer) tasks where robust domain transfer and representation generality are critical. TRAM outperforms SAM- and TR-based optimization across all tasks, notably surpassing competing methods for hard transfer between _anticorrelated_ domains. TRAM establishes a novel standard in fine-tuning for domain-generalizable models with minimal additional computation over previous sharpness-aware methods.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Primary Area: optimization
Submission Number: 5847
Loading