Bringing Stability to Diffusion: Decomposing and Reducing Variance of Training Masked Diffusion Models
Keywords: Diffusion LLMs, Masked Diffusion Models, Training Variance, Training Stability, Mask Schedule, Mask Sampling
TL;DR: We stabilize MDM training by deriving a variance decomposition and introducing two core methods: P-POTS, which is Pareto-optimal among all unbiased t-samplers, and MIRROR, which complements it. Experiments yield clear gains on final performance.
Abstract: Masked diffusion models (MDMs) are a promising alternative to autoregressive models (ARMs), but they suffer from **inherently** much higher training variance. High variance leads to noisier gradient estimates and unstable optimization, so even equally strong pretrained MDMs and ARMs that are competitive at initialization often diverge after task-specific training, with MDMs falling far behind. Currently, there has been no theoretical explanation or systematic solution. In this paper, we derive **the first decomposition** of MDM training variance into three sources: {A} masking pattern noise, {B} masking rate noise, and {C} data noise -- while ARMs are only affected by {C}. This cleanly explains the fundamental training gap. Building on this foundation, we design six variance-reduction methods, including two core methods: (1) P-POTS, a **Pareto-optimal** $t$-sampler that minimizes training variance by sampling harder $t$ values more often with appropriately smaller update steps, and (2) MIRROR, which uses negatively correlated samples to reduce {A}. Experiments show that, compared to standard MDM training, our methods improve accuracy by **7–8\%** on complex reasoning tasks, while simultaneously reducing run-to-run variability to **near ARM levels**, substantially narrowing the gap with strong ARM baselines; in most settings, even the best baseline method runs remain below the worst run of our method.
Primary Area: foundation or frontier models, including LLMs
Submission Number: 9788
Loading