Keywords: Transformer inference, Efficient transformer architectures, Recurrent Memory Transformers, diagonal batching, inference scheduling, LLM inference acceleration, efficient deep learning, long-context sequence modeling
Abstract: Long-context inference with Transformers is constrained by quadratic attention
and linear memory growth. Many linear-time alternatives require pretraining from
scratch, whereas Recurrent Memory Transformers (RMTs) convert pretrained
models into segment-recurrent variants via finetuning without modifying the original
model architecture. However, their sequential memory updates underutilize GPUs.
We show that RMT-style architectures with layer-level memory (PRMTs)
(e.g., ARMT) can be among the most latency-efficient linear approaches when scheduled
properly. We introduce Diagonal Batching, a compute-reordering scheme that
preserves exact recurrence while exposing inter-step parallelism by executing
"diagonals" concurrently with grouped layers. On LLaMA (1B/3B/8B) up to 131,072
tokens on A100/H100, Diagonal Batching achieves up to 3.3× lower latency than
full-attention inference and 1.8× over a sequential ARMT baseline, with no
custom CUDA kernels. With the right scheduling, PRMTs achieve linear scaling with
context length and stand out as competitive, scalable architectures among linear
recurrent models.
Supplementary Material: zip
Primary Area: foundation or frontier models, including LLMs
Submission Number: 24761
Loading