import torch
import numpy as np
from inspect import isfunction
import math

def exists(x):
    return x is not None

def cycle(dl):
    while True:
        for data in dl:
            yield data[0], data[1], data[2]

def _sqrt(x):
    tol = torch.zeros_like(x)
    return torch.sqrt(torch.max(x, tol))

def normal_log_density(x, mean, log_std, std):
    var = std.pow(2)
    log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std
    return log_density.sum(1, keepdim=True)

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def noise_like(shape, device, repeat=False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

EPSILON = 1e-9
def cosine_distance(x, y):
    numerator = torch.sum(x * y,dim=1)
    denominator = torch.sqrt(torch.sum(x**2,dim=1)) * torch.sqrt(torch.sum(y**2,dim=1))
    cos_similarity = numerator / (denominator + EPSILON)
    return torch.atan2(_sqrt(1. - cos_similarity**2), cos_similarity)

def cosine_beta_schedule(timesteps, s = 0.008):
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, a_min = 0, a_max = 0.999)
