Fast Training of Diffusion Models with Masked Transformers

Published: 05 Mar 2024, Last Modified: 05 Mar 2024Accepted by TMLREveryoneRevisionsBibTeX
Abstract: We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50\%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30\% of its original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion models without sacrificing the generative performance. Our code is available at https://github.com/Anima-Lab/MaskDiT.
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: We have updated our paper with the following revisions: 1. Add results of experiments on ImageNet-$512\times512$. (suggestion from Reviewer difn and FZfh) a. Modification and additions have been made to the abstract, the second last paragraph of the introduction, Section 4.1, 4.2, 4.3, and the conclusion to reflect our new results. b. Table 2 is added to present generative performance comparison on ImageNet-$512\times512$. c. Figure 4 is revised to present training efficiency comparison on ImageNet-$512\times512$. 2. Fix typos in abstract, intro, and 3.2. (thanks to Reviewer eJnL) 3. Remove redundant citations. All the changes are highlighted in a deep blue-green color for ease of identification.
Code: https://github.com/Anima-Lab/MaskDiT
Assigned Action Editor: ~Laurent_Dinh1
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Number: 1671
Loading