Normalization Layer Per-Example Gradients are Sufficient to Predict Gradient Noise Scale in Transformers

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 posterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Efficient deep learning, gradient noise scale, critical batch size, language models
TL;DR: While using a trick to compute per-example gradients efficiently we discover that normalization layers statistics predict GNS accurately.
Abstract: Per-example gradient norms are a vital ingredient for estimating gradient noise scale (GNS) with minimal variance. Observing the tensor contractions required to compute them, we propose a method with minimal FLOPs in 3D or greater tensor regimes by simultaneously computing the norms while computing the parameter gradients. Using this method we are able to observe the GNS of different layers at higher accuracy than previously possible. We find that the total GNS of contemporary transformer models is predicted well by the GNS of only the normalization layers. As a result, focusing only on the normalization layer, we develop a custom kernel to compute the per-example gradient norms while performing the LayerNorm backward pass with zero throughput overhead. Tracking GNS on only those layers, we are able to guide a practical batch size schedule that reduces training time by 18% on a Chinchilla-optimal language model.
Supplementary Material: zip
Primary Area: Optimization for deep networks
Submission Number: 21305
Loading