Keywords: FlashMask, Efficient Attention Computation, Sparse Mask Representation, Linear Memory Complexity, Low Computational Complexity
Abstract: Recent advancements in Larger-Scale Transformers have significantly benefited from sophisticated attention mechanisms, which are critical for modeling long-context sequences. However, the computational and memory demands of conventional attention mask computations, typically scaling with an $\mathcal{O}(N^2)$ complexity where $N$ is the sequence length, pose significant challenges. This paper introduces FlashMask, a simple yet effective \emph{Exact} attention algorithm designed to substantially reduce both the computational complexity and memory requirements of attention computations. By adopting a novel column-wise sparse representation of attention masks, FlashMask achieves a linear memory complexity of $\mathcal{O}(N)$ and computational complexity of $\mathcal{O}(N)\sim\mathcal{O}(N^2)$. We assess the performance of FlashMask in a variety of masking scenarios, including causal and customized attention masks, demonstrating its versatility and robustness across a wide range of attention patterns and models. Our empirical analysis encompasses a variety of downstream training modalities, including Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and Reward Model (RM). We compare FlashMask against state-of-the-art techniques, including notably FlashAttention. In kernel-level assessments, FlashMask achieves substantial computational speedups, up to 6.7x (SFT), 6.9x (DPO), and 8.3x (RM). Furthermore, in end-to-end training, FlashMask consistently enhances training speed significantly, with accelerations up to 2.4x (SFT), 4.2x (LoRA), 2.5x (DPO), and 2.6x (RM) across these varied scenarios without sacrificing model accuracy. Additionally, when implemented in the LoRA scenario, FlashMask enables the LLaMA2-7B to process sequence lengths of up to 544k, significantly enhancing its capability for long-context input.
Primary Area: Infrastructure (libraries, improved implementation and scalability, distributed solutions)
Submission Number: 16069
Loading