How to set AdamW's weight decay as you scale model and dataset size

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
Abstract: The scaling of the optimal AdamW weight decay hyperparameter with model and dataset size is critical as we seek to build larger models, but is poorly understood. We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. We find that the optimal timescale, measured in epochs, is roughly constant as we change model and dataset size. Moreover, given a learning rate, there is a one-to-one mapping from the EMA timescale to the weight decay hyperparameter. Thus, if the optimal EMA timescale is constant, that implies that as the dataset size increases, the optimal weight decay should fall and as the model size increases, the optimal weight decay should increase (if we follow the muP recommendation for scaling the learning rate). We validate these scaling rules on ResNet-18 and Vision Transformers trained on CIFAR-10 and ImageNet, and on NanoGPT pre-training on OpenWebText. Finally, we found that as training progresses, muP's learning rate scaling breaks down for AdamW unless weight decay is scaled appropriately.
Lay Summary: We all know about mu-P for learning rates. What about the weight decay? Turns out that the weight decay is actually **more** important for pre-training over long timescales. We show how the weight decay scales with model and dataset size
Primary Area: Deep Learning
Keywords: Hyperparameter transfer, AdamW, Optimization
Submission Number: 3340
Loading