Training-time Selection of Linear Vs. Softmax Attention in Layer-based Hybrid Transformers

20 Sept 2025 (modified: 03 Dec 2025)ICLR 2026 Conference Withdrawn SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: LLM, Transformer, Attention, Hybrid, Language-modeling, Efficient LLM, Linear Attention, KV-cache
TL;DR: A training method to construct memory efficient and accurate layer-based hybrid LLMs
Abstract: Given a prompt of initial length $M$ to generate $N$ tokens, the current transformer-based LLMs’ memory requirement grows with $\mathcal{O}(M + N)$, while the inference time grows with $\mathcal{O}(MN + N^2)$. While these models have achieved remarkable results across various tasks, their rapid growth rates set upper limits for enhancing the accuracy of the output via increasing the context length due to memory size constraints and time requirements. Linear attention mechanisms offer $\mathcal{O}(N)$ time and constant memory complexity but fall short in some language modeling tasks, especially when the softmax-attention is using the full context. A natural direction is to design hybrid models that combine the strengths of both approaches. In this work, we propose a training-based method for constructing such hybrid models. This method aims to find the optimal layer-based hybrid configuration of a transformer given a maximum tolerable incremental loss. The method aims to replace any softmax attention block with its linear counterpart, so long as it does not incur additional loss beyond a desirable tolerable limit. We evaluate our hybrid models on various language-modeling benchmarks. The result shows that the hybrid models obtained by this method, in some cases, can cut the LLM's total context cache peak memory usage by up to 40 \% while affecting the accuracy minimally (increasing perplexity by 1\%). Furthermore, we observe that our method of training, in some cases, even results in a reduction of the task-specific loss (e.g., cross-entropy) compared to an all softmax-attention configuration. Therefore, using the proposed method not only makes the model more efficient in terms of memory usage and compute intensity but also increases the accuracy, i.e., reduces perplexity. We show that early and final layers can usually be replaced with linear attention layers, while the mid layers must preserve softmax attention, and the exact pattern differs from dataset to dataset.
Primary Area: foundation or frontier models, including LLMs
Submission Number: 23796
Loading