"""
DDIM-GMM Sampler

This is a self-contained implementation of DDIM-GMM sampling for Stable Diffusion and compatible diffusion models.

================================================================================
LICENSE AND ATTRIBUTION
================================================================================

This file incorporates code adapted from:
Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors

Licensed under the CreativeML Open RAIL-M License.
See https://github.com/CompVis/stable-diffusion/blob/main/LICENSE for full text.

This file incorporates code adapted from:
- https://github.com/openai/improved-diffusion (Gaussian diffusion utilities)
- https://github.com/openai/guided-diffusion (Neural network utilities)
- https://github.com/lucidrains/denoising-diffusion-pytorch (Diffusion implementation)

================================================================================
DEPENDENCIES
================================================================================

Required:
- torch (PyTorch)
- numpy
- tqdm (progress bars)
- PIL (Python Imaging Library)

Optional (for evaluation metrics only):
- torch_fidelity (FID calculation)
- ipr (Improved Precision-Recall metrics)

Model Interface Requirements:
The diffusion model object must provide:
- model.num_timesteps: int - Number of diffusion timesteps
- model.alphas_cumprod: Tensor - Cumulative product of (1 - beta_t)
- model.betas: Tensor - Beta schedule
- model.alphas_cumprod_prev: Tensor - Shifted alphas_cumprod
- model.cond_stage_key: str or None - Key for conditioning (e.g., "class_label")
- model.cond_stage_model: nn.Module or None - Conditioning model
- model.first_stage_key: str - Key for input data (e.g., "image")
- model.channels: int - Number of latent channels
- model.image_size: int - Latent image size
- model.device: torch.device - Device for computation
- model.get_input(batch, key, **kwargs): Get and encode input data
- model.apply_model(x, t, c): Apply denoising model
- model.decode_first_stage(z): Decode latent to image space
- model.get_learned_conditioning(batch): Get conditioning embeddings
- model.train() / model.eval(): Set train/eval mode

================================================================================
"""

import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from itertools import cycle
import os

# Standard library imports
from torch.distributions.categorical import Categorical
from torch.nn.functional import relu, softmax, normalize
from torch.linalg import norm

# Optional imports for evaluation
try:
    from PIL import Image
    HAS_PIL = True
except ImportError:
    HAS_PIL = False
    print("Warning: PIL not available. Image saving will be disabled.")

try:
    import torch_fidelity
    HAS_TORCH_FIDELITY = True
except ImportError:
    HAS_TORCH_FIDELITY = False

try:
    import ipr.improved_precision_recall as ipr
    HAS_IPR = True
except ImportError:
    HAS_IPR = False


# ================================================================================
# UTILITY FUNCTIONS (Inlined from ldm.modules.diffusionmodules.util)
# ================================================================================
# The following functions are adapted from:
# - https://github.com/openai/improved-diffusion
# - https://github.com/openai/guided-diffusion
# ================================================================================

def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
    """
    Create timestep schedule for DDIM sampling.

    Args:
        ddim_discr_method: Discretization method ('uniform' or 'quad')
        num_ddim_timesteps: Number of DDIM sampling steps
        num_ddpm_timesteps: Total number of diffusion timesteps
        verbose: Whether to print selected timesteps

    Returns:
        Array of selected timestep indices
    """
    if ddim_discr_method == 'uniform':
        c = num_ddpm_timesteps // num_ddim_timesteps
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
    elif ddim_discr_method == 'quad':
        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
    else:
        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')

    # Add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1
    if verbose:
        print(f'Selected timesteps for ddim sampler: {steps_out}')
    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
    """
    Create sampling parameters for DDIM.

    Args:
        alphacums: Cumulative alphas from diffusion model
        ddim_timesteps: Selected timestep indices
        eta: Stochasticity parameter (0 = deterministic DDIM, 1 = DDPM)
        verbose: Whether to print parameters

    Returns:
        Tuple of (sigmas, alphas, alphas_prev)
    """
    # Select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # According to the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
    if verbose:
        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
        print(f'For the chosen value of eta, which is {eta}, '
              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
    return sigmas, alphas, alphas_prev


def extract_into_tensor(a, t, x_shape):
    """
    Extract values from array 'a' at indices 't' and reshape for broadcasting.

    Args:
        a: Source array/tensor
        t: Indices to extract
        x_shape: Target shape for broadcasting

    Returns:
        Extracted values reshaped for broadcasting
    """
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def noise_like(shape, device, repeat=False):
    """
    Generate noise with specified shape.

    Args:
        shape: Shape of noise tensor
        device: Device to create tensor on
        repeat: Whether to repeat same noise across batch

    Returns:
        Random noise tensor
    """
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()


# ================================================================================
# DDIM-GMM SAMPLER
# ================================================================================

class DDIMSampler(object):
    """
    DDIM-GMM sampler.

    This implementation extends DDIM by replacing the unimodal Gaussian kernel
    with a multimodal Gaussian mixture kernel. The mixture parameters are
    constrained so that the DDIM-GMM forward marginals have the same first
    and second order moments as the DDPM forward marginals.

    Supports:
    - Standard DDIM sampling (deterministic when eta=0)
    - DDIM-GMM sampling with mixture kernels
    - Unconditional and conditional generation
    - Classifier-free guidance

    Reference: "Denoising Diffusion Implicit Models" (Song et al., 2020)
               https://arxiv.org/abs/2010.02502
    """

    def __init__(self, model, schedule="linear", **kwargs):
        """
        Initialize DDIM sampler.

        Args:
            model: Diffusion model (see model interface requirements in header)
            schedule: Beta schedule type (inherited from model)
            **kwargs:
                gpu: Device specification (False, "mps", or int GPU index)
                gmm: Whether to use DDIM-GMM sampling
                gmm_params: GMM parameters object
                guidance_fn: Optional guidance function
                inflate_latent_norms: Whether to inflate latent norms
        """
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule
        self.gpu = False
        if 'gpu' in kwargs:
            self.gpu = kwargs['gpu']
        self.gmm = False
        if 'gmm' in kwargs:
            self.gmm = kwargs['gmm']
        self.gmm_params = None
        if 'gmm_params' in kwargs:
            self.gmm_params = kwargs['gmm_params']
        self.guidance_fn = None
        if 'guidance_fn' in kwargs:
            self.guidance_fn = kwargs['guidance_fn']
        self.inflate_latent_norms = False
        if 'inflate_latent_norms' in kwargs:
            self.inflate_latent_norms = kwargs['inflate_latent_norms']

    def register_buffer(self, name, attr):
        """
        Register a buffer (tensor or other attribute) with device handling.

        Handles CPU, CUDA, and MPS (Apple Silicon) devices.
        """
        if type(attr) == torch.Tensor:
            if self.gpu is not False:
                if self.gpu == "mps" and torch.backends.mps.is_available():
                    # MPS doesn't support float64, convert to float32
                    if attr.dtype == torch.float64:
                        attr = attr.float()
                    if attr.device != torch.device("mps"):
                        attr = attr.to(torch.device("mps"))
                elif isinstance(self.gpu, int):
                    assert isinstance(self.gpu, int), "Specified GPU must be an integer in the range 0-(#GPUs-1)"
                    if attr.device != torch.device("cuda:{}".format(str(self.gpu))):
                        attr = attr.to(torch.device("cuda:{}".format(str(self.gpu))))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        """
        Create DDIM sampling schedule.

        Args:
            ddim_num_steps: Number of DDIM sampling steps
            ddim_discretize: Discretization method ('uniform' or 'quad')
            ddim_eta: Stochasticity parameter (0 = deterministic)
            verbose: Whether to print schedule info
        """
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               **kwargs
               ):
        """
        DDIM-GMM sampling.

        Args:
            S: Number of sampling steps
            batch_size: Batch size
            shape: Shape of latent (C, H, W)
            conditioning: Conditional input (e.g., class labels)
            eta: Stochasticity (0 = deterministic)
            unconditional_guidance_scale: Classifier-free guidance scale
            unconditional_conditioning: Unconditional embeddings for guidance
            x_T: Optional starting noise (if None, sample from N(0,I))
            ... (see code for other optional parameters)

        Returns:
            Tuple of (samples, intermediates_dict)
        """
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,):
        """
        Main DDIM sampling loop.
        """
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps.shape[0])) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps.shape[0] if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None):
        """
        Single DDIM sampling step.

        Implements Equation 12 from DDIM paper:
        x_{t-1} = sqrt(α_{t-1}) * pred_x0 + sqrt(1 - α_{t-1} - σ_t^2) * ε_θ(x_t) + σ_t * ε

        With DDIM-GMM when self.gmm = True.
        """
        b, *_, device = *x.shape, x.device

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)

        # Sampling from a GMM (MUST happen BEFORE model call)
        if self.gmm and self.gmm_params is not None:
            # Number of mixture components
            n_components = self.gmm_params.n_components

            # For sampling at a single timestep, use per-step weights directly
            # (NOT marginal weights - those are only for multi-step trajectories)
            mixture_weights = self.gmm_params.mixture_weights[index].to(x.device)
            orth_mean_offsets = self.gmm_params.means[index].to(x.device)

            # Center the offsets
            weighted_orth_mean_offsets = mixture_weights.unsqueeze(0) * orth_mean_offsets
            centered_orth_mean_offsets = orth_mean_offsets - torch.sum(weighted_orth_mean_offsets, dim=-1, keepdim=True)
            centered_orth_mean_offsets = centered_orth_mean_offsets.unsqueeze(0).repeat(b, 1, 1)

            # Sample component for this timestep
            mixture_weights = mixture_weights.repeat(b, 1)
            pi = Categorical(probs=mixture_weights)
            mixture_comp_idx = pi.sample().to(device=device)
            sampled_comp_weight = torch.gather(mixture_weights, -1, mixture_comp_idx.unsqueeze(-1).type(torch.int64))
            mixture_comp_idx = mixture_comp_idx.unsqueeze(-1).unsqueeze(-1)
            mixture_comp_idx = mixture_comp_idx.expand(-1, centered_orth_mean_offsets.shape[1], -1)
            sampled_mean_offset = torch.gather(centered_orth_mean_offsets, -1, mixture_comp_idx.type(torch.int64))
            # Resizing
            sampled_mean_offset = sampled_mean_offset.view(b, *x.shape[1:])

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            model_output = self.model.apply_model(x, t, c)
            # Check if model predicts learned variance (double channels)
            if model_output.shape[1] == x.shape[1] * 2:
                # Split into noise and variance predictions
                e_t, var_pred = model_output.chunk(2, dim=1)
                # Note: variance is discarded when no CFG is applied
            else:
                # Standard case: model only predicts noise
                e_t = model_output
        else:
            # Check if we're using DiT's native CFG mechanism
            # DiT models have forward_with_cfg and use 'adm' conditioning
            use_dit_native_cfg = (hasattr(self.model.model, 'diffusion_model') and
                                 hasattr(self.model.model.diffusion_model, 'forward_with_cfg') and
                                 self.model.model.conditioning_key == 'adm')

            if use_dit_native_cfg:
                # DiT uses internal CFG - pass cfg_scale to apply_model
                # Only pass conditional inputs (no batch doubling)
                model_output = self.model.apply_model(x, t, c, cfg_scale=unconditional_guidance_scale)
                # DiT's forward_with_cfg already handles the CFG formula internally
                e_t = model_output
            else:
                # Standard external CFG for other models
                x_in = torch.cat([x] * 2)
                t_in = torch.cat([t] * 2)
                c_in = torch.cat([unconditional_conditioning, c])
                model_output = self.model.apply_model(x_in, t_in, c_in)
                model_output_uncond, model_output_cond = model_output.chunk(2)

                # Check if model predicts learned variance (double channels)
                if model_output.shape[1] == x.shape[1] * 2:
                    # Split into noise and variance predictions
                    e_t_uncond, var_uncond = model_output_uncond.chunk(2, dim=1)
                    e_t, var_cond = model_output_cond.chunk(2, dim=1)
                    # Apply CFG only to noise channels
                    e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
                    # Note: variance predictions are discarded; framework uses fixed schedule
                else:
                    # Standard case: model only predicts noise
                    e_t_uncond = model_output_uncond
                    e_t = model_output_cond
                    # CFG on all channels
                    e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

        if score_corrector is not None:
            assert self.model.parameterization == "eps"
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        if self.guidance_fn is not None:
            e_t -= sqrt_one_minus_at * self.guidance_fn(x, t)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t

        # Apply GMM offset and variance adjustment
        if self.gmm and self.gmm_params is not None:
            scale = self.gmm_params.scale
            if self.gmm_params.upper_bound_vars:
                # Variance upper bounds
                vars_upper_bound = (scale ** 2) * (1.0/n_components)
            else:
                # True variances
                weighted_centered_orth_mean_offsets = mixture_weights.unsqueeze(-2) * centered_orth_mean_offsets
                pooled_cov = torch.bmm(weighted_centered_orth_mean_offsets, centered_orth_mean_offsets.transpose(1, 2))
                vars = torch.diagonal(pooled_cov, dim1=-2, dim2=-1)

            dir_xt += sampled_mean_offset

            # Subtracting a worst-case diagonal approximation
            sigma_t_sq = (sigma_t ** 2).repeat(1, *x.shape[1:]).view(b, -1)
            if self.gmm_params.upper_bound_vars:
                # Variance upper bounds
                sigma_t_sq[:, :n_components-1] -= vars_upper_bound
            else:
                # True variances
                sigma_t_sq[:, :] -= vars

            sigma_t_sq = sigma_t_sq.view(*x.shape)
            sigma_t = relu(sigma_t_sq).sqrt()

        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)

        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise

        return x_prev, pred_x0

    @torch.no_grad()
    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
        """
        Encode image to noisy latent at timestep t.

        Args:
            x0: Clean latent
            t: Timestep index
            use_original_steps: Whether to use original DDPM steps
            noise: Optional noise (if None, sample new noise)

        Returns:
            Noisy latent x_t
        """
        # fast, but does not allow for exact reconstruction
        # t serves as an index to gather the correct alphas
        if use_original_steps:
            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
        else:
            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas

        if noise is None:
            noise = torch.randn_like(x0)
        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
               use_original_steps=False):
        """
        Decode from noisy latent to clean image.

        Args:
            x_latent: Noisy latent at timestep t_start
            cond: Conditioning
            t_start: Starting timestep
            unconditional_guidance_scale: CFG scale
            unconditional_conditioning: Unconditional embeddings
            use_original_steps: Whether to use original DDPM steps

        Returns:
            Decoded samples
        """
        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                          unconditional_guidance_scale=unconditional_guidance_scale,
                                          unconditional_conditioning=unconditional_conditioning)
        return x_dec


# ================================================================================
# GAUSSIAN MIXTURE MODEL (GMM) FOR DDIM-GMM
# ================================================================================

class GMM(object):
    """
    Gaussian Mixture Model parameters for DDIM-GMM.

    This class stores mixture parameters for DDIM-GMM, which replaces DDIM's
    unimodal Gaussian kernel with a multimodal Gaussian mixture kernel. The
    mixture components have component-specific mean and covariance offsets.

    At each timestep, a mixture component k_t is sampled from a categorical
    distribution with probabilities π_t. The parameters are constrained so that
    the DDIM-GMM forward marginals have the same first and second order moments
    as the DDPM forward marginals.

    Key features:
    - Moment constraints: forward marginals match DDPM forward marginals
    - Upper-bound variance (VUB) or full covariance computation
    """

    def __init__(self, **kwargs):
        """
        Initialize GMM parameters.

        Args:
            **kwargs:
                gpu: Device specification (False, "mps", or int GPU index)
        """
        super().__init__()
        self.dim = 0
        self.n_components = 0
        self.means = None
        self.cov = None
        self.scale = 1.0
        self.dynamic_scale = False
        self.gpu = False
        if 'gpu' in kwargs:
            self.gpu = kwargs['gpu']

    def initialize(self, dim, n_components, n_steps, scale=1.0, uniform_priors=False, orthonormal=False, init_cov=False,
                   upper_bound_vars=False, dynamic_scale=False):
        """
        Initialize GMM parameters with specified configuration.

        Args:
            dim: Latent dimensionality
            n_components: Number of mixture components
            n_steps: Number of diffusion timesteps
            scale: Scale for mean offsets
            uniform_priors: If True, use uniform mixture weights; else learn
            orthonormal: If True, use orthonormal initialization for means
            init_cov: If True, initialize covariance matrices
            upper_bound_vars: If True, use VUB approximation
            dynamic_scale: If True, use dynamic scaling during training
        """
        if self.gpu is not False:
            if self.gpu == "mps" and torch.backends.mps.is_available():
                device = torch.device("mps")
            elif isinstance(self.gpu, int):
                device = torch.device("cuda:{}".format(str(self.gpu)))
            else:
                device = torch.device("cpu")
        else:
            device = torch.device("cpu")

        # Number of mixture components
        self.dim = dim
        self.n_components = n_components
        self.n_steps = n_steps
        self.scale = scale
        self.dynamic_scale = dynamic_scale
        self.upper_bound_vars = upper_bound_vars

        if uniform_priors:
            self.mixture_weights = (1. / self.n_components) * torch.ones((self.n_steps, self.n_components), device=device)
        else:  # Can we learn them?
            self.mixture_weights = softmax(torch.randn((self.n_steps, self.n_components), device=device))

        self.means = torch.randn(self.n_steps, self.dim, self.n_components, device=device)
        if orthonormal:
            svd_means = torch.linalg.svd(self.means, full_matrices=False)
            if self.dynamic_scale:
                self.means = svd_means[0]
            else:
                self.means = self.scale * svd_means[0]
        else:
            if self.dynamic_scale:
                self.means = normalize(self.means, p=2.0, dim=1)
            else:
                self.means = self.scale * normalize(self.means, p=2.0, dim=1)

        if init_cov:
            self.cov = torch.randn(self.n_steps, dim, dim, n_components, device=device)


# ================================================================================
# END OF FILE
# ================================================================================
