Fast Training of Diffusion Models with Masked Transformers

06/15/2023
by   Hongkai Zheng, et al.
14

We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256×256 show that our approach achieves the same performance as the state-of-the-art Diffusion Transformer (DiT) model, using only 31% of its original training time. Thus, our method allows for efficient training of diffusion models without sacrificing the generative performance.

READ FULL TEXT

Please sign up or login with your details

Forgot password? Click here to reset