Keywords: transformers, llms, large language models, mamba, linear rnns, distillation, test time compute, inference time compute
TL;DR: Distilling pretrained Transformers such as Llama 3 to pure or hybrid Mamba models lead to better scaling in reasoning tasks due to faster inference time
Abstract: Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time.
A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms.
This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget?
To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers.
Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences.
Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teachers under fixed time budgets, opening a new direction for scaling inference compute.
Submission Number: 143
Loading