import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math



    


def beta_schedule(beta_schedule, beta_start, beta_end,  num_diffusion_timesteps): 
    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        print('beta schedule is not defined properly')
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    betas = torch.tensor(betas).type(torch.float)
    return betas

def cosine_schedule(timesteps):
    beta = betas_for_alpha_bar(
            timesteps,
            lambda t: (math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2),
        )
    return torch.tensor(beta).type(torch.float)



def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].
    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def get_index_from_list(vals, t, x_shape, config):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def compute_alpha(beta, t, config):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    beta = beta.to(config.model.device)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a

def compute_alpha2(beta, t, config):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    beta = beta.to(config.model.device)
    
    # Ensure t is also on the same device as beta and has the right dtype
    t = t.to(config.model.device).long()
    
    a = (1 - beta).cumprod(dim=0).index_select(0, t+1).view(-1, 1, 1, 1)
    return a