"""
Test script for DDIM-GMM scheduler with HuggingFace Diffusers

This script demonstrates how to use the DDIM-GMM scheduler.

Prerequisites:
1. Install diffusers: pip install diffusers torch pillow
2. Copy scheduling_ddim.py to your diffusers installation:
   cp scheduling_ddim.py $(python -c "import diffusers; import os; print(os.path.dirname(diffusers.__file__))")/schedulers/

Usage:
    python test_generation.py
"""

import os
import torch
from PIL import Image

# Set cache directory (optional)
os.environ['HF_HOME'] = os.path.expanduser("~/.cache/huggingface")

# Import from diffusers
# Note: scheduling_ddim.py must be in diffusers/schedulers/ directory
try:
    from diffusers import DDIMScheduler, UNet2DModel
    from diffusers.schedulers.scheduling_ddim import GMM
    print("✓ Successfully imported from installed diffusers")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("\nPlease ensure:")
    print("1. diffusers is installed: pip install diffusers")
    print("2. scheduling_ddim.py is copied to diffusers/schedulers/")
    print("   Run: cp scheduling_ddim.py $(python -c \"import diffusers; import os; print(os.path.dirname(diffusers.__file__))\")/schedulers/")
    exit(1)

def main():
    print("="*60)
    print("DDIM-GMM Test Generation")
    print("="*60)

    # Configuration
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    n_steps = 50
    gmm_n_components = 16
    gmm_scale = 1.0
    gmm_uniform_priors = True
    gmm_orthonormal_offsets = True
    gmm_vub = True

    print(f"\nGMM Configuration:")
    print(f"  - Components: {gmm_n_components}")
    print(f"  - Steps: {n_steps}")
    print(f"  - Scale: {gmm_scale}")
    print(f"  - VUB: {gmm_vub}")

    # Load pretrained model and scheduler
    print("\nLoading model...")
    try:
        model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to(device)
        scheduler = DDIMScheduler.from_pretrained("google/ddpm-cat-256")
        print("✓ Model loaded successfully")
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        print("Make sure you have internet connection for first-time download")
        exit(1)

    # Initialize GMM parameters
    print("\nInitializing GMM parameters...")
    sample_size = model.config.sample_size  # 256 for cat model
    z_channels = model.config.in_channels   # 3 for RGB

    gmm_params = GMM(device=device)
    gmm_params.initialize(
        dim=z_channels * sample_size * sample_size,
        n_components=gmm_n_components,
        n_steps=n_steps,
        scale=gmm_scale,
        uniform_priors=gmm_uniform_priors,
        orthonormal=gmm_orthonormal_offsets,
        upper_bound_vars=gmm_vub
    )
    print("✓ GMM initialized")

    # IMPORTANT: Set GMM parameters in scheduler
    scheduler.set_gmm_params(gmm_params=gmm_params)
    print("✓ GMM parameters set in scheduler")

    # Set number of inference steps
    scheduler.set_timesteps(n_steps)
    print(f"✓ Timesteps set to {n_steps}")

    # Generate image
    print(f"\nGenerating image...")
    noise = torch.randn((1, z_channels, sample_size, sample_size), device=device)
    input = noise

    for i, t in enumerate(scheduler.timesteps):
        if (i + 1) % 10 == 0:
            print(f"  Step {i+1}/{n_steps}")

        with torch.no_grad():
            noisy_residual = model(input, t).sample
            prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
            input = prev_noisy_sample

    print("✓ Generation complete")

    # Convert to image and save
    print("\nSaving image...")
    image = (input / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
    image = Image.fromarray((image * 255).round().astype("uint8"))

    output_path = 'test_generation_output.png'
    image.save(output_path)
    print(f"✓ Image saved to {output_path}")

    print("\n" + "="*60)
    print("✅ Test completed successfully!")
    print("="*60)

if __name__ == "__main__":
    main()
