Abstract: Sequence parallelism (SP) serves as a prevalent strategy to handle long sequences that exceed the memory limit of a single device. However, for linear sequence modeling methods like linear attention, existing SP approaches do not take advantage of their right-product-first feature, resulting in sub-optimal communication efficiency and usability. In this paper, we introduce Linear Attention Sequence Parallelism (LASP), an efficient SP approach designed for linear attention-based transformer models. Specifically, we design an efficient point-to-point ring-style communication mechanism to leverage the right-product kernel trick of linear attention, which sharply decreases the communication overhead, comparing with existing SP methods. We enhance the computation efficiency of LASP by performing kernel fusion and intermediate state caching, making the implementation of LASP hardware-friendly on GPUs. Furthermore, we meticulously ensure the compatibility of sequence-level LASP with all types of batch-level data parallel methods, which is vital for distributed training on large clusters with very-long sequences. We also discuss the generalization of LASP on other linear sequence modeling methods. Extensive experiments on linear attention-based models are conducted with varying sequence lengths from 2K to 4096K. LASP scales sequence length up to 4096K on 128 GPUs, which is 8$\times$ longer than existing SP methods.
Submission Length: Regular submission (no more than 12 pages of main content)
Assigned Action Editor: ~Shuangfei_Zhai3
Submission Number: 3917
Loading