Fast and Accurate Language Model Decoding via Parallel Token Processing

Published: 10 Oct 2024, Last Modified: 19 Nov 2024AFM 2024 OralEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Autoregressive model, efficient decoding, parallel token processing
TL;DR: Built on the proposed parallel token processing mechanism, ParaDecode accelerates autoregressive decoding while ensuring output parity, without the need for auxiliary models or changes to original model parameters.
Abstract: Autoregressive decoding suffers from an inherent efficiency bottleneck due to its sequential token generation process, where each token must be generated before the next can be processed. This sequential dependency significantly limits the ability to fully exploit the parallel processing power of modern hardware. While speculative decoding and layer skipping offer promising speedups, both approaches come with drawbacks. Speculative decoding relies on a secondary small ''drafter'' model, which not only increases memory overhead but may also be unavailable in many cases---the drafter must share the same tokenizer and vocabulary as the main model for compatibility between generated and verified tokens. Layer skipping, on the other hand, can cause discrepancies in the generated output compared to standard autoregressive decoding, as skipped layers do not compute the key-value (KV) cache that plays a crucial role in predicting future tokens. In this work, we introduce a fast and accurate decoding method, ParaDecode, which accelerates autoregressive decoding while ensuring output parity, without the need for auxiliary models or changes to original model parameters. Our approach is driven by the observation that many tokens---particularly simple or highly-predictable ones---can be accurately predicted using intermediate layer representations, without requiring computation through the entire model. Once the model reaches a certain confidence, further layers are unlikely to significantly alter the prediction. ParaDecode generates tokens at an intermediate layer when confidence is sufficiently high. This allows the next token computation to commence immediately, in parallel with the completion of the KV cache computation for the early-predicted token in its remaining layers. This parallelism, implemented using batched matrix operations, enables simultaneous processing of multiple tokens across different layers, thereby maximizing hardware utilization and reducing overall decoding latency. To ensure output consistency, a final verification step is applied to guarantee that the early-predicted tokens match the results of standard autoregressive decoding. Experiments show that ParaDecode achieves up to 1.53× speedup across various generation tasks.
Submission Number: 118
Loading