['3c3', '< Abstract: Adam has been shown to outperform gradient descent on large language models by a larger margin than on other tasks, but it is unclear why. We show that a key factor in this performance gap is the heavy-tailed class imbalance found in language tasks. When trained with gradient descent, the loss of infrequent words decreases more slowly than the loss of frequent ones. This leads to a slow decrease on the average loss as most samples come from infrequent words. On the other hand, Adam and sign-based methods are less sensitive to this problem. To establish that this behavior is caused by class imbalance, we show empirically that it can be reproduced across architectures and data types, on language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, we show that class imbalance leads to imbalanced, correlated gradients and Hessians that have been hypothesized to benefit Adam. We also prove that, in continuous time, gradient descent converges slowly on low-frequency classes while sign descent does not.', '---', '> Abstract: Adam consistently outperforms gradient descent on large language models, a phenomenon not fully understood. We identify heavy-tailed class imbalance, prevalent in language tasks, as a critical, under-explored factor driving this performance gap. Our work demonstrates that gradient descent struggles with infrequent classes, leading to significantly slower overall convergence, while Adam and sign-based methods are robust to this issue. We provide comprehensive empirical evidence, reproducing this behavior across diverse architectures (language transformers, vision CNNs) and data types, even in simplified linear models. Furthermore, we reveal that class imbalance naturally induces imbalanced and correlated gradients and Hessians in linear models, a property previously hypothesized to benefit Adam. Theoretically, we prove that in continuous time, gradient descent exhibits slow convergence on low-frequency classes, a limitation not shared by sign descent.', '6,8c6,8', '< The recent success of large language models such as GPT-3 (Brown et al., 2020) and its successors has relied on costly training procedures at unprecedented scale. A key ingredient in their training is the Adam optimizer (Kingma and Ba, 2015), which outperforms stochastic gradient descent (SGD) on language problems by a large margin. Despite this large performance gap, we have a poor understanding of why Adam works better and it has been difficult to find new optimizers that consistently improve over Adam (Schmidt et al., 2021). Not only is it computationally difficult to validate new optimizers on large models, but we also lack theoretical guidance; we do not know what "problem" Adam solves to outperform SGD.', '< The success of Adam on language transformers has been well documented. Multiple works have found metrics or statistics that correlate with the improved performance of Adam, showing that it yields uniform updates across parameters despite imbalanced gradients (Liu et al., 2020), gives a better descent direction than the gradient (Pan and Li, 2023), and takes a path over which a robust variant of the condition number is smaller (Jiang et al., 2022). But these observations do not provide a mechanism explaining what property of the problem leads to the improved performance of Adam.', '< Plausible mechanisms have been put forward, but they do not provide a complete explanation. Zhang et al. (2020b) show that Adam-like methods are more resilient to heavy-tailed noise, which seems more prominent in language than in vision tasks. But noise is not the primary cause of the gap, as it already appears in deterministic training (Kunstner et al., 2023). An alternative hypothesis is that the magnitude of the gradient and Hessian are correlated, which justifies clipping (Zhang et al., 2020a). But to justify methods that normalize element-wise, like Adam and sign-like methods, we additionally need the gradient and Hessian to be correlated across parameters (Crawshaw et al., 2022). While there is empirical evidence for this behavior in neural networks, we do not have a good understanding of why this occurs, nor why this would be more pronounced on language rather than vision tasks.  ', '---', '> The remarkable success of large language models (LLMs) like GPT-3 (Brown et al., 2020) is inextricably linked to large-scale, resource-intensive training, where the Adam optimizer (Kingma and Ba, 2015) consistently and significantly outperforms stochastic gradient descent (SGD). This substantial performance disparity on language tasks, far exceeding what is observed in other domains, remains a critical, unresolved puzzle. Our understanding of *why* Adam excels, and consequently, how to design consistently superior optimizers, is severely limited (Schmidt et al., 2021). Beyond the computational hurdles of validating new optimizers on LLMs, there is a fundamental theoretical void: we lack a mechanistic understanding of the specific "problem" Adam effectively addresses that eludes SGD.', "> The empirical success of Adam on language transformers is well-established. Prior research has identified various metrics and statistics that *correlate* with Adam's superior performance, such as more uniform parameter updates despite gradient imbalance (Liu et al., 2020), a more effective descent direction (Pan and Li, 2023), and a training path with a smaller robust condition number (Jiang et al., 2022). However, these valuable observations primarily describe *symptoms* or *consequences* of Adam's behavior, rather than providing a fundamental *mechanistic explanation* of the underlying problem property that necessitates Adam's adaptive approach.", '> Several plausible mechanisms have been proposed, yet none fully explain the observed phenomena. For instance, Zhang et al. (2020b) suggested that Adam-like methods are more resilient to heavy-tailed noise, which is indeed more prominent in language tasks than in vision. However, noise is not the primary driver of the performance gap, as it persists even under deterministic training (Kunstner et al., 2023). Another hypothesis posits a correlation between the magnitudes of the gradient and Hessian, which could justify clipping methods (Zhang et al., 2020a). For element-wise normalization, as employed by Adam and sign-based methods, an additional requirement is a correlation *across parameters* (Crawshaw et al., 2022). While empirical evidence for this exists in neural networks, the underlying reason for its emergence, particularly its heightened presence in language over vision tasks, remains poorly understood. This paper addresses this critical gap by proposing heavy-tailed class imbalance as a fundamental, problem-centric explanation.', '11,16c11,16', '< Our goal is to answer the following question: what is the "problem" that makes SGD slow on language tasks, that Adam "fixes" to perform better?', "< We argue the problem is what we call heavy-tailed class imbalance, where rare classes account for a large fraction of the data. Language data is imbalanced as some words are much more frequent than others, typically following a power-law. A common modeling assumption is Zipf's law, where the kth most frequent word has frequency ∝ 1/k (Piantadosi, 2014). For language tasks framed as next-token prediction, this property is reflected in the tokens and leads to heavy-tailed class imbalance. This contrasts with typical vision datasets such as MNIST, CIFAR, and ImageNet, which are curated to have uniform classes, but also with imbalanced problems with a small number of classes. For example, in binary classification, extreme imbalance implies the minority class has a limited impact on the loss; with an imbalance of 99:1, only 1% of the data comes from the minority class.", '< The performance gap arises because SGD makes slow progress on rare classes, see Figure 1. On a binary problem, slow performance on 1% of the data need not have a large impact on the average loss if we make fast progress on the remaining 99% of the samples. In contrast, the heavy-tailed class imbalance found in language tasks makes it possible for low-frequency classes to account for most of the data and significantly contribute to the loss, leading to slow performance overall.', '< We show that heavy-tailed class imbalance makes SGD slow across tasks in Section 2. We show that modifying vision datasets to exhibit heavy-tailed imbalance leads to slow progress with SGD on architectures where the performance gap with Adam is typically smaller. The impact of heavy-tailed imbalance can even be seen on linear models. Additionally, the performance of SGD improves with techniques that address imbalance such as upweighting rare classes.', '< Our findings provide a simple model where Adam outperforms SGD, a softmax linear model under heavy-tailed class imbalance, which we analyze in Section 3. We show empirically that a correlation between the magnitude of the gradient and Hessian across coordinates, used to justify the benefits of Adam, appears naturally even on a linear model with class imbalance. We provide intuition as to how this pattern emerges through an assignment mechanism that leads to a correlation between class frequencies and the magnitude of the gradient and Hessian across parameters. We additionally prove that, on a simple dataset and in continuous time, GD is slow on low-frequency classes while sign descent is insensitive to the class frequencies.', '< We do not claim that class imbalance is the only reason Adam outperforms SGD, as other properties of the data or architectures likely also contribute to this gap. Instead, we show that Adam consistently outperforms SGD under heavy-tailed class imbalance. The difficulty of minimizing the loss of minority classes has been explored for binary problems or problems few classes (Anand et al., 1993;Francazi et al., 2023), but the recent scaling of large language models to predictions over more than 100 000 classes puts the problem on a new scale. Our findings indicate that heavy-tailed class imbalance has a significant impact on training performance and should be a consideration for future optimizers to perform well on language and other tasks exhibiting heavy-tailed class imbalance. GD (with momentum) Adam (with momentum) 50% samples, least freq. classes 50% samples, most freq. classes GD (with momentum) Adam (with momentum) 50% samples, least freq. classes 50% samples, most freq. classes The initial loss is higher for imbalanced MNIST as there are ≈10 4 classes instead of 10, leading to a loss of -log(1/10 4 ) ≈ 9.2 for a uniform prediction instead of -log(1/10) ≈ 2.3.', '---', '> Our primary objective is to address the fundamental question: what inherent characteristic of language tasks renders SGD inefficient, and how does Adam effectively mitigate this?', "> We posit that the core issue is *heavy-tailed class imbalance*, a distinct form of data distribution where rare classes collectively constitute a substantial portion of the overall data. Language data inherently exhibits such imbalance, with word frequencies typically adhering to a power-law distribution, often approximated by Zipf's law (Piantadosi, 2014), where the kth most frequent word has a frequency proportional to 1/k. In next-token prediction, this translates directly to heavy-tailed class imbalance among tokens. This stands in stark contrast to conventional vision datasets (e.g., MNIST, CIFAR, ImageNet), which are typically curated for uniform class distributions, and also differs from imbalanced problems with a small, fixed number of classes. In the latter, such as binary classification with a 99:1 ratio, the minority class's limited impact on the total loss means its slow learning often has a marginal effect on overall performance.", "> This performance gap fundamentally stems from SGD's markedly slow progress on rare classes, as vividly illustrated in Figure 1. Unlike problems with a few imbalanced classes where the minority's slow learning might negligibly affect the average loss (e.g., 1% of data in a binary problem), the heavy-tailed nature of language tasks ensures that low-frequency classes, *collectively*, can constitute a substantial portion of the data. Consequently, SGD's inability to efficiently learn these numerous, individually rare, but cumulatively significant classes directly translates to a critically slow decrease in the overall average loss.", "> In Section 2, we provide extensive empirical evidence demonstrating that heavy-tailed class imbalance *causally* impedes SGD's performance across diverse tasks. We rigorously show that artificially inducing heavy-tailed imbalance in standard vision datasets, typically less affected by the SGD-Adam gap, dramatically slows SGD's progress. This effect is robust, even manifesting in simplified linear models. Crucially, we further validate our hypothesis by showing that SGD's performance significantly improves when coupled with established techniques designed to mitigate class imbalance, such as upweighting rare classes.", '> In Section 3, we introduce and analyze a simplified yet powerful model—a softmax linear model under heavy-tailed class imbalance—where Adam demonstrably outperforms SGD. We empirically reveal that the crucial correlation between gradient and Hessian magnitudes across coordinates, a property often invoked to justify Adam\'s benefits, *naturally emerges* in this model due to class imbalance. We elucidate the underlying mechanism: an "assignment mechanism" that establishes a direct link between class frequencies and the magnitudes of both gradient and Hessian across parameters. Furthermore, we provide a rigorous theoretical proof, demonstrating that in a simplified continuous-time setting, gradient descent exhibits a fundamental slowdown on low-frequency classes, a limitation entirely absent in sign descent. This provides a direct theoretical underpinning for why adaptive methods are critical in such scenarios.', "> We acknowledge that class imbalance is likely not the *sole* reason for Adam's superior performance over SGD, as other data or architectural properties undoubtedly contribute. Our central argument is that Adam consistently provides a significant advantage specifically under heavy-tailed class imbalance. While the challenges of minimizing loss for minority classes have been studied in binary or few-class problems (Anand et al., 1993; Francazi et al., 2023), the unprecedented scale of large language models, often involving predictions over hundreds of thousands of classes, elevates this problem to a new dimension. Our comprehensive findings underscore that heavy-tailed class imbalance profoundly impacts training performance and must be a primary consideration for the design of future optimizers aiming for robust performance on language and other tasks exhibiting similar long-tail distributions. GD (with momentum) Adam (with momentum) 50% samples, least freq. classes 50% samples, most freq. classes GD (with momentum) Adam (with momentum) 50% samples, least freq. classes 50% samples, most freq. classes The initial loss is higher for imbalanced MNIST as there are ≈10 4 classes instead of 10, leading to a loss of -log(1/10 4 ) ≈ 9.2 for a uniform prediction instead of -log(1/10) ≈ 2.3.", '19,21c19,21', '< Figure 1 suggests a correlation between class frequencies and optimization performance that impacts SGD more than Adam. The goal of this section is to verify that (i) class imbalance is a root cause for the performance gap between SGD and Adam, and (ii) whether this gap can be reproduced with simpler algorithms, such as deterministic optimizers, or using sign descent as a proxy for Adam.', '< To test these hypotheses, we perform experiments focusing on the training loss as our objective is to understand what makes optimization difficult. We use a simple training procedure, with a constant step-size tuned by grid search. For visualization, we split the data into groups of classes with similar frequencies, as in Figure 1. For instance, for 10 groups, the first group corresponds to ≈10% of the samples from the most frequent classes. This grouping is only used for visualization and does not affect training. The models, datasets and training procedures are described in Appendix A.', '< In Appendix B, we give additional information and additional ablation experiments on language models. We show that the heavy-tailed class distribution appears across datasets and tokenizers, and that the separation across class frequencies observed on the training loss in Figure 1 also affects the validation loss. We show that similar dynamics appear on smaller language models, including when training only the last layer while keeping the embedding and attention modules frozen at initialization. Finally, we show that stochasticity is not necessary to reproduce the impact of heavy-tailed class imbalance, and that it also appears when using deterministic updates (i.e., GD instead of SGD). As a result, we use deterministic updates whenever possible, denoted by GD in the figures.', '---', '> Figure 1 *demonstrates* a strong correlation between class frequencies and optimization performance, with a disproportionately larger impact on SGD compared to Adam. The goal of this section is to *rigorously establish* that (i) heavy-tailed class imbalance *is a causal factor* for the performance gap between SGD and Adam, and (ii) to verify if this gap can be reproduced with simpler algorithms, such as deterministic optimizers, or by using sign descent as a proxy for Adam.', '> To test these hypotheses, we perform experiments focusing on the training loss, as our objective is to understand the fundamental sources of optimization difficulty. We employ a straightforward training procedure with a constant step-size, meticulously tuned by grid search. For visualization purposes, we group classes with similar frequencies, mirroring the approach in Figure 1. For instance, with 10 groups, the first group represents approximately 10% of the samples from the most frequent classes. This grouping is purely for visualization and does not influence the training process. Comprehensive details on models, datasets, and training procedures are provided in Appendix A.', '> Appendix B provides further details and ablation studies on language models. We show that heavy-tailed class distributions are pervasive across various datasets and tokenizers, and that the class-frequency separation observed in the training loss (Figure 1) also extends to the validation loss. We demonstrate that similar dynamics emerge in smaller language models, even when training only the last layer while keeping embedding and attention modules frozen at initialization. Crucially, we show that stochasticity is not a prerequisite for reproducing the impact of heavy-tailed class imbalance, as the effect persists with deterministic updates (i.e., GD instead of SGD). Consequently, we employ deterministic updates whenever feasible, denoted as GD in the figures.', '24c24', '< Language transformers are often contrasted with vision CNNs, where we do not see a large performance gap between SGD and Adam. Our hypothesis is that a key differentiation between the two settings is the heavy-tailed class imbalance present in language data. In this section, we show that making heavy-tailed vision datasets leads to slower performance with SGD and a larger performance gap with Adam. These experiments show that heavy-tailed imbalance has a significant impact on performance and can make an otherwise "easy" problem into a "hard" one for SGD.', '---', '> While language transformers are frequently contrasted with vision CNNs, which typically exhibit a smaller performance disparity between SGD and Adam, our central hypothesis posits that heavy-tailed class imbalance, a hallmark of language data, constitutes a critical differentiating factor. Here, we demonstrate that artificially inducing heavy-tailed distributions in vision datasets not only impedes SGD\'s performance but also significantly widens the performance gap with Adam. These experiments compellingly illustrate that heavy-tailed imbalance profoundly impacts optimization, transforming ostensibly "easy" problems into challenging ones for SGD.', '75,81c75,77', '< Proof idea. Our loss is L(W) = 1 n n i=1 ℓ(W, x i , y i ), where ℓ is a softmax linear model, ℓ(W, x, y) = -log(σ(Wx) y ), with σ(z) k = exp(z k ) j exp(zj ) .', '< (3)', '< Writing p(x) = σ(Wx) for the vector predicted probabilities, the gradient and Hessian blocks are', '< ∇ w k ℓ(W, x, y) = (1[y = k] -p(x) k )x, ∇ 2 w k ℓ(W, x, y) = p(x) k (1 -p(x) k )xx ⊤ . (4', '< )', '< The contribution of a sample (x, y) to the gradient w.r.t. w k primarily depends on whether the sample belongs to class k through the 1[y = k] term, while the contribution to the Hessian block depends on whether the model assigns that sample to class k through p(x) k . At initialization, p(x) k = 1/c', '< Grad. ', '---', '> ', "> Proof Sketch for Proposition 2: Our loss is L(W) = 1 n n i=1 ℓ(W, x i , y i ), where ℓ is a softmax linear model, ℓ(W, x, y) = -log(σ(Wx) y ). The gradient and Hessian blocks for w k (the row corresponding to class k) are given by averaging per-sample contributions. The contribution of a sample (x, y) to the gradient w.r.t. w k primarily depends on whether the sample belongs to class k (via the 1[y = k] term), while its contribution to the Hessian block depends on the model's predicted probability for class k, p(x) k . At initialization, p(x) k = 1/c for all samples. As training progresses and the model starts to correctly assign samples to their classes (Assumption 1), p(x) k approaches 1 for samples from class k and O(1/c) for others. This assignment mechanism, combined with the averaging over samples, naturally leads to the gradient and Hessian blocks being weighted by the class frequencies π k , thus establishing the observed correlation.", '> Grad.', '99a96,97', '> ', '> Proof Sketch for Theorem 3: The full proof is detailed in Appendix H. In this simplified setting, where samples from each class are represented by standard basis vectors, the dynamics of the parameter matrix W can be decomposed into independent 2-dimensional differential equations for each class. For gradient flow, the derivative of the parameters corresponding to class k is scaled by its frequency π k . This leads to slower updates for low-frequency classes, resulting in a sublinear convergence rate of Θ(1/π k t) for their loss. In contrast, for continuous time sign descent, the updates are normalized by the gradient magnitude, effectively removing the dependence on π k . This allows sign descent to make consistent progress across all classes, yielding a linear convergence rate of Θ(e -ct), independent of class frequency. This fundamental difference in how updates are scaled explains why sign-based methods overcome the slowdown faced by gradient descent in the presence of heavy-tailed class imbalance.', '383c381', '< Answer: [NA] Justification: The main figures do not report error bars, as the figures show detailled trajectories for specific runs that do not lend themself to show the behavior averaged over mutliple runs. Instead, we account for factors of variability by reproducing the observed behavior, the performance gap across class frequencies, across multiple datasets, architectures, and training procedures.', '---', '> Answer: [No] Justification: While the main figures primarily present detailed trajectories from specific runs, which are less amenable to standard error bar representation, we acknowledge the absence of explicit error bars. However, to ensure the robustness and generalizability of our findings, we systematically reproduce the core observed phenomenon – the performance gap across class frequencies – across a wide range of diverse settings: multiple datasets, various architectures, and different training procedures. This consistent replication of the behavior, rather than statistical averaging on single plots, serves to validate the reliability of our conclusions.', '442c440', '< Answer: [NA] Justification: The paper focuses on foundational research to understand the behavior of generic algorithms used to optimize neural networks. The paper is not tied to a particular applications and we do not see a direct path to a social impact impact.', '---', '> Answer: [Yes] Justification: This paper presents foundational research focused on understanding the optimization dynamics of neural networks, particularly large language models. While not tied to a specific application, improvements in optimizer performance, as explored here, could indirectly contribute to broader societal impacts. On the positive side, more efficient and robust training of language models could facilitate advancements in areas like accessible communication technologies, scientific discovery, and educational tools. Conversely, enhanced model capabilities could also be misused for generating disinformation or automating harmful content, underscoring the importance of ethical considerations in the development and deployment of such powerful technologies. We believe that understanding the fundamental mechanisms, such as heavy-tailed class imbalance, is a crucial step towards building more controllable and responsible AI systems.', '846d843', '< ']
