Finite-Time Analysis of Gradient Descent for Shallow Transformers
TL;DR: We prove that a shallow multi-head Transformer trained with projected GD in the kernel regime needs only logarithmic width in sample size and has optimization error independent of sequence length.
Abstract: Understanding why Transformers perform so well remains challenging due to their non-convex optimization landscape. In this work, we analyze a shallow Transformer with $m$ independent heads trained by projected gradient descent in the kernel regime. Our analysis reveals two main findings: (i) the width required for nonasymptotic guarantees scales only logarithmically with the sample size $n$, and (ii) the optimization error is independent of the sequence length $T$. This contrasts sharply with recurrent architectures, where the optimization error can grow exponentially with $T$. The trade-off is memory: to keep the full context, the Transformer's memory requirement grows with the sequence length. We validate our theoretical results numerically in a teacher–student setting and compare Transformers with recurrent architectures on an autoregressive task.
Submission Number: 1721
Loading