Layer Importance for Mathematical Reasoning is Forged in Pre-Training and Invariant after Post-Training
Keywords: mechanistic interpretability, Layer Ablation, Critical Layers, Post-training, Pre-training
TL;DR: Math reasoning in large language models relies on critical layers that form during pretraining and remain stable after post-training, with token representations shifting from syntactic to task-relevant clusters around the critical layers.
Abstract: Large language models improve at math after instruction tuning, reinforcement learning, or knowledge distillation. We ask whether these gains come from major changes in the transformer layers or from smaller adjustments that keep the original structure. Using layer-wise ablation on base and trained variants, we find that math reasoning depends on a few critical layers, which stay important across all post-training methods. Removing these layers reduces math accuracy by as much as 80%, whereas factual recall tasks only show relatively smaller drops. This suggests that specialized layers for mathematical tasks form during pre-training and remain stable afterward. As measured by Normalized Mutual Information (NMI), we find that near these critical layers, tokens drift from their original syntactic clusters toward representations aligned with tokens less syntactically related but potentially more useful for downstream task.
Submission Number: 231
Loading