Abstract: Modern large language models (LLMs) excel at fitting finetuning data, but often struggle on unseen examples. In order to teach models genuine reasoning abilities rather than superficial pattern matching, our work aims to better understand how the learning dynamics of LLM finetuning shapes downstream generalization. Our analysis focuses on reasoning tasks, whose problem structure allows us to distinguish between memorization (the exact replication of reasoning steps from the training data) and performance (the correctness of the final solution). We find that a model's performance on test prompts can be effectively characterized by a training metric we call pre-memorization train accuracy: the accuracy of model samples on training queries before they begin to copy the exact reasoning steps from the training set. On the dataset level, this metric is able to almost perfectly predict test accuracy, achieving $R^2$ of $\geq 0.9$ across various models (Llama3 8B, Gemma2 9B), datasets (GSM8k, MATH), and training configurations. On a per-example level, this metric is also indicative of whether individual model predictions are robust to perturbations in the training query. By connecting a model's learning dynamics to test performance, pre-memorization train accuracy can inform training decisions, such as the makeup of the training data. Our experiments on data curation show that prioritizing examples with low pre-memorization accuracy leads to 1.5-2x improvements in data efficiency compared to i.i.d. data scaling and other data scaling techniques.
Lay Summary: Large Language Models often excel at mimicking training data rather than genuinely learning problem-solving skills, leading to poor performance on new problems. Our research introduces "pre-memorization train accuracy"—a measure of how well models solve problems before they start copying training solutions. This metric robustly predicts how well models will generalize to unseen reasoning tasks across various LLMs and datasets. By focusing training on examples with low pre-memorization accuracy, we can significantly improve data efficiency (1.5-2x), leading to LLMs with more genuine problem-solving abilities.
Primary Area: Deep Learning->Large Language Models
Keywords: reasoning, generalization, memorization
Submission Number: 7398
Loading