How Transformers Learn Regular Language Recognition: A Theoretical Study on Training Dynamics and Implicit Bias

Published: 01 May 2025, Last Modified: 18 Jun 2025ICML 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
TL;DR: We analyze how a one-layer transformer learns to solve two regular language tasks: even pairs and parity check. The model can learn to solve even pairs directly and parity check through CoT.
Abstract: Language recognition tasks are fundamental in natural language processing (NLP) and have been widely used to benchmark the performance of large language models (LLMs). These tasks also play a crucial role in explaining the working mechanisms of transformers. In this work, we focus on two representative tasks in the category of regular language recognition, known as 'even pairs' and 'parity check', the aim of which is to determine whether the occurrences of certain subsequences in a given sequence are even. Our goal is to explore how a one-layer transformer, consisting of an attention layer followed by a linear layer, learns to solve these tasks by theoretically analyzing its training dynamics under gradient descent. While even pairs can be solved directly by a one-layer transformer, parity check need to be solved by integrating Chain-of-Thought (CoT), either into the inference stage of a transformer well-trained for the even pairs task, or into the training of a one-layer transformer. For both problems, our analysis shows that the joint training of attention and linear layers exhibits two distinct phases. In the first phase, the attention layer grows rapidly, mapping data sequences into separable vectors. In the second phase, the attention layer becomes stable, while the linear layer grows logarithmically and approaches in direction to a max-margin hyperplane that correctly separates the attention layer outputs into positive and negative samples, and the loss decreases at a rate of $O(1/t)$. Our experiments validate those theoretical results.
Lay Summary: Language recognition tasks are essential tools in natural language processing (NLP), both for evaluating large language models and for understanding how they work. In this study, we take a closer look at two such tasks, namely 'even pairs' and 'parity check', which test whether certain patterns appear an even number of times in a sequence. We investigate how a one-layer transformer model can learn to solve these tasks. Through mathematical analysis, we track how the model changes as it is trained using gradient descent. We find that the training process happens in two stages. First, the attention layer quickly learns to highlight useful parts of the input. Then, the linear layer gradually learns to make the final decision, eventually drawing a clear boundary between correct and incorrect answers. Interestingly, we theoretically show that the model trained on the even pairs task can solve the parity check task through Chain-of-Thought reasoning. This reasoning can be further added during model training to enhance the power of transformers. Finally, we confirm our theoretical insights with experiments, showing that the model behaves just as our analysis predicts.
Primary Area: Theory->Learning Theory
Keywords: Transformers, Training dynamics, Implicit bias, language recognition
Submission Number: 12758
Loading