Fundamental limits of learning in sequence multi-index models and deep attention networks: high-dimensional asymptotics and sharp thresholds
TL;DR: We derive the computational limits of Bayesian estimation for sequence multi-index models in high dimensions, with consequences for deep attention neural networks.
Abstract: In this manuscript, we study the learning of deep attention neural networks, defined as the composition of multiple self-attention layers, with tied and low-rank weights. We first establish a mapping of such models to sequence multi-index models, a generalization of the widely studied multi-index model to sequential covariates, for which we establish a number of general results. In the context of Bayes-optimal learning, in the limit of large dimension $D$ and proportionally large number of samples $N$, we derive a sharp asymptotic characterization of the optimal performance as well as the performance of the best-known polynomial-time algorithm for this setting --namely approximate message-passing--, and characterize sharp thresholds on the minimal sample complexity required for better-than-random prediction performance.
Our analysis uncovers, in particular, how the different layers are learned sequentially. Finally, we discuss how this sequential learning can also be observed in a realistic setup.
Lay Summary: Modern AI systems often rely on attention networks — models that look at relationships between different parts of the input. These are central to breakthroughs like large language models. But how do such complex networks actually learn, and how much data do they need to start making accurate predictions?
We studied deep attention networks in a simplified, low-rank setting, allowing us to derive exact results on the fundamental limits of learning from data. Our analysis provides a comprehensive mathematical framework to understand when learning is possible and how well these models can perform.
One of our most surprising findings is that the network doesn’t learn all at once: instead, its layers are learned sequentially. This means the model gradually builds up complexity, layer by layer — a behavior we also observe in some realistic scenarios.
Link To Code: https://github.com/SPOC-group/SequenceIndexModels
Primary Area: Theory->Learning Theory
Keywords: theory, multi-index model, approximate message-passing, replica method, statistical physics, attention mechanism
Submission Number: 6949
Loading