Track: long paper (up to 8 pages)
Keywords: Diffusion model, Generative models, self-correction, inference-time strategy
TL;DR: We propose plug-and-play fine-tuning algorithm equipping self-correction ability to masked diffusion models.
Abstract: A natural desideratum for generative models is self-correction–detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architec-
tures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM–Plug-in Remasking for Inference-time Self-correction of Masked Diffusions–a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).
Anonymization: This submission has been anonymized for double-blind review via the removal of identifying information such as names, affiliations, and identifying URLs.
Data Release: We authorize the release of our submission and author names to the public in the event of acceptance.
Submission Number: 32
Loading