Fast and Memory-Efficient Multi-Sequence Generation via Structured Masking

Published: 21 Jun 2024, Last Modified: 24 Jul 2024ES-FoMo-II 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Multi-sequence sampling, parallel sampling, KV cache, LLM, attention, sparse masking, generation
TL;DR: Faster and More Memory Efficient PyTorch implementation of multi-sequence sampling in LLMs.
Abstract: Many applications of large language models (LLM) require drawing multiple samples from a single prompt, also known as multi-sequence generation. Current open-source approaches (e.g., HuggingFace) achieve this by replicating the prompt multiple times and treating each replication as an independent prompt within a batch. This approach is highly memory-inefficient, because the key-value (KV) cache will keep multiple copies for the repeated prompts. In this work, we present \name{}, an alternative exact and memory-efficient strategy for multi-sequence generation that only requires storing each prompt once. To achieve exactness, we design a structured masking strategy that ensures newly sampled tokens for each generation only attend to their predecessor tokens in the same sequence. Further, we propose a novel attention computation algorithm based on intermixing matrix multiplications and diagonalized matrices that has the same theoretical runtime as the baseline approach and is generally faster in practice. Empirically, we demonstrate that \name{} achieves consistent improvements in both generation time and memory consumption on a range of generation scenarios carefully controlled for prompt lengths, generation lengths, and number of sequence generations. Our core technique will be open-sourced and can be implemented in less than 50 lines of PyTorch.
Supplementary Material: pdf
Submission Number: 80
Loading