Efficient and Approximate Per-Example Gradient Norms for Gradient Noise Scale

Published: 28 Oct 2023, Last Modified: 30 Nov 2023WANT@NeurIPS 2023 PosterEveryoneRevisionsBibTeX
Keywords: gradient norms, hyperparameter elimination, training metrics, foundation models
TL;DR: The efficient per-example gradient norm trick fails for 3D tensors so we develop an approximation to use instead for computing GNS..
Abstract: Gradient Noise Scale (GNS) is valuable to compute because it provides a suggestion for a compute efficient batch size during training: small enough to be compute efficient and large enough to take advantage of parallelism. While it can be a valuable tool, computing GNS is often cumbersome or expensive due to the difficulty of obtaining gradient norms over a small batch of examples (smaller than the training batch used). An existing trick for collecting “efficient” per-example gradient norms is inefficient in transformer or convolutional models. By assuming activations are normally distributed, we compute an approximate per-example gradient norm that tracks the true per-example gradient norm in practical settings. Using this approximation, we construct a Scaled Output Gradient Noise Scale (SOGNS) that is generally applicable at negligible cost and provides additional feedback to the practitioner during training.
Submission Number: 48