import torch
from torch import nn, einsum
import torch.nn.functional as F

#scaled linear schedule
def linear_beta_schedule(timesteps):
    beta_start = 0.0015
    beta_end = 0.0195
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float32)**2

timesteps = 1000

# define beta schedule. Scaled linear beta schedule.
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

def extract(a, t, x_shape):

    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())

    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):

    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)

    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

#ensure that x_start is in the range [-1,1]
def get_noisy_image(x_start, t):

  # add noise
  x_noisy = q_sample(x_start, t=t)
  
  return x_noisy

