Do Language Models Plan Ahead for Future Tokens?

Research Area: Science of LMs
Keywords: probing, mechanistic interpretability, future tokens, transformers
TL;DR: We propose a theoretical notion of transformer models "pre-caching" (preparing information useful for future token inference); we exhibit pre-caching in a synthetic data setting and measure its occurence in language models.
Abstract: Do transformers ``think ahead'' during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at time step $t$ that is then used in future forward passes $t+\tau$. We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present during training result in the model computing features at $t$ irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step $t$ are already the same as those that would most benefit inference at time $t+\tau$. We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a constructed synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis, though pre-caching increases with model scale.
