TL;DR: Transformers implement parallelized Bayesian inference through constrained belief updating, creating predictable intermediate geometric structures that bridge optimal prediction with attention's architectural constraints.
Abstract: What computational structures emerge in transformers trained on next-token prediction? In this work, we provide evidence that transformers implement constrained Bayesian belief updating---a parallelized version of partial Bayesian inference shaped by architectural constraints. We integrate the model-agnostic theory of optimal prediction with mechanistic interpretability to analyze transformers trained on a tractable family of hidden Markov models that generate rich geometric patterns in neural activations. Our primary analysis focuses on single-layer transformers, revealing how the first attention layer implements these constrained updates, with extensions to multi-layer architectures demonstrating how subsequent layers refine these representations. We find that attention carries out an algorithm with a natural interpretation in the probability simplex, and create representations with distinctive geometric structure. We show how both the algorithmic behavior and the underlying geometry of these representations can be theoretically predicted in detail---including the attention pattern, OV-vectors, and embedding vectors---by modifying the equations for optimal future token predictions to account for the architectural constraints of attention. Our approach provides a principled lens on how architectural constraints shape the implementation of optimal prediction, revealing why transformers develop specific intermediate geometric structures.
Lay Summary: How do AI systems like ChatGPT actually work inside? While these "transformer" models excel at predicting text, understanding their internal mechanisms has been like trying to reverse-engineer a black box. We faced a fundamental puzzle: optimal text prediction requires recurrent, step-by-step processing, but transformers process everything simultaneously in parallel. We studied transformers trained on simple, controlled data where we could mathematically predict what the optimal solution should look like. We discovered that transformers develop a clever workaround for this parallel-versus-sequential tension. They create intermediate internal representations that approximate optimal recurrent processing using parallel computation. Remarkably, we can predict exactly what these internal structures should look like using mathematical theory. This breakthrough provides a principled way to understand why AI systems learn specific internal representations. Rather than just observing what transformers do after the fact, we can now predict and explain their internal mechanisms based on the fundamental trade-off between optimal computation and architectural constraints. This theoretical framework could help us better understand, interpret, and potentially control large language models by revealing the mathematical principles governing how they balance computational efficiency with optimal prediction.
Link To Code: https://github.com/adamimos/epsilon-transformers/blob/main/examples/intermediate_representations.ipynb
Primary Area: Deep Learning->Theory
Keywords: transformer architectures, computational mechanics, mechanistic interpretability, optimal prediction, neural network representations, belief states, hidden markov models
Submission Number: 2891
Loading