The Canonical Representation of a Task

TMLR Paper9114 Authors

21 May 2026 (modified: 29 May 2026)Under review for TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: Generalization in deep learning remains poorly understood, as neural networks fall outside the framework of classical statistical learning theory. To make progress on understanding generalization, research has focused on controlled tasks such as modular arithmetic, as a testbed. On these tasks, models exhibit grokking, i.e., a delayed onset of generalization after training loss has converged. Prior work has identified empirical regularities in the learned representations associated with this transition, but the mapping between representation structure and generalization behavior remains empirical and descriptive. We lack a predictive theory of why and when generalization occurs. In this work, we provide such a predictive theory for modular arithmetic tasks including addition, subtraction, multiplication, and division. We introduce the notion of \textit{canonical representation} of a task: the representation determined by the target function prior to training which is needed for perfect generalization. For modular arithmetic, the canonical representation can be derived from the group structure of the task. We then define \textit{representational deviation} as the discrepancy between the learned representation and the canonical representation which meets a specified target loss. From this, we derive that reaching a prescribed level of generalization requires the representational deviation to fall below a threshold. We finally provide a set of reproducible experiments which empirically confirm the above findings and offer a regularizer to accelerate the grokking transition.
Submission Type: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Erin_Grant1
Submission Number: 9114
Loading