To Grok or not to Grok: Disentangling Generalization and Memorization on Corrupted Algorithmic Datasets

Published: 16 Jan 2024, Last Modified: 05 Mar 2024ICLR 2024 posterEveryoneRevisionsBibTeX
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: Interpretability, Grokking, Label noise, Generalization, Memorization, Representations, Modular Arithmetic
Submission Guidelines: I certify that this submission complies with the submission instructions as described on
TL;DR: We study interpretable models+task where generalizing and memorizing representations are distinguishable. Training with corrupted data, we isolate the performance on corrupted and uncorrupted data; and explain the effect of explicit regularization.
Abstract: Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where ($\xi \cdot 100\\%$) of labels are corrupted (*i.e.* some results of the modular operations in the training set are incorrect). We show that (i) it is possible for the network to memorize the corrupted labels *and* achieve $100\\%$ generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve $100\\%$ accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is ("mechanistically") interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes *grokking* dynamics reaching high train *and* test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps from $100\\%$ to $100 (1-\xi)\\%$.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
Supplementary Material: zip
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Primary Area: visualization or interpretation of learned representations
Submission Number: 8168