import torch

from helpers import (cosine_beta_schedule,
                     linear_beta_schedule,
                     vp_beta_schedule,
                     extract,
                     Losses,
                     Progress,
                     Silent)


def main():
    
    n_timesteps = 5
    beta_schedule = 'vp'
    base_beta = 0.6
    eta = 0.0

    if beta_schedule == 'linear':
        betas = linear_beta_schedule(n_timesteps)
    elif beta_schedule == 'cosine':
        betas = cosine_beta_schedule(n_timesteps)
    elif beta_schedule == 'vp':
        betas = vp_beta_schedule(n_timesteps)

    base_betas = torch.ones_like(betas) * base_beta
    target_betas = betas

    betas = base_betas * (target_betas / base_betas).pow(eta)

    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = torch.cat([torch.ones(1).to(betas), alphas_cumprod[:-1]])

    alphas_cumprod = alphas_cumprod
    alphas_cumprod_prev = alphas_cumprod_prev

    # calculations for diffusion q(x_t | x_{t-1}) and others
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod)
    sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod)
    sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1)

    # calculations for posterior q(x_{t-1} | x_t, x_0)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    posterior_variance = posterior_variance

    ## log calculation clipped because the posterior variance
    ## is 0 at the beginning of the diffusion chain
    posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20))
    posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
    posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)

    print("betas:", betas)
    print("sqrt_alphas_cumprod", sqrt_alphas_cumprod)
    print("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod)
    print("posterior_variance", posterior_variance)

if __name__ == "__main__":
    main()