TL;DR: We introduce a formal framework for understanding the role of sparsity in length generalization in LLMs and provide empirical evidence for it.
Abstract: Training large language models to predict beyond their training context lengths has drawn much attention in recent years, yet the principles driving such behavior of length generalization remain underexplored. We propose a new theoretical framework to study length generalization for the next-token prediction task, as performed by decoder-only transformers. Conceptually, we show that length generalization occurs as long as each predicted token depends on a small (fixed) number of previous tokens. We formalize such tasks via a notion we call k-sparse planted correlation distributions, and show that an idealized model of transformers which generalize attention heads successfully length-generalize on such tasks. As a bonus, our theoretical model allows us to provide justifications for techniques to modify positional embeddings which have been introduced to improve length generalization, such as position coupling.
We support our theoretical results with experiments on synthetic tasks and natural language, which confirm that a key factor driving length generalization is indeed a ``sparse'' dependency structure of each token on the previous ones. Further, inspired by our theory, we introduce Predictive Position Coupling, a generalization of position coupling which trains the transformer to predict the position IDs used in a positional coupling approach. Predictive Position Coupling thereby allows us to broaden the array of tasks to which Position Coupling can successfully be applied to achieve length generalization.
Lay Summary: Researchers are trying to train language models (like ChatGPT) to make accurate predictions even when working with texts that are much longer than what they saw during training. But we still don’t fully understand why some models can handle these longer texts so well.
This paper introduces a new way of thinking about the problem. The key idea is that as long as each word (or token) mostly depends on just a few of the words before it, the model can learn to keep going even past its original training length. We build a mathematical framework to capture this idea and show that a simplified version of a transformer model (the kind used in ChatGPT) can succeed under these conditions.
We also explain why certain tricks—like adjusting how the model understands the positions of each word—help with this kind of generalization. Based on our theory, we introduce a new method called Predictive Position Coupling, which teaches the model to guess where in the sequence it is, improving its ability to handle longer inputs across a wider range of tasks.
We back up our claims with experiments on both made-up data and real language, showing that the model’s ability to handle longer sequences can be explained in large part by how many previous tokens each word depends on.
Primary Area: Theory->Domain Adaptation and Transfer Learning
Keywords: Length generalization, Transformers, Positional Encoding
Submission Number: 9371
Loading