Grokking modular arithmetic can be explained by margin maximization

Published: 07 Nov 2023, Last Modified: 13 Dec 2023M3L 2023 PosterEveryoneRevisionsBibTeX
Keywords: grokking, deep learning, generalization
TL;DR: Grokking for modular arithmetic problems can be explained by gradient descent's implicit bias towards margin maximization
Abstract: We present a margin-based generalization theory explaining the “grokking” phenomenon (Power et, al. 2022), where the model generalizes long after overfitting to arithmetic datasets. Specifically, we study two-layer quadratic networks on mod-$p$ arithmetic problems, and show that solutions with maximal margin normalized by $\ell_\infty$ norm generalize with $\tilde O(p^{5/3})$ samples. To the best of our knowledge, this is the first sample complexity bound strictly better than a trivial $O(p^2)$ complexity for modular addition. Empirically, we find that GD on unregularized $\ell_2$ or cross entropy loss tend to maximize the margin. In contrast, we show that kernel-based models, such as networks that are well-approximated by their neural tangent kernel, need $\Omega(p^2)$ samples to achieve non-trivial $\ell_2$ loss. Our theory suggests that grokking might be caused by overfitting in the kernel regime of early training, followed by generalization as gradient descent eventually leaves the kernel regime and maximizes the normalized margin.
Submission Number: 86
Loading