STree: Speculative Tree Decoding for Hybrid State Space Models

Published: 18 Sept 2025, Last Modified: 29 Oct 2025NeurIPS 2025 posterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: State Space Models, hybrid models, speculative decoding, tree-based decoding
TL;DR: We enable tree-based decoding on SSMs to facilitate speculative decoding with tree-based verification with SSMs
Abstract: Speculative decoding is a technique to leverage hardware concurrency in order to enable multiple steps of token generation in a single forward pass, thus improving the efficiency of large-scale autoregressive (AR) Transformer models. State-space models (SSMs) are already more efficient than AR Transformers, since their state summarizes all past data with no need to cache or re-process tokens in the sliding window context. However, their state can also comprise thousands of tokens; so, speculative decoding has recently been extended to SSMs. Existing approaches, however, do not leverage the tree-based verification methods, since current SSMs lack the means to compute a token tree efficiently. We propose the first scalable algorithm to perform tree-based speculative decoding in state-space models (SSMs) and hybrid architectures of SSMs and Transformer layers. We exploit the structure of accumulated state transition matrices to facilitate tree-based speculative decoding with minimal overhead relative to current SSM implementations. Along with the algorithm, we describe a hardware-aware implementation that improves naive application of AR Transformer tree-based speculative decoding methods to SSMs. Furthermore, we outperform vanilla speculative decoding with SSMs even with a baseline drafting model and tree structure on three different benchmarks, opening up opportunities for further speed up with SSM and hybrid model inference. Code can be find at: https://github.com/wyc1997/stree.
Primary Area: Deep learning (e.g., architectures, generative models, optimization for deep networks, foundation models, LLMs)
Submission Number: 5718
Loading