Keywords: Invariance, Regularization, Math Reasoning
Abstract: Mathematical reasoning is a central aspect in the evaluation of language models. For many modular arithmetic tasks, prior work has observed the phenomenon of grokking, where the training accuracy converges to nearly 100\%, whereas the test performance lags behind for an extended number of epochs until finally reaching perfect accuracy. In this paper, we find that by injecting invariant structures into modular arithmetic tasks, we can significantly speed up the number of training steps. Specifically, let $g$ denote a label-invariant transformation and $x$ denote an input. In the case of modular addition, $a$ plus $b$, if we transform the input into $a + i$ (mod $p$) and $b - i$ (mod $p$), the outcome remains the same. Given a math reasoning task and a set of invariant transformation rules, our approach works by applying one of the transformations $g$ to the input $x$ (similar to data augmentation). Then, we interpolate the transformed input $g(x)$ with the original input $x$. Finally, we also add noise to the weights before computing the gradient to reduce the sharpness of the loss surface. When evaluated on three modular arithmetic tasks, we find that this approach reduces the number of grokking steps by more than $60$\% compared to existing sharpness-reduction and acceleration methods. In addition, this new approach can also be used for out-of-domain samples. When evaluated on six text-based arithmetic and graph-algorithmic tasks, our approach improves the test accuracy of LLMs by $16.5$\% and by $69$\%. Lastly, we provide a generalization bound that depends on a Hessian distance measure for learning invariant function classes to further validate our approach.
Submission Number: 108
Loading