HiDRA: A Blazing Fast LM-Head Replacement
Abstract: The output projection (LMHead) in large language models scales linearly with vocabulary size and constitutes a major portion of decoding cost, often at least 10\% on sub-1B models which are often deployed on small consumer hardware. We propose HiDRA (Hierarchical Decision Routing Architecture), a post-training method that replaces the dense vocabulary projection with a hierarchical sequence of binary decisions, reducing the computational complexity from $\mathcal{O}(V)$ to $\mathcal{O}(\log V)$ dot products per token. The hierarchy is constructed on top of an existing model by hierarchically splitting the embedding space in half by a single hyperplane. We prove that minimizing the error rate of a given hyperplane is equivalent to a Linear Discriminant Analysis optimization problem. On Gemma with WildChat completions, HiDRA achieves an $\sim4.79\times$ practical LMHead speedup on CPU, while reducing top-1 accuracy by only $2\%$ absolute at a moderate routing depth with beam search. More generally, different depths provide a smooth spectrum of operating points, allowing practitioners to select the desired balance between compute and model quality without retraining the base model.
Submission Number: 68
Loading