FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

Published: 16 Jan 2024, Last Modified: 17 Mar 2024ICLR 2024 posterEveryoneRevisionsBibTeX
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: convolutions, GPUs, hardware-efficient algorithms, long context, fast fourier transform, I/O awareness
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: We propose FlashFFTConv, a new system that optimizes the FFT convolution algorithm to speed up long convolutions and enable long-sequence applications.
Abstract: Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)---which allows long convolutions to run in $O(N\log N)$ time in sequence length $N$ but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms---1) partial convolutions and 2) frequency-sparse convolutions---which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 8.7$\times$ over PyTorch and achieves up to 4.4$\times$ speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity and M2-BERT-base to achieve 3.3 points higher GLUE score---matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models---yielding the first DNA model that can process the longest human genes (2.3M base pairs)---and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
Supplementary Material: zip
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Primary Area: infrastructure, software libraries, hardware, etc.
Submission Number: 8584
Loading