import math
import torch
import numpy as np
from tqdm import tqdm

def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2):
    """
    Define a function to set the beta schedule, supporting multiple scheduling strategies.

    Parameters:
    - schedule: The type of schedule to use ('linear', 'const', 'quad', 'jsd', 'sigmoid', 'cosine', 'cosine_anneal').
    - num_timesteps: The number of timesteps for the schedule.
    - start: The starting value of beta.
    - end: The ending value of beta.

    Returns:
    - betas: The beta values for each timestep.
    """
    if schedule == "linear":
        betas = torch.linspace(start, end, num_timesteps)
    elif schedule == "const":
        betas = end * torch.ones(num_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
    elif schedule == "jsd":
        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    elif schedule == "cosine" or schedule == "cosine_reverse":
        max_beta = 0.999
        cosine_s = 0.008
        betas = torch.tensor(
            [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) /
                 (math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in
             range(num_timesteps)])
    elif schedule == "cosine_anneal":
        betas = torch.tensor(
            [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in
             range(num_timesteps)])
    return betas


def extract(input, t, x):
    """
    Extract specified elements from the input tensor.

    Parameters:
    - input: The input tensor.
    - t: The indices to extract.
    - x: The reference tensor to determine the shape.

    Returns:
    - out: The extracted elements reshaped to match the reference tensor.
    """
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)

def e2y(y_t_batch, epsilon_pred, t, alphas_bar_sqrt, one_minus_alphas_bar_sqrt):
    sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y_t_batch)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t_batch)
    y_converted = 1 / sqrt_alpha_bar_t * (y_t_batch - sqrt_one_minus_alpha_bar_t * epsilon_pred)
    return y_converted


def y2e(y_t_batch, y_pred, t, alphas_bar_sqrt, one_minus_alphas_bar_sqrt):
    sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y_t_batch)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t_batch)
    e_converted = 1/  sqrt_one_minus_alpha_bar_t * (y_t_batch - sqrt_alpha_bar_t * y_pred)
    return e_converted

def q_sample(y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise=None, fq_x=None):
    """
    Forward diffusion function to generate noisy samples or conditional samples.

    Parameters:
    - y: The input tensor.
    - alphas_bar_sqrt: Precomputed square root of alpha bar values.
    - one_minus_alphas_bar_sqrt: Precomputed square root of one minus alpha bar values.
    - t: The timestep.
    - noise: Optional noise tensor.
    - fq_x: Optional conditional input tensor.

    Returns:
    - y_t: The noisy sample at timestep t.
    """
    if noise is None:
        noise = torch.randn_like(y).to(y.device)
    sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)
    # q(y_t | y_0, x)
    if fq_x is None:
        y_t = sqrt_alpha_bar_t * y + sqrt_one_minus_alpha_bar_t * noise
    else:
        y_t = sqrt_alpha_bar_t * y + (1 - sqrt_alpha_bar_t) * fq_x + sqrt_one_minus_alpha_bar_t * noise
    return y_t


def p_sample(model, x, y_t, fp_x, t, alphas, one_minus_alphas_bar_sqrt, stochastic=False, fq_x=None):
    """
    Reverse diffusion process sampling -- one time step.

    Parameters:
    - model: The diffusion model.
    - x: The input tensor.
    - y_t: The noisy sample at timestep t.
    - fp_x: Embedding of fp encoder.
    - t: The current timestep.
    - alphas: Precomputed alpha values.
    - one_minus_alphas_bar_sqrt: Precomputed square root of one minus alpha bar values.
    - stochastic: Whether to use stochastic sampling.
    - fq_x: Optional conditional input tensor.

    Returns:
    - y_t_m_1: The sample at timestep t-1.
    """
    device = next(model.parameters()).device
    z = stochastic * torch.randn_like(y_t)
    t = torch.tensor([t]).to(device)
    alpha_t = extract(alphas, t, y_t)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t)
    sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y_t)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()
    # y_t_m_1 posterior mean component coefficients
    gamma_0 = (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
    gamma_1 = (sqrt_one_minus_alpha_bar_t_m_1.square()) * (alpha_t.sqrt()) / (sqrt_one_minus_alpha_bar_t.square())

    eps_theta = model(x, y_t, t, fp_x).to(device).detach()
    # y_0 reparameterization
    if fq_x is None:
        y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)
    else:
        y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - (1 - sqrt_alpha_bar_t) * fq_x - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)

    # posterior mean
    if fq_x is None:
        y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y_t
    else:
        gamma_2 = 1 + (sqrt_alpha_bar_t - 1) * (alpha_t.sqrt() + sqrt_alpha_bar_t_m_1) / (
            sqrt_one_minus_alpha_bar_t.square())
        y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y_t + gamma_2 * fq_x

    # posterior variance
    beta_t_hat = (sqrt_one_minus_alpha_bar_t_m_1.square()) / (sqrt_one_minus_alpha_bar_t.square()) * (1 - alpha_t)
    y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device)
    return y_t_m_1


def p_sample_t_1to0(model, x, y_t, fp_x, one_minus_alphas_bar_sqrt, fq_x=None):
    """
    Reverse function to sample y_0 given y_1.

    Parameters:
    - model: The diffusion model.
    - x: The input tensor.
    - y_t: The noisy sample at timestep t.
    - fp_x: Embedding of fp encoder.
    - one_minus_alphas_bar_sqrt: Precomputed square root of one minus alpha bar values.
    - fq_x: Optional conditional input tensor.

    Returns:
    - y_t_m_1: The sample at timestep t-1.
    """
    device = next(model.parameters()).device
    t = torch.tensor([0]).to(device)  # corresponding to timestep 1 (i.e., t=1 in diffusion models)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t)
    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    eps_theta = model(x, y_t, t, fp_x).to(device).detach()
    # y_0 reparameterization
    if fq_x is None:
        y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)
    else:
        y_0_reparam = 1 / sqrt_alpha_bar_t * (y_t - (1 - sqrt_alpha_bar_t) * fq_x - eps_theta * sqrt_one_minus_alpha_bar_t).to(device)

    y_t_m_1 = y_0_reparam.to(device)
    return y_t_m_1


def p_sample_loop(model, x, fp_x, n_steps, alphas, one_minus_alphas_bar_sqrt,
                  only_last_sample=True, stochastic=True, fq_x=None):
    """
    Sample a sequence of samples using the reverse diffusion process.

    Parameters:
    - model: The diffusion model.
    - x: The input tensor.
    - fp_x: Embedding of fp encoder.
    - n_steps: The number of timesteps.
    - alphas: Precomputed alpha values.
    - one_minus_alphas_bar_sqrt: Precomputed square root of one minus alpha bar values.
    - only_last_sample: Whether to return only the last sample.
    - stochastic: Whether to use stochastic sampling.
    - fq_x: Optional conditional input tensor.

    Returns:
    - y_0 or y_p_seq: The sample sequence or the final sample.
    """
    num_t, y_p_seq = None, None
    device = next(model.parameters()).device
    batch_size = x.shape[0]

    if fq_x is None:
        y_t = stochastic * torch.randn_like(torch.zeros([batch_size, model.y_dim])).to(device)
    else:
        y_t = stochastic * torch.randn_like(torch.zeros([batch_size, model.y_dim])).to(device) + fq_x

    if only_last_sample:
        num_t = 1
    else:
        y_p_seq = torch.zeros([y_t.shape[0], y_t.shape[1], n_steps + 1]).to(device)
        y_p_seq[:, :, n_steps] = y_t
    for t in reversed(range(1, n_steps)):
        y_t = p_sample(model, x, y_t, fp_x, t, alphas, one_minus_alphas_bar_sqrt, stochastic=stochastic, fq_x=fq_x)  # y_{t-1}
        if only_last_sample:
            num_t += 1
        else:
            y_p_seq[:, :, t] = y_t
    if only_last_sample:
        assert num_t == n_steps
        y_0 = p_sample_t_1to0(model, x, y_t, fp_x, one_minus_alphas_bar_sqrt, fq_x=fq_x)
        return y_0
    else:
        y_0 = p_sample_t_1to0(model, x, y_p_seq[:, :, 1], fp_x, one_minus_alphas_bar_sqrt, fq_x=fq_x)
        y_p_seq[:, :, 0] = y_0
        return y_p_seq


def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps):
    """
    Make timesteps for DDIM sampling.

    Parameters:
    - ddim_discr_method: The discretization method ('uniform' or 'quad').
    - num_ddim_timesteps: The number of DDIM timesteps.

    Returns:
    - steps_out: The timesteps for DDIM sampling.
    """
    if ddim_discr_method == 'uniform':
        c = num_ddpm_timesteps // num_ddim_timesteps
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
    elif ddim_discr_method == 'quad':
        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
    else:
        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')

    steps_out = ddim_timesteps + 1

    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
    """
    Make DDIM sampling parameters.

    Parameters:
    - alphacums: The cumulative alpha values.
    - ddim_timesteps: The timesteps for DDIM sampling.
    - eta: The noise level.

    Returns:
    - sigmas: The sigma values for DDIM sampling.
    - alphas: The alpha values for DDIM sampling.
    - alphas_prev: The previous alpha values for DDIM sampling.
    """
    device = alphacums.device
    alphas = alphacums[ddim_timesteps]
    alphas_prev = torch.tensor([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()).to(device)

    sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
    return sigmas, alphas, alphas_prev


def ddim_sample_loop(model, x_embed, fp_x, timesteps, y_dim, ddim_alphas, ddim_alphas_prev, ddim_sigmas, stochastic=True, pred_type='epsilon'):
    """
    Perform DDIM sampling.

    Parameters:
    - model: The diffusion model.
    - x_embed: The embedded input tensor.
    - fp_x: Embedding of fp encoder.
    - timesteps: The timesteps for DDIM sampling.
    - y_dim: The dimension of the output.
    - ddim_alphas: The alpha values for DDIM sampling.
    - ddim_alphas_prev: The previous alpha values for DDIM sampling.
    - ddim_sigmas: The sigma values for DDIM sampling.
    - stochastic: Whether to use stochastic sampling.
    - fq_x: Optional conditional input tensor.

    Returns:
    - y_t: The final sampled tensor.
    """
    device = next(model.parameters()).device
    batch_size = x_embed.shape[0]


    y_t = stochastic * torch.randn_like(torch.zeros([batch_size, y_dim])).to(device)

    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]

    for i, step in enumerate(time_range):
        index = total_steps - i - 1
        t = torch.full((batch_size,), step, device=device, dtype=torch.long)

        y_t, pred_y0 = ddim_sample_step(model, x_embed, y_t, fp_x, t, index, ddim_alphas, ddim_alphas_prev, ddim_sigmas, pred_type)

    return y_t, pred_y0


def ddim_sample_step(model, x_embed, y_t, fp_x, t, index, ddim_alphas, ddim_alphas_prev, ddim_sigmas, pred_type):
    """
    Perform a single step of DDIM sampling.

    Parameters:
    - model: The diffusion model.
    - x_embed: The embedded input tensor.
    - y_t: The noisy sample at timestep t.
    - fp_x: Embedding of fp encoder.
    - t: The current timestep.
    - index: The index of the current timestep.
    - ddim_alphas: The alpha values for DDIM sampling.
    - ddim_alphas_prev: The previous alpha values for DDIM sampling.
    - ddim_sigmas: The sigma values for DDIM sampling.

    Returns:
    - y_t_m_1: The sample at timestep t-1.
    - y_0_reparam: The reparameterized sample.
    """
    batch_size = x_embed.shape[0]
    device = next(model.parameters()).device
    
    sqrt_alphas = torch.sqrt(ddim_alphas)
    sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
    # select parameters corresponding to the currently considered timestep
    a_t = torch.full([batch_size, 1], ddim_alphas[index], device=device)
    a_t_m_1 = torch.full([batch_size, 1], ddim_alphas_prev[index], device=device)
    sigma_t = torch.full([batch_size, 1], ddim_sigmas[index], device=device)
    sqrt_one_minus_at = torch.full([batch_size, 1], sqrt_one_minus_alphas[index], device=device)

    if pred_type == 'epsilon':
        e_t = model(x_embed, y_t, t, fp_x).to(device).detach()
    else:
        pred_y = model(x_embed, y_t, t, fp_x).to(device).detach()
        e_t = 1 / sqrt_alphas * (y_t - sqrt_one_minus_alphas * pred_y)

    # direction pointing to y_t
    dir_y_t = (1. - a_t_m_1 - sigma_t ** 2).sqrt() * e_t
    noise = sigma_t * torch.randn_like(y_t).to(device)

    # reparameterize y_0
    y_0_reparam = (y_t - sqrt_one_minus_at * e_t) / a_t.sqrt()
    y_t_m_1 = a_t_m_1.sqrt() * y_0_reparam + dir_y_t + noise

    return y_t_m_1, y_0_reparam

def ddim_sample_loop_hard_steps(
    model, x_embed, fp_x,
    timesteps, y_dim,
    ddim_alphas, ddim_alphas_prev, ddim_sigmas,
    stochastic=True, pred_type='epsilon'
):

    device = next(model.parameters()).device
    batch_size = x_embed.shape[0]

    y_t = (stochastic * torch.randn(batch_size, y_dim, device=device)).float()

    step_hard_labels = []
    step_hard_labels.append(torch.argmax(y_t, dim=1).detach().cpu())

    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]

    for i, step in enumerate(time_range):
        index = total_steps - i - 1
        t = torch.full((batch_size,), int(step), device=device, dtype=torch.long)

        y_t, pred_y0 = ddim_sample_step(
            model, x_embed, y_t, fp_x, t, index,
            ddim_alphas, ddim_alphas_prev, ddim_sigmas,
            pred_type
        )

        step_hard_labels.append(torch.argmax(pred_y0, dim=1).detach().cpu())

    return y_t, step_hard_labels



def predict_sample(model, x_embed, fp_x, ddpm_num_timesteps, ddim_timesteps, y_dim, ddim_alphas, one_step, pred_type='y'):
    device = next(model.parameters()).device
    batch_size = x_embed.shape[0]

    # Initialize y_t (or y_0 at the first step, as we start with noise)
    y_T = torch.randn(batch_size, y_dim).to(device)  # Initial noise
    T = torch.full((batch_size,), ddpm_num_timesteps, device=device, dtype=torch.long)
    pred_y_0_t = model(x_embed, y_T, T, fp_x).to(device).detach()
    if pred_type == 'epsilon':
        sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
        total_index = ddim_timesteps.shape[0] - 1
        sqrt_a_T = torch.full([batch_size, 1], ddim_alphas[total_index], device=device).sqrt()
        sqrt_one_minus_aT = torch.full([batch_size, 1], sqrt_one_minus_alphas[total_index], device=device)
        pred_y_0_t = 1 / sqrt_a_T * (y_T - sqrt_one_minus_aT * pred_y_0_t)

    if not one_step:
        time_range = np.flip(ddim_timesteps)
        total_steps = ddim_timesteps.shape[0]

        for i, step in enumerate(time_range):
            index = total_steps - i - 1
            t = torch.full((batch_size,), step, device=device, dtype=torch.long)

            # Compute the parameters for the current timestep
            sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
            sqrt_a_t = torch.full([batch_size, 1], ddim_alphas[index], device=device).sqrt()
            sqrt_one_minus_at = torch.full([batch_size, 1], sqrt_one_minus_alphas[index], device=device)

            # Sample y_t from current pred_y_0
            noise = torch.randn_like(pred_y_0_t).to(device)
            y_t_m_1 = sqrt_a_t * pred_y_0_t + sqrt_one_minus_at * noise

            if pred_type == 'epsilon':
                # convert e to y_0 at current timestep
                pred_epsilon_t = model(x_embed, y_t_m_1, t, fp_x).to(device).detach()
                pred_y_0_t = 1 / sqrt_a_t * (y_t_m_1 - sqrt_one_minus_at * pred_epsilon_t)
            else:
                # Predict y_0 at current timestep
                pred_y_0_t = model(x_embed, y_t_m_1, t, fp_x).to(device).detach()

    return pred_y_0_t

def predict_sample_hard_steps(
    model, x_embed, fp_x,
    ddpm_num_timesteps, ddim_timesteps, y_dim,
    ddim_alphas, one_step, pred_type='y'
):
    device = next(model.parameters()).device
    batch_size = x_embed.shape[0]

    y_T = torch.randn(batch_size, y_dim, device=device)

    step_hard_labels = [torch.argmax(y_T, dim=1).detach().cpu()]

    T = torch.full((batch_size,), ddpm_num_timesteps, device=device, dtype=torch.long)
    pred_y_0_t = model(x_embed, y_T, T, fp_x).to(device).detach()

    if pred_type == 'epsilon':
        sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
        total_index = ddim_timesteps.shape[0] - 1
        sqrt_a_T = torch.full([batch_size, 1], ddim_alphas[total_index], device=device).sqrt()
        sqrt_one_minus_aT = torch.full([batch_size, 1], sqrt_one_minus_alphas[total_index], device=device)
        pred_y_0_t = 1 / sqrt_a_T * (y_T - sqrt_one_minus_aT * pred_y_0_t)

    step_hard_labels.append(torch.argmax(pred_y_0_t, dim=1).detach().cpu())

    if not one_step:
        time_range = np.flip(ddim_timesteps)
        total_steps = ddim_timesteps.shape[0]

        for i, step in enumerate(time_range):
            index = total_steps - i - 1
            t = torch.full((batch_size,), step, device=device, dtype=torch.long)

            sqrt_one_minus_alphas = torch.sqrt(1. - ddim_alphas)
            sqrt_a_t = torch.full([batch_size, 1], ddim_alphas[index], device=device).sqrt()
            sqrt_one_minus_at = torch.full([batch_size, 1], sqrt_one_minus_alphas[index], device=device)

            noise = torch.randn_like(pred_y_0_t, device=device)
            y_t_m_1 = sqrt_a_t * pred_y_0_t + sqrt_one_minus_at * noise

            if pred_type == 'epsilon':
                pred_epsilon_t = model(x_embed, y_t_m_1, t, fp_x).to(device).detach()
                pred_y_0_t = 1 / sqrt_a_t * (y_t_m_1 - sqrt_one_minus_at * pred_epsilon_t)
            else:
                pred_y_0_t = model(x_embed, y_t_m_1, t, fp_x).to(device).detach()

            step_hard_labels.append(torch.argmax(pred_y_0_t, dim=1).detach().cpu())

    return pred_y_0_t, step_hard_labels

