Keywords: Generative Conditional Diffusion Models, Deep Mutual Learning, Long-Tailed Learning, Learning in Games
TL;DR: We provide a mathematically rigorous framework to study the use of Deep Mutual Learning in the context of conditional diffusion models for long tailed generation.
Abstract: Deep generative models, particularly diffusion models, have achieved remarkable success but face significant challenges when trained on real-world, long-tailed datasets- where few "head" classes dominate and many "tail" classes are underrepresented. This paper develops a theoretical framework for long-tailed learning via diffusion models through the lens of deep mutual learning. We introduce a novel regularized training objective that combines the standard diffusion loss with a mutual learning term, enabling balanced performance across all class labels, including the underrepresented tails. Our approach to learn via the proposed regularized objective is to formulate it as a multi-player game, with Nash equilibrium serving as the solution concept. We derive a non-asymptotic first-order convergence result for individual gradient descent algorithm to find the Nash equilibrium. We show that the Nash gap of the score network obtained from the algorithm is upper bounded by $\mathcal{O}(\frac{1}{\sqrt{T_{train}}}+\beta)$ where $\beta$ is the regularizing parameter and $T_{train}$ is the number of iterations of the training algorithm. Furthermore, we theoretically establish hyper-parameters for training and sampling algorithm that ensure that we find conditional score networks (under our model) with a worst case sampling error $\mathcal{O}(\epsilon+1), \forall \epsilon>0$ across all class labels. Our results offer insights and guarantees for training diffusion models on imbalanced, long-tailed data, with implications for fairness, privacy, and generalization in real-world generative modeling scenarios.
Supplementary Material: zip
Primary Area: Theory (e.g., control theory, learning theory, algorithmic game theory)
Submission Number: 22545
Loading