Research Area: Compute efficient LMs
Keywords: Distributed system, memory efficient exact attention, long context, large language model
TL;DR: We develop a distributed extension to FlashAttention, specifically optimized for long context causal LLMs training.
Abstract: FlashAttention effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU. In this paper, we introduce DistFlashAttention, a distributed memory-efficient attention mechanism optimized for long-context LLMs training. We propose three key techniques: token-level workload balancing, overlapping key-value communication, and a rematerialization-aware gradient checkpointing algorithm. We evaluate DistFlashAttention on Llama-7B and variants with sequence lengths from 32K to 512K. DistFlashAttention achieves 8x longer sequences, 4.45 - 5.64x speedup compared to Ring Self-Attention, 2-8x longer sequences, 1.24- 2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67x and 1.26-1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses. Codes are available at https://github.com/RulinShao/LightSeq.
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the COLM Code of Ethics on https://colmweb.org/CoE.html
Author Guide: I certify that this submission complies with the submission instructions as described on https://colmweb.org/AuthorGuide.html
Submission Number: 1068
Loading