Learning Generation Orders for Masked Discrete Diffusion Models via Variational Inference

Published: 03 Mar 2026, Last Modified: 07 Apr 2026ICLR 2026 DeLTa Workshop PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Discrete Diffusion, Variational Inference, Masked Discrete Diffusion Models, Generative Models
TL;DR: We propose a variational inference scheme for predicting which tokens to unmask during parallel generation in a masked discrete diffusion model
Abstract: Masked discrete diffusion models (MDMs) are a promising new approach to generative modelling, offering the ability for parallel token generation and therefore greater efficiency than autoregressive counterparts. However, achieving an optimal balance between parallel generation and sample quality remains an open problem. Current approaches primarily address this issue through fixed, heuristic parallel sampling methods. There exist some recent learning based approaches to this problem, but its formulation from the perspective of variational inference remains underexplored. In this work, we propose a variational inference framework for learning parallel generation orders for MDMs. As part of our method, we propose a parameterisation for the approximate posterior of generation orders which facilitates parallelism and efficient sampling during training. Using this method, we conduct preliminary experiments on the GSM8K dataset, where our method performs competitively against heuristic sampling strategies in the regime of highly parallel generation. For example, our method achieves 33.1\% accuracy with an average of only 4 generation steps, compared to 23.7-29.0\% accuracy achieved by standard competitor methods in the same number of steps. We believe further experiments and analysis of the method will yield valuable insights into the problem of parallel generation with MDMs.
Submission Number: 128
Loading