Why Adam Outperforms Gradient Descent on Language Models: A Heavy-Tailed Class Imbalance Problem

Published: 26 Oct 2023, Last Modified: 13 Dec 2023NeurIPS 2023 Workshop PosterEveryoneRevisionsBibTeX
Keywords: optimization, heavy-tailed, class imbalance, Adam, adaptive methods, sign descent, language models, transformers
TL;DR: We provide experimental evidence that gradient descent struggles to fit classification problems with heavy-tailed imbalanced classes, such as those found in language modelling tasks, while Adam and sign-like methods well.
Abstract: We show that the heavy-tailed class imbalance found in language modeling tasks leads to difficul- ties in optimization dynamics. When training with gradient descent, the loss associated with low frequency classes decreases slower than the loss associated with high frequency classes. Under the heavy-tailed class imbalance found in language modeling tasks, most samples are from classes of low relative frequency, leading to overall slow decreasing on the average loss. Sign-based optimizers such as Adam and sign descent do not suffer from this problem, and lead to decrease on all classes. We give evidence of this behavior on training for a 2-layer transformer on language data, a linear model on synthetic data whose only property is a heavy-tailed class distribution, and a convolutional network on a modified MNIST dataset made to exhibit heavy-tailed class imbalance.
Submission Number: 41
Loading