Unveiling Induction Heads: Provable Training Dynamics and Feature Learning in Transformers

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 posterEveryoneRevisionsBibTeXCC BY-NC 4.0
Keywords: transformer, in context learning, Markov chain, n-gram model, feature learning, training dynamics, induction head
TL;DR: We theoretically prove that a two-layer transformer model learns to perform "induction head" after training on n-gram Markovian data
Abstract: In-context learning (ICL) is a cornerstone of large language model (LLM) functionality, yet its theoretical foundations remain elusive due to the complexity of transformer architectures. In particular, most existing work only theoretically explains how the attention mechanism facilitates ICL under certain data models. It remains unclear how the other building blocks of the transformer contribute to ICL. To address this question, we study how a two-attention-layer transformer is trained to perform ICL on $n$-gram Markov chain data, where each token in the Markov chain statistically depends on the previous n tokens. We analyze a sophisticated transformer model featuring relative positional embedding, multi-head softmax attention, and a feed-forward layer with normalization. We prove that the gradient flow with respect to a cross-entropy ICL loss converges to a limiting model that performs a generalized version of the "induction head" mechanism with a learned feature, resulting from the congruous contribution of all the building blocks. Specifically, the first attention layer acts as a copier, copying past tokens within a given window to each position, and the feed-forward network with normalization acts as a selector that generates a feature vector by only looking at informationally relevant parents from the window. Finally, the second attention layer is a classifier that compares these features with the feature at the output position, and uses the resulting similarity scores to generate the desired output. Our theory is further validated by simulation experiments.
Primary Area: Optimization (convex and non-convex, discrete, stochastic, robust)
Submission Number: 19091
Loading