SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference

23 Sept 2023 (modified: 11 Feb 2024)Submitted to ICLR 2024EveryoneRevisionsBibTeX
Primary Area: generative models
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: Adaptive Computation, Efficient Inference, Generative Models
Submission Guidelines: I certify that this submission complies with the submission instructions as described on https://iclr.cc/Conferences/2024/AuthorGuide.
TL;DR: SkipDecode, a novel method, enhances autoregressive language model inference by allowing controllable computational budgeting, achieving 2x to 5x speedups and compatibility with batch inferencing and Key-Value caching.
Abstract: Autoregressive large language models (LLMs) have made remarkable progress in various natural language generation tasks. However, the high computing cost and latency resulting from token-by-token generation impede their widespread adoption. To address this issue, several approaches have been proposed that reduce computational cost using early-exit strategies. These strategies enable faster text generation using reduced computation without applying the full computation graph to each token. While existing token-level early exit methods show promising results for online inference – they cannot be readily applied for batch inferencing and Key-Value caching. This is because they have to wait until the last token in a batch exits before they can stop computing. This severely limits the practical application of such techniques. In this paper, we propose a simple and effective token-level early exit method, SkipDecode, designed to work seamlessly with batch inferencing and KV caching. It overcomes prior constraints by setting up a singular exit point for every token in a batch at a each sequence position. It also guarantees a monotonic decrease in exit points, thereby eliminating the need to recompute KV Caches for preceding tokens. Rather than terminating computation prematurely as in prior works, our approach bypasses lower to middle layers, devoting most of the computational resources to upper layers, allowing later tokens to benefit from the compute expenditure by earlier tokens. Our experimental results show that SkipDecode can obtain 2x to 5x inference speedups with negligible regression across a variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7 billion parameters, all the while being directly compatible with batching and KV caching optimization techniques.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Submission Number: 7405
Loading