6 Jun 2024 | Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, Michalis K. Titsias
This paper introduces a simplified and generalized masked diffusion framework for discrete data, aiming to unlock the full potential of masked diffusion models. The authors show that the continuous-time variational objective of masked diffusion models can be expressed as a weighted integral of cross-entropy losses. Their framework enables training generalized masked diffusion models with state-dependent masking schedules. When evaluated on OpenWebText, their models outperform prior diffusion language models at GPT-2 scale and perform well on zero-shot language modeling tasks. Additionally, their models achieve competitive or better performance than autoregressive models on pixel-level image modeling tasks, with 2.78 bits per dimension on CIFAR-10 and 3.42 bits per dimension on ImageNet 64×64. The paper also discusses the relationship between masked diffusion models and existing work, including continuous-time Markov chains and score parameterization. The authors propose a generalized masked diffusion model that allows state-dependent masking schedules, leading to improved predictive performance. The experiments show that their models outperform previous discrete diffusion models on text and image data, with the generalized model showing further improvements in likelihoods. Despite these improvements, the authors note that masked diffusion models still have limitations, such as not being competitive with autoregressive models in some tasks and being prone to overfitting. The paper concludes that their framework provides a simple and effective approach to masked diffusion models, with promising results on discrete data tasks.This paper introduces a simplified and generalized masked diffusion framework for discrete data, aiming to unlock the full potential of masked diffusion models. The authors show that the continuous-time variational objective of masked diffusion models can be expressed as a weighted integral of cross-entropy losses. Their framework enables training generalized masked diffusion models with state-dependent masking schedules. When evaluated on OpenWebText, their models outperform prior diffusion language models at GPT-2 scale and perform well on zero-shot language modeling tasks. Additionally, their models achieve competitive or better performance than autoregressive models on pixel-level image modeling tasks, with 2.78 bits per dimension on CIFAR-10 and 3.42 bits per dimension on ImageNet 64×64. The paper also discusses the relationship between masked diffusion models and existing work, including continuous-time Markov chains and score parameterization. The authors propose a generalized masked diffusion model that allows state-dependent masking schedules, leading to improved predictive performance. The experiments show that their models outperform previous discrete diffusion models on text and image data, with the generalized model showing further improvements in likelihoods. Despite these improvements, the authors note that masked diffusion models still have limitations, such as not being competitive with autoregressive models in some tasks and being prone to overfitting. The paper concludes that their framework provides a simple and effective approach to masked diffusion models, with promising results on discrete data tasks.