Training Multi-Layer Transformers in Almost Linear Time

ICLR 2026 Conference Submission14692 Authors

19 Sept 2025 (modified: 08 Oct 2025)ICLR 2026 Conference SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Attention, Transformers, LLM, Training, Theory
Abstract: The computational complexity of the self-attention mechanism in popular transformer architectures poses significant challenges for training and inference, and becomes the bottleneck for long inputs. Is it possible to significantly reduce the quadratic time complexity of computing the gradients in multi-layer transformer models? This paper proves that a novel fast approximation method can calculate the gradients in almost linear time $n^{1+o(1)}$ where $n$ is the input sequence length, while it maintains a polynomially small approximation error $1 / \mathrm{poly}(n)$ across the entire model. Our theory holds for general loss functions and when the multi-layer transformer model contains many practical sub-modules, such as residual connection, causal mask, and multi-head attention. We further validate our approach through numerical experiments, demonstrating both its high approximation fidelity and substantial speedups in practice. By improving the efficiency of gradient computation, we hope that this work will facilitate more effective training and deployment of long-context language models based on our theoretical results.
Primary Area: foundation or frontier models, including LLMs
Submission Number: 14692
Loading