Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts

20 Sept 2025 (modified: 11 Feb 2026)Submitted to ICLR 2026EveryoneRevisionsBibTeXCC BY 4.0
Keywords: Transformer inference, Efficient transformer architectures, Recurrent Memory Transformers, diagonal batching, inference scheduling, LLM inference acceleration, efficient deep learning, long-context sequence modeling
Abstract: Long-context inference with Transformers is constrained by quadratic attention and linear memory growth. Many linear-time alternatives require pretraining from scratch, whereas Recurrent Memory Transformers (RMTs) convert pretrained models into segment-recurrent variants via finetuning without modifying the original model architecture. However, their sequential memory updates underutilize GPUs. We show that RMT-style architectures with layer-level memory (PRMTs) (e.g., ARMT) can be among the most latency-efficient linear approaches when scheduled properly. We introduce Diagonal Batching, a compute-reordering scheme that preserves exact recurrence while exposing inter-step parallelism by executing "diagonals" concurrently with grouped layers. On LLaMA (1B/3B/8B) up to 131,072 tokens on A100/H100, Diagonal Batching achieves up to 3.3× lower latency than full-attention inference and 1.8× over a sequential ARMT baseline, with no custom CUDA kernels. With the right scheduling, PRMTs achieve linear scaling with context length and stand out as competitive, scalable architectures among linear recurrent models.
Supplementary Material: zip
Primary Area: foundation or frontier models, including LLMs
Submission Number: 24761
Loading