# DDIM-GMM for Stable Diffusion

This is a self-contained implementation of DDIM-GMM sampling for Stable Diffusion and compatible diffusion models.

## Overview

DDIM-GMM extends DDIM by replacing the unimodal Gaussian kernel with a multimodal Gaussian mixture kernel. The mixture parameters are constrained so that the DDIM-GMM forward marginals have the same first and second order moments as the DDPM forward marginals. This provides:
- Improved sampling flexibility through mixture components
- Constrained parameters ensuring consistency with DDPM forward marginals
- Upper-bound variance (VUB) or full covariance computation

## Features

- ✅ Self-contained implementation - no dependencies on original `ldm` codebase
- ✅ Compatible with Stable Diffusion models
- ✅ Clean API with minimal interface requirements
- ✅ Optional evaluation metrics (FID, IS, Precision-Recall)

## Installation

### Requirements

```bash
pip install torch numpy tqdm pillow
```

###Optional (for evaluation metrics):

```bash
pip install torch-fidelity
```

## Usage

### Basic Sampling

```python
from ddim_gmm import DDIMSampler, GMM
import torch

# Initialize GMM parameters
gmm_params = GMM(gpu=0)  # or gpu=False for CPU
gmm_params.initialize(
    dim=3 * 64 * 64,        # latent dimension
    n_components=8,          # number of mixture components
    n_steps=50,              # number of DDIM steps
    scale=0.1,               # scale factor for mean offsets
    uniform_priors=True,     # uniform mixture weights
    orthonormal=True,        # orthonormal mean offsets
    upper_bound_vars=True    # use variance upper bound (VUB)
)

# Create DDIM sampler with GMM
sampler = DDIMSampler(model=diffusion_model, gpu=0, gmm=True, gmm_params=gmm_params)
sampler.make_schedule(ddim_num_steps=50, ddim_eta=0.0)

# Sample
samples, intermediates = sampler.sample(
    steps=50,
    batch_size=4,
    shape=(3, 64, 64),
    conditioning=conditioning_vector,  # optional
    eta=0.0
)
```

### Model Interface Requirements

Your diffusion model must provide:

```python
# Required attributes
model.num_timesteps: int
model.alphas_cumprod: Tensor
model.betas: Tensor
model.alphas_cumprod_prev: Tensor
model.cond_stage_key: str or None
model.cond_stage_model: nn.Module or None
model.first_stage_key: str
model.channels: int
model.image_size: int
model.device: torch.device

# Required methods
model.get_input(batch, key, **kwargs)
model.apply_model(x, t, c)
model.decode_first_stage(z)
model.get_learned_conditioning(batch)  # if conditional
model.train() / model.eval()
```

## Parameters

### GMM.initialize()

- **dim** (int): Dimension of latent space (channels × height × width)
- **n_components** (int): Number of GMM mixture components
- **n_steps** (int): Number of diffusion timesteps
- **scale** (float): Scale factor for mean offsets (default: 1.0)
- **uniform_priors** (bool): Use uniform mixture weights (default: True)
- **orthonormal** (bool): Orthonormalize mean offsets (default: True)
- **upper_bound_vars** (bool): Use VUB approximation (default: True)
  - True: Diagonal variance upper bound (faster)
  - False: Full covariance computation (more accurate)
- **dynamic_scale** (bool): Dynamically scale means (default: False)
- **init_cov** (bool): Initialize covariance matrices (default: False)

### DDIMSampler.sample()

- **steps** (int): Number of sampling steps
- **batch_size** (int): Batch size
- **shape** (tuple): Shape of latent space (C, H, W)
- **eta** (float): DDIM eta parameter (0.0 = deterministic)
- **conditioning** (Tensor, optional): Conditioning vector for conditional generation
- **unconditional_conditioning** (Tensor, optional): For classifier-free guidance
- **unconditional_guidance_scale** (float): Guidance scale (default: 1.0)

## Examples

### Unconditional Generation

```python
# No conditioning required
samples, _ = sampler.sample(
    steps=50,
    batch_size=4,
    shape=(3, 64, 64),
    eta=0.0
)
```

### Conditional Generation with Classifier-Free Guidance

```python
# Get conditioning embeddings
conditioning = model.get_learned_conditioning(class_labels)
null_conditioning = model.get_learned_conditioning(null_labels)

samples, _ = sampler.sample(
    steps=50,
    batch_size=4,
    shape=(3, 64, 64),
    conditioning=conditioning,
    unconditional_conditioning=null_conditioning,
    unconditional_guidance_scale=5.0,
    eta=0.0
)
```

## License

Licensed under the CreativeML Open RAIL-M License. See LICENSE file for details.

This implementation incorporates code adapted from:
- https://github.com/CompVis/stable-diffusion
- https://github.com/openai/improved-diffusion
- https://github.com/openai/guided-diffusion

## Notes

- For best results with conditional models, use classifier-free guidance
- VUB (upper_bound_vars=True) is recommended for faster sampling with minimal quality loss
- For high-quality generation, consider using more mixture components (16-32) and steps (100-250)
