Keywords: grokking, generalization, delayed generalization, modular addition, representation learning, weight decay
TL;DR: we argue that post-memorization learning can be understood through the lens of constrained optimization
Abstract: Grokking is a phenomenon in neural networks, where full generalization occurs only after a substantial delay after complete memorization of the training data. Previous research has linked this delayed generalization to representation learning driven by weight decay, but the precise underlying dynamics remain elusive. In this paper, we argue that post-memorization learning can be understood through the lens of constrained optimization: gradient descent effectively minimizes the weight norm on the zero-loss manifold. We formally prove this in the limit of infinitesimally small learning rates and weight decay coefficients. To further dissect this regime, we introduce an approximation that decouples the learning dynamics of a subset of parameters from the rest of the network. Applying this framework, we derive a closed-form expression for the post-memorization dynamics of the first layer in a two-layer network. Experiments confirm that simulating the training process using our predicted gradients reproduces both the delayed generalization and representation learning characteristic of grokking.
Supplementary Material: zip
Primary Area: interpretability and explainable AI
Submission Number: 19274
Loading