Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

Published: 06 Mar 2025, Last Modified: 19 Mar 2025ICLR 2025 FM-Wild WorkshopEveryoneRevisionsBibTeXCC BY 4.0
Keywords: xLSTM, FlashLinearAttention, FlashAttention, Linear Attention, Linear RNN, RNN, mLSTM, Attention, Self-Attention, Transformers, LLM, Kernels, Systems, Efficiency
TL;DR: We introduce TiledFlashLinearAttention a faster kernel algorithm for Linear RNNs and mLSTMs by improved Sequence Parallelism.
Abstract: Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient FlashAttention kernels (Dao, 2024). Leveraging the chunkwise-parallel formulation of linear RNNs, FlashLinearAttention (FLA) (Yang & Zhang, 2024) shows that linear RNN kernels are faster than FlashAttention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM (Beck et al., 2024). Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized FlashAttention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.
Submission Number: 30
Loading