"""
This module provides the core components and utility functions for a diffusion model,
based on Denoising Diffusion Probabilistic Models (DDPM).

It includes functions for:
- Creating noise schedules (beta schedules).
- Pre-calculating and registering all necessary diffusion variables (alphas, betas, etc.).
- The forward process (q_sample) to add noise to data.
- The reverse process (p_sample, ddim_sample) to generate data from noise.
- Calculating the diffusion loss (`p_losses`), including support for ELBO and p2-weighting.
"""

import torch
import numpy as np
from tqdm import tqdm


def make_beta_schedule(
    schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
    """
    Creates a beta schedule for the diffusion process.

    Args:
        schedule (str): The type of schedule to use ('linear', 'cosine', 'sqrt_linear', 'sqrt').
        n_timestep (int): The number of diffusion timesteps.
        linear_start (float, optional): The starting value for linear schedules. Defaults to 1e-4.
        linear_end (float, optional): The ending value for linear schedules. Defaults to 2e-2.
        cosine_s (float, optional): The s parameter for the cosine schedule. Defaults to 8e-3.

    Raises:
        ValueError: If the schedule name is unknown.

    Returns:
        torch.Tensor: A tensor of beta values of shape (n_timestep,).
    """
    if schedule == "linear":
        betas = (
            torch.linspace(
                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
            )
            ** 2
        )
    elif schedule == "cosine":
        t = torch.arange(n_timestep + 1, dtype=torch.float64)
        alphas_cumprod = (
            torch.cos(((t / n_timestep) + cosine_s) / (1 + cosine_s) * torch.pi * 0.5)
            ** 2
        )
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        betas = torch.clip(betas, 0, 0.999)
    elif schedule == "sqrt_linear":
        betas = torch.linspace(
            linear_start, linear_end, n_timestep, dtype=torch.float64
        )
    elif schedule == "sqrt":
        betas = (
            torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
            ** 0.5
        )
    else:
        raise ValueError(f"schedule '{schedule}' unknown.")
    return betas.float()


def register_diffusion_buffers(
    beta_schedule="linear",
    timesteps=1000,
    linear_start=1e-4,
    linear_end=2e-2,
    cosine_s=8e-3,
    given_betas=None,
    device="cpu",
):
    """
    Pre-computes and registers all necessary variables for the diffusion process.

    Based on a given beta schedule, this function calculates alphas, cumulative alpha products,
    and coefficients used in the forward (q) and reverse (p) processes, and stores them in
    a dictionary.

    Args:
        beta_schedule (str, optional): The name of the beta schedule to use. Defaults to "linear".
        timesteps (int, optional): The number of diffusion timesteps. Defaults to 1000.
        linear_start (float, optional): The starting beta value for linear schedules. Defaults to 1e-4.
        linear_end (float, optional): The ending beta value for linear schedules. Defaults to 2e-2.
        cosine_s (float, optional): The 's' parameter for cosine schedules. Defaults to 8e-3.
        given_betas (torch.Tensor, optional): An external tensor of betas to use instead of
                                             generating a new one. Defaults to None.
        device (str, optional): The device to store the buffers on. Defaults to "cpu".

    Returns:
        dict: A dictionary of pre-computed tensors for the diffusion process.
    """
    if given_betas is not None:
        betas = given_betas
    else:
        betas = make_beta_schedule(
            schedule=beta_schedule,
            n_timestep=timesteps,
            linear_start=linear_start,
            linear_end=linear_end,
            cosine_s=cosine_s,
        )

    # Pass betas to the device
    betas = betas.to(device)

    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat(
        [torch.tensor([1.0], device=device), alphas_cumprod[:-1]]
    )

    buffers = {
        "betas": betas,
        "alphas_cumprod": alphas_cumprod,
        "alphas_cumprod_prev": alphas_cumprod_prev,
        # calculations for diffusion q(x_t | x_{t-1}) and others
        "sqrt_alphas_cumprod": torch.sqrt(alphas_cumprod),
        "sqrt_one_minus_alphas_cumprod": torch.sqrt(1.0 - alphas_cumprod),
        "log_one_minus_alphas_cumprod": torch.log(1.0 - alphas_cumprod),
        "sqrt_recip_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod),
        "sqrt_recipm1_alphas_cumprod": torch.sqrt(1.0 / alphas_cumprod - 1),
        # calculations for posterior q(x_{t-1} | x_t, x_0)
        "posterior_variance": betas
        * (1.0 - alphas_cumprod_prev)
        / (1.0 - alphas_cumprod),
        "posterior_log_variance_clipped": torch.log(
            torch.clamp(
                betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod), min=1e-20
            )
        ),
        "posterior_mean_coef1": betas
        * torch.sqrt(alphas_cumprod_prev)
        / (1.0 - alphas_cumprod),
        "posterior_mean_coef2": (1.0 - alphas_cumprod_prev)
        * torch.sqrt(alphas)
        / (1.0 - alphas_cumprod),
    }

    for k, v in buffers.items():
        buffers[k] = v.to(device)

    return buffers


def extract_into_tensor(a, t, x_shape):
    """
    Extracts values from a 1D tensor `a` at indices specified by `t` and reshapes them.

    The output tensor is reshaped to be broadcastable with a tensor of shape `x_shape`.

    Args:
        a (torch.Tensor): The 1D tensor to extract values from.
        t (torch.Tensor): A tensor of indices.
        x_shape (tuple): The shape of the target tensor for broadcasting.

    Returns:
        torch.Tensor: The extracted values, reshaped to (b, 1, 1, ...).
    """
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def q_sample(x_start, t, diffusion_buffers, noise=None):
    """
    Performs the forward diffusion process (q(x_t | x_0)).

    It takes an initial image `x_start` and diffuses it to a specified timestep `t`.

    Args:
        x_start (torch.Tensor): The initial data tensor (x_0).
        t (torch.Tensor): The target timestep(s).
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        noise (torch.Tensor, optional): Optional noise tensor to use. If None, it will be
                                        randomly generated. Defaults to None.

    Returns:
        torch.Tensor: The noisy tensor at timestep t (x_t).
    """
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_alphas_cumprod"], t, x_start.shape
    )
    sqrt_one_minus_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_one_minus_alphas_cumprod"], t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise


def predict_start_from_noise(x_t, t, noise, diffusion_buffers):
    """
    Predicts the initial state x_0 from a noisy state x_t and the predicted noise.

    This is the reverse of the q_sample process.

    Args:
        x_t (torch.Tensor): The noisy tensor at timestep t.
        t (torch.Tensor): The current timestep(s).
        noise (torch.Tensor): The predicted noise from the model.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.

    Returns:
        torch.Tensor: The predicted initial state (x_0).
    """
    sqrt_recip_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_recip_alphas_cumprod"], t, x_t.shape
    )
    sqrt_recipm1_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_recipm1_alphas_cumprod"], t, x_t.shape
    )
    return sqrt_recip_alphas_cumprod_t * x_t - sqrt_recipm1_alphas_cumprod_t * noise


def predict_start_from_v(x_t, t, v, diffusion_buffers):
    """
    Predicts the initial state x_0 from a noisy state x_t and the predicted velocity v.

    Used when the model is trained to predict the velocity as in progressive distillation.

    Args:
        x_t (torch.Tensor): The noisy tensor at timestep t.
        t (torch.Tensor): The current timestep(s).
        v (torch.Tensor): The predicted velocity from the model.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.

    Returns:
        torch.Tensor: The predicted initial state (x_0).
    """
    sqrt_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_alphas_cumprod"], t, x_t.shape
    )
    sqrt_one_minus_alphas_cumprod_t = extract_into_tensor(
        diffusion_buffers["sqrt_one_minus_alphas_cumprod"], t, x_t.shape
    )
    return sqrt_alphas_cumprod_t * x_t - sqrt_one_minus_alphas_cumprod_t * v


def q_posterior_mean_variance(x_start, x_t, t, diffusion_buffers):
    """
    Computes the mean and variance of the posterior distribution q(x_{t-1} | x_t, x_0).

    Args:
        x_start (torch.Tensor): The initial data tensor (predicted x_0).
        x_t (torch.Tensor): The noisy tensor at timestep t.
        t (torch.Tensor): The current timestep(s).
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.

    Returns:
        tuple: A tuple containing:
            - posterior_mean (torch.Tensor): The mean of the posterior distribution.
            - posterior_variance (torch.Tensor): The variance of the posterior distribution.
            - posterior_log_variance_clipped (torch.Tensor): The log of the clipped posterior variance.
    """
    posterior_mean_coef1 = extract_into_tensor(
        diffusion_buffers["posterior_mean_coef1"], t, x_start.shape
    )
    posterior_mean_coef2 = extract_into_tensor(
        diffusion_buffers["posterior_mean_coef2"], t, x_start.shape
    )
    posterior_mean = posterior_mean_coef1 * x_start + posterior_mean_coef2 * x_t

    posterior_variance = extract_into_tensor(
        diffusion_buffers["posterior_variance"], t, x_start.shape
    )
    posterior_log_variance_clipped = extract_into_tensor(
        diffusion_buffers["posterior_log_variance_clipped"], t, x_start.shape
    )
    return posterior_mean, posterior_variance, posterior_log_variance_clipped


def p_losses(
    model,
    x_start,
    t,
    diffusion_buffers,
    cond,
    noise=None,
    clip_denoised=True,
    original_elbo_weight=0.0,
    l_simple_weight=1.0,
    p2_gamma=0.5,
    p2_k=1.0,
):
    """
    Calculates the diffusion model loss.

    This function can compute a combination of the simplified L_simple loss (MSE between
    predicted noise and true noise) and the full variational lower bound (VLB) loss.
    It also supports p2-weighting, which adjusts the loss based on the signal-to-noise
    ratio at different timesteps.

    Args:
        model (torch.nn.Module): The diffusion model.
        x_start (torch.Tensor): The initial data (x_0).
        t (torch.Tensor): The timestep(s) for which to calculate the loss.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        cond (torch.Tensor): The conditioning information for the model.
        noise (torch.Tensor, optional): The noise added to x_start. If None, it's generated.
                                        Defaults to None.
        clip_denoised (bool, optional): Whether to clip the predicted x_0 to [-1, 1].
                                        Defaults to True.
        original_elbo_weight (float, optional): The weight for the VLB loss term.
                                                If 0, VLB is not computed. Defaults to 0.0.
        l_simple_weight (float, optional): The weight for the simplified loss term.
                                           Defaults to 1.0.
        p2_gamma (float, optional): The gamma parameter for p2-weighting. Defaults to 0.5.
        p2_k (float, optional): The k parameter for p2-weighting. Defaults to 1.0.

    Returns:
        tuple: A tuple containing:
            - loss (torch.Tensor): The final combined loss (scalar).
            - loss_simple (torch.Tensor): The unweighted simple loss per batch item.
    """
    noise = torch.randn_like(x_start) if noise is None else noise
    x_t = q_sample(x_start, t, diffusion_buffers, noise)
    model_output = model(t, x_t, cond)

    target = noise

    loss_simple = torch.nn.functional.mse_loss(target, model_output, reduction="none")
    loss_simple = loss_simple.mean(dim=list(range(1, len(x_start.shape))))

    # SNR (p2) weighting
    alpha_bar_t = extract_into_tensor(
        diffusion_buffers["alphas_cumprod"], t, x_start.shape
    ).squeeze()  # [B]
    snr_t = alpha_bar_t / torch.clamp(1.0 - alpha_bar_t, min=1e-8)

    weights = (1.0 / (snr_t + p2_k)).pow(p2_gamma)  # [B]
    loss_simple_weighted = (weights * loss_simple).mean()

    # Initialize as zeros matching per-sample shape to avoid scalar-float issues
    loss_vlb = torch.zeros_like(loss_simple)
    if original_elbo_weight > 0:
        # VLB loss
        t_minus_1 = t - 1
        # Prevent negative indices
        t_minus_1 = torch.clamp(t_minus_1, min=0)

        # q(x_{t-1} | x_t, x_0)
        q_posterior_mean, _, q_posterior_log_variance = q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t, diffusion_buffers=diffusion_buffers
        )

        # p(x_{t-1} | x_t)
        p_mean, _, _ = p_mean_variance(
            model,
            x_t,
            t,
            diffusion_buffers,
            clip_denoised,
            cond,
        )

        # KL divergence between q and p
        kl_term = torch.nn.functional.mse_loss(
            q_posterior_mean, p_mean, reduction="none"
        )
        kl_term = kl_term / (2 * torch.exp(q_posterior_log_variance))
        kl_term = kl_term.mean(dim=list(range(1, len(x_start.shape))))

        # Only for t > 0
        mask = (t > 0).float()
        loss_vlb = mask * kl_term

    # Combine losses (use weighted simple loss; keep VLB term if used)
    loss = (
        l_simple_weight * loss_simple_weighted + original_elbo_weight * loss_vlb.mean()
    )
    return loss, loss_simple


def p_mean_variance(model, x, t, diffusion_buffers, clip_denoised, cond=None):
    """
    Computes the mean and variance for the reverse process p(x_{t-1} | x_t).

    This involves running the model to predict x_0 and then using the posterior
    calculation `q_posterior_mean_variance`.

    Args:
        model (torch.nn.Module): The diffusion model.
        x (torch.Tensor): The noisy tensor at timestep t (x_t).
        t (torch.Tensor): The current timestep(s).
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        clip_denoised (bool): Whether to clip the predicted x_0 to [-1, 1].
        cond (torch.Tensor, optional): Conditioning information. Defaults to None.

    Returns:
        tuple: A tuple containing:
            - model_mean (torch.Tensor): The mean of the reverse process distribution.
            - posterior_variance (torch.Tensor): The variance of the reverse process distribution.
            - posterior_log_variance_clipped (torch.Tensor): The log variance.
    """
    model_output = model(t, x, cond)
    x_recon = predict_start_from_noise(x, t, model_output, diffusion_buffers)

    if clip_denoised:
        x_recon.clamp_(-1.0, 1.0)

    model_mean, posterior_variance, posterior_log_variance_clipped = (
        q_posterior_mean_variance(x_recon, x, t, diffusion_buffers)
    )
    return model_mean, posterior_variance, posterior_log_variance_clipped


def p_sample(
    model,
    x,
    t,
    diffusion_buffers,
    clip_denoised=False,
    cond=None,
    temperature=1.0,
):
    """
    Performs a single sampling step of the reverse diffusion process (p-sample).

    Args:
        model (torch.nn.Module): The diffusion model.
        x (torch.Tensor): The noisy tensor at the current timestep (x_t).
        t (int): The current timestep.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        clip_denoised (bool, optional): Whether to clip the predicted x_0. Defaults to False.
        cond (torch.Tensor, optional): Conditioning information. Defaults to None.
        temperature (float, optional): The temperature of the sampling noise. Defaults to 1.0.

    Returns:
        torch.Tensor: The denoised tensor for the previous timestep (x_{t-1}).
    """
    b, *_, device = *x.shape, x.device
    t_tensor = torch.full((b,), t, device=device, dtype=torch.long)
    model_mean, model_variance, _ = p_mean_variance(
        model, x, t_tensor, diffusion_buffers, clip_denoised, cond
    )

    noise = torch.randn_like(x) * temperature
    # no noise when t == 0
    nonzero_mask = (1 - (t_tensor == 0).float()).reshape(
        b, *((1,) * (len(x.shape) - 1))
    )

    return model_mean + nonzero_mask * torch.sqrt(model_variance) * noise


@torch.no_grad()
def p_sample_loop(
    model,
    shape,
    cond,
    diffusion_buffers,
    device,
    timesteps,
    clip_denoised=False,
    temperature=1.0,
    verbose=False,
):
    """
    Runs the full p-sampling loop to generate a sample from noise.

    This function starts with pure Gaussian noise and iteratively applies the `p_sample`
    step for `timesteps` to generate a clean sample.

    Args:
        model (torch.nn.Module): The diffusion model.
        shape (tuple): The shape of the desired output tensor.
        cond (torch.Tensor): The conditioning information.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        device (torch.device): The device to perform sampling on.
        timesteps (int): The total number of timesteps in the reverse process.
        clip_denoised (bool, optional): Whether to clip the predicted x_0 at each step.
                                        Defaults to False.
        temperature (float, optional): The temperature of the sampling noise. Defaults to 1.0.
        verbose (bool, optional): If True, shows a progress bar. Defaults to False.

    Returns:
        torch.Tensor: The final generated sample.
    """
    img = torch.randn(shape, device=device)

    iterator = range(timesteps - 1, -1, -1)
    if verbose:
        iterator = tqdm(iterator, desc="Sampling t")

    for i in iterator:
        img = p_sample(
            model,
            img,
            i,
            diffusion_buffers,
            clip_denoised,
            cond,
            temperature,
        )

    return img


@torch.no_grad()
def ddim_sample(
    model,
    x,
    t,
    t_prev,
    diffusion_buffers,
    clip_denoised=False,
    cond=None,
    eta=0.0,
):
    """
    Performs a single sampling step using the Denoising Diffusion Implicit Models (DDIM) formula.

    DDIM is a more general sampling method that can produce samples in fewer steps than
    the standard DDPM sampler. Setting eta=0 makes the process deterministic.

    Args:
        model (torch.nn.Module): The diffusion model.
        x (torch.Tensor): The noisy tensor at the current timestep (x_t).
        t (torch.Tensor): The current timestep(s).
        t_prev (torch.Tensor): The previous timestep(s).
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        clip_denoised (bool, optional): Whether to clip the predicted x_0. Defaults to False.
        cond (torch.Tensor, optional): Conditioning information. Defaults to None.
        eta (float, optional): The DDIM eta parameter. `eta=0` for deterministic DDIM,
                               `eta=1` for DDPM-like stochasticity. Defaults to 0.0.

    Returns:
        torch.Tensor: The denoised tensor for the previous timestep (x_{t-1}).
    """
    # Predict noise using the model
    pred_noise = model(t, x, cond)

    # Predict the starting image x_0 from the noisy image x_t and the predicted noise
    x_start = predict_start_from_noise(x, t, pred_noise, diffusion_buffers)

    if clip_denoised:
        x_start.clamp_(-1.0, 1.0)

    # Clamp indices to valid range to avoid out-of-bounds gather
    max_idx = diffusion_buffers["alphas_cumprod"].shape[0] - 1
    t_cap = t.clamp(min=0, max=max_idx)
    t_prev_cap = t_prev.clamp(min=-1, max=max_idx)

    # Get alpha products for current and previous timesteps
    alpha_prod_t = extract_into_tensor(diffusion_buffers["alphas_cumprod"], t, x.shape)

    # Create a mask for the final step where t_prev is -1
    is_final_step = t_prev < 0

    # Use a clamped version of t_prev for safe indexing in the 'else' case
    safe_t_prev = t_prev.clamp(min=0)

    # Safely extract the alpha_prod values for the previous timesteps
    alpha_prod_t_prev_values = extract_into_tensor(
        diffusion_buffers["alphas_cumprod"], safe_t_prev, x.shape
    )

    # Use torch.where to select 1.0 for the final step, and the correct value otherwise
    # The reshape is needed for torch.where to broadcast correctly
    alpha_prod_t_prev = torch.where(
        is_final_step.view(-1, *([1] * (len(x.shape) - 1))),
        torch.ones_like(alpha_prod_t),
        alpha_prod_t_prev_values,
    )
    # DDIM formula components
    # sigma_t calculation
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev

    # variance of posterior q(x_{t-1} | x_t, x_0)
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

    # DDIM sigma
    sigma = eta * torch.sqrt(variance)

    # Direction pointing to x_t
    # Clamp to prevent nan issues from floating point inaccuracies
    pred_dir_xt = (
        torch.sqrt((1 - alpha_prod_t_prev - sigma**2).clamp(min=0.0)) * pred_noise
    )

    # Mean of the posterior distribution
    x_prev = torch.sqrt(alpha_prod_t_prev) * x_start + pred_dir_xt

    # Add noise if eta > 0
    if eta > 0:
        noise = torch.randn_like(x)
        x_prev = x_prev + sigma * noise

    return x_prev


@torch.no_grad()
def ddim_sample_loop(
    model,
    shape,
    cond,
    diffusion_buffers,
    device,
    total_timesteps,
    ddim_num_steps,
    eta=0.0,
    clip_denoised=False,
    verbose=False,
):
    """
    Runs the full DDIM sampling loop to generate a sample from noise.

    Args:
        model (torch.nn.Module): The diffusion model.
        shape (tuple): The shape of the desired output tensor.
        cond (torch.Tensor): The conditioning information.
        diffusion_buffers (dict): Dictionary of pre-computed diffusion variables.
        device (torch.device): The device to perform sampling on.
        total_timesteps (int): The total number of timesteps the model was trained on.
        ddim_num_steps (int): The number of DDIM steps to perform for sampling.
        eta (float, optional): The DDIM eta parameter. Defaults to 0.0.
        clip_denoised (bool, optional): Whether to clip the predicted x_0 at each step.
                                        Defaults to False.
        verbose (bool, optional): If True, shows a progress bar. Defaults to False.

    Returns:
        torch.Tensor: The final generated sample.
    """
    # Generate DDIM timesteps
    times = np.linspace(-1, total_timesteps - 1, ddim_num_steps + 1)
    times = list(reversed(times.astype(int)))
    # Clamp to valid range after integer casting
    times = [min(max(ti, -1), total_timesteps - 1) for ti in times]
    time_pairs = list(zip(times[:-1], times[1:]))

    img = torch.randn(shape, device=device)

    iterator = (
        tqdm(time_pairs, desc="DDIM Sampling", total=len(time_pairs))
        if verbose
        else time_pairs
    )

    for time, time_next in iterator:
        t = torch.full((shape[0],), time, device=device, dtype=torch.long)
        prev_t = torch.full((shape[0],), time_next, device=device, dtype=torch.long)

        img = ddim_sample(
            model,
            img,
            t,
            prev_t,
            diffusion_buffers,
            clip_denoised=clip_denoised,
            cond=cond,
            eta=eta,
        )

    return img
