RLMedusa: Reinforcement Learning for Multiple Decoding Heads to Accelerate LLM Inference

Published: 05 Mar 2025, Last Modified: 05 Mar 2025SLLMEveryoneRevisionsBibTeXCC BY 4.0
Track: tiny / short paper (up to 4 pages)
Keywords: Inference Acceleration, Reinforcement Learning
TL;DR: We use reinforcement learning to train multiple decoding heads to speed up language models at inference time.
Abstract: Traditional transformer inference requires step-by-step generation of tokens in which each step is dependent on the previous one, presenting a bottleneck in inference speed. The Medusa technique used LoRA fine-tuning to train multiple decoding heads, each predicting a different number of tokens in advance in order to generate multiple tokens in parallel as part of a draft model that the base model can verify. In this paper, we propose a reinforcement learning based approach to training multiple decoding heads. Our method proposes a reward model scheme that leverages feed-forward networks to estimate token probabilities based on context hidden states and candidate token embeddings. We provide commentary comparing our interpretation of reinforcement learning in language modeling research and how this contrasts with traditional, RLHF-centric interpretations, as well as discuss our experiments with RLMedusa.
Anonymization: This submission has been anonymized for double-blind review via the removal of identifying information such as names, affiliations, and identifying URLs.
Presenter: ~Aadit_Juneja1
Format: Yes, the presenting author will attend in person if this work is accepted to the workshop.
Funding: Yes, the presenting author of this submission falls under ICLR’s funding aims, and funding would significantly impact their ability to attend the workshop in person.
Submission Number: 71
Loading

OpenReview is a long-term project to advance science through improved peer review with legal nonprofit status. We gratefully acknowledge the support of the OpenReview Sponsors. © 2025 OpenReview