Understanding Hallucinations in Diffusion Models through Mode Interpolation
๐ Abstract
The paper studies a failure mode in diffusion models called "hallucination", where the models generate samples that lie completely outside the support of the training data distribution. The authors investigate this phenomenon through experiments on 1D and 2D Gaussian mixtures, as well as a synthetic "Simple Shapes" dataset. They find that diffusion models exhibit "mode interpolation", where the models smoothly interpolate between nearby modes in the training data, generating samples in regions with negligible probability under the true data distribution.
The authors analyze the learned score function of the diffusion models and show that the neural networks learn a smooth approximation of the true discontinuous score function, leading to the mode interpolation behavior. They further observe that hallucinated samples exhibit high variance in the trajectory of the predicted values during the reverse diffusion process. Based on this observation, they propose a metric to detect hallucinations, which can filter out over 95% of hallucinated samples while retaining 96% of in-support samples.
Finally, the authors study the implications of hallucinations in the context of recursive generative model training, where models are trained on their own generated outputs. They show that the proposed detection mechanism can mitigate the model collapse that occurs during recursive training on synthetic datasets like 2D Gaussians, Simple Shapes, and MNIST.
๐ Q&A
[01] Understanding Mode Interpolation and Hallucination
1. What is the central phenomenon studied in the paper? The central phenomenon studied in the paper is "hallucination" in diffusion models, where the models generate samples that lie completely outside the support of the training data distribution.
2. How do the authors define hallucination and mode interpolation? The authors define hallucination as samples generated by the model that lie entirely outside the support of the real data distribution. Mode interpolation occurs when the model generates samples that directly interpolate (in input space) between two samples in the support set, such that the interpolation lies in the hallucination set.
3. How do the authors investigate hallucinations through simplified synthetic datasets? The authors investigate hallucinations through experiments on 1D and 2D Gaussian mixture datasets, as well as a synthetic "Simple Shapes" dataset. These simplified setups allow them to clearly observe the mode interpolation behavior of diffusion models.
4. What key observation do the authors make about the learned score function of diffusion models? The authors observe that the neural networks learn a smooth approximation of the true discontinuous score function, particularly in the regions between disjoint modes of the data distribution. This leads to the mode interpolation behavior, where the models generate samples in regions with negligible probability under the true data distribution.
[02] Diffusion Models know when they Hallucinate
1. What insight do the authors gain about hallucinated samples based on the trajectory of the predicted values? The authors observe that the trajectories of hallucinated samples exhibit high variance in the predicted values towards the end of the reverse diffusion process, in contrast to the low-variance trajectories of non-hallucinated samples.
2. How do the authors use this observation to develop a metric for detecting hallucinations? The authors use the variance of the predicted values in the final steps of the reverse diffusion process as a metric to distinguish hallucinated and non-hallucinated (in-support) samples. They show that this metric can effectively filter out over 95% of hallucinated samples while retaining 96% of in-support samples.
[03] Implications on Recursive Model Training
1. What is the key insight about the impact of hallucinations in the context of recursive generative model training? The authors observe that the hallucinated samples significantly influence the learning of the next generation's distribution during recursive training, leading to the modes collapsing into a single mode, deviating greatly from the original data distribution.
2. How does the proposed detection mechanism help mitigate the model collapse during recursive training? The authors show that using the proposed variance-based metric to filter out hallucinated samples can effectively mitigate the model collapse during recursive training on synthetic datasets like 2D Gaussians, Simple Shapes, and MNIST.