Robust gradient estimation in the presence of heavy-tailed noise
Keywords: median, heavy-tailed gradients, clipping
TL;DR: How clipping is related to robust median gradient estimators
Abstract: In applications such as training transformers on NLP tasks, or distributed learning in the presence of corrupted nodes, the stochastic gradients have a heavy-tailed distribution. We argue that in these settings, momentum is not the best suited method for estimating the gradient. Instead, variants of momentum with different forms of clipping are better suited. Our argument is based on the following: in the presence of heavy tailed noise the sample median of the gradient is a better estimate than the sample mean. We then devise new iterative methods for computing the sample median on the fly based on the SPP (stochastic proximal point) method. These SPP methods applied to different definitions of median give rise to known and new type of clipped momentum estimates. We find that these clipped momentum estimates are more robust at estimating the gradient in the presence of noise coming from an alpha-stable distribution, and for a transformer architecture on the PTB and Wikitext-2 datasets, in particular when the batch size is large.
Submission Number: 16