Enabling Approximate Joint Sampling in Diffusion LMs

17 Sept 2025 (modified: 12 Feb 2026)ICLR 2026 Conference Desk Rejected SubmissionEveryoneRevisionsBibTeXCC BY 4.0
Keywords: diffusion language models, joint sampling, drafting for language models
TL;DR: We present a novel algorithm for approximate joint sampling of multiple tokens from a single forward of the diffusion language model
Abstract: In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order. Generating an overall string sampled from the correct underlying joint distribution would (again) require unmasking exactly one token in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to {\em approximately} sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer "sampler" on top of an existing large diffusion LM. Multiple tokens are generated via multiple forward passes of only this sampler layer. Our sampler is trained to mimic exact joint sampling. We show the effectiveness of approximate joint sampling for both pretrained-only (Dream-7B-Base) and instruction-tuned (Dream-7B-Instruct) models on language modeling and math \& coding tasks. When four tokens are unmasked for each denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) w.r.t. the true joint distribution.
Primary Area: generative models
Submission Number: 9811
Loading