Abstract: Understanding the learning process and the embedded computation in transformers is becoming a central goal for the development of interpretable AI. In the present study, we introduce a hierarchical filtering procedure for data models of sequences on trees, allowing us to hand-tune the range of positional correlations in the data. Leveraging this controlled setting, we provide evidence that vanilla encoder-only transformers can approximate the exact inference algorithm when trained on root classification and masked language modeling tasks, and study *how* this computation is discovered and implemented. We find that correlations at larger distances, corresponding to increasing layers of the hierarchy, are sequentially included by the network during training.
By comparing attention maps from models trained with varying degrees of filtering and by probing the different encoder levels, we find clear evidence of a reconstruction of correlations on successive length scales corresponding to the various levels of the hierarchy, which we relate to a plausible implementation of the exact inference algorithm within the same architecture.
Lay Summary: Modern artificial intelligence models called transformers power many state-of-the-art technologies, at the forefront of which are Large Language Models such as ChatGPT, Gemini, or Claude, to name a few. These models work by detecting and using intricate patterns and correlations in text, but how they learn to do this is still not fully understood.
In our study, we investigate this question using a simplified setting where we generate artificial sequences with a built-in layered structure, somewhat like how language is organized into phrases, subphrases, and words. This controlled setup lets us adjust how much structure is present and observe how it is learned by deep neural networks.
We show that standard transformer models, when trained on this kind of data, begin to approximate a known step-by-step method for processing such structures, called belief propagation, even without being explicitly told to do so. Remarkably, this approximation is learned progressively during training, and implemented in a way that is both intuitive and interpretable within the model’s architecture. This provides insight into how current AI systems can, through training alone, come to perform structured algorithmic computations.
Link To Code: https://github.com/emanuele-moscato/tree-language
Primary Area: Deep Learning->Attention Mechanisms
Keywords: Transformers, Belief Propagation, mechanistic explanation, structured data, hierarchical data model, attention, masked language modeling
Submission Number: 10111
Loading