Measuring In-Context Computation Complexity via Hidden State Prediction

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: Hidden state unpredictability in sequence models is a meaningful measure for in-context reasoning complexity
Abstract: Detecting when a neural sequence model does "interesting" computation is an open problem. The next token prediction loss is a poor indicator: Low loss can stem from trivially predictable sequences that are uninteresting, while high loss may reflect unpredictable but also irrelevant information that can be ignored by the model. We propose a better metric: measuring the model's ability to predict its own future hidden states. We show empirically that this metric–in contrast to the next token prediction loss–correlates with the intuitive interestingness of the task. To measure predictability, we introduce the architecture-agnostic "prediction of hidden states" (PHi) layer that serves as an information bottleneck on the main pathway of the network (e.g., the residual stream in Transformers). We propose a novel learned predictive prior that enables us to measure the novel information gained in each computation step, which serves as our metric. We show empirically that our metric predicts the description length of formal languages learned in-context, the complexity of mathematical reasoning problems, and the correctness of self-generated reasoning chains.
Lay Summary: Large language models, like ChatGPT, have become very powerful—in large part due to their capability to learn _in-context_. This means that these models can make novel inferences at test time. We propose a method that pinpoints when, and to what extent, such a model performs this kind of inference. In other words, we detect when a model is 'thinking hard'. We accomplish this by measuring how predictable the model's current internal states are, given previous states: if they are predictable, no information has been gained. But if they aren't, that means that the model has made new inferences. This method gives us insight into which tasks are challenging and interesting for a model, in the sense of requiring complex in-context computation. Intriguingly, we find that for reasoning tasks, such as solving a complex mathematical problem, the 'in-context computation complexity' as measured by our method is predictive of whether the model will arrive at the correct solution. This is particularly true for difficult problems where more 'thought' is required.
Link To Code: https://github.com/vincentherrmann/predicting-hidden-states
Primary Area: Deep Learning->Sequential Models, Time series
Keywords: in-context learning, interpretability, transformers
Submission Number: 16096
Loading