Keywords: Masked Diffusion Models, Diffusion model, Probabilistic Methods
Abstract: Masked diffusion language models (MDLMs) are trained to infill positions in randomly masked sequences, in contrast to traditional next-token prediction (NTP) models. Discussions around MDLMs focus on two benefits: (1) multi-token decoding and 2) any-order decoding. However, we observe that for math and coding tasks, any-order algorithms often underperform or behave similarly to left-to-right sampling, and standard multi-token decoding significantly degrades performance. At inference time, MDLMs compute the conditional distribution of all masked positions. A natural question is: How can we justify this additional compute when left-to-right one-token-at-a-time decoding is on par with any-order decoding algorithms? These findings warrant rethinking how MDLMs are utilized. First, we propose multi-token entropy decoding (MED), a simple adaptive sampler that minimizes the error incurred by decoding positions in parallel based on the conditional entropies of those positions. MED preserves performance across benchmarks and leads to 3× fewer steps. Second, we propose a reasoning-as-infilling framework. By using MDLMs to infill a reasoning template, we can structure outputs and distinguish between reasoning and answer tokens. In turn, this enables measuring answer uncertainty during reasoning. This enables early exits when the model converges on an answer. Combined with MED, this leads to a 69% speed-up on GSM8K with a minimal (0.1%) effect on accuracy. Finally, given an answer, our framework enables sampling from the posterior over reasoning traces conditioned on the answer, even when the model is incorrect. On GSM8K, this enables generating correct reasoning traces for 43% of problems originally solved incorrectly. Our work demonstrates that the training objective and compute used by MDLMs unlock many new possibilities for inference and post-training methods.
Primary Area: generative models
Submission Number: 19124
Loading