import math

import torch


def extract(tensor, t, x):
    shape = x.shape
    out = torch.gather(tensor, 0, t.to(tensor.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)


def get_gammas(alphas, one_minus_alphas_bar_sqrt, t, y_t, squeeze=False):
    alpha_t = extract(alphas, t, y_t)
    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y_t)
    sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y_t)

    if squeeze:
        alpha_t = alpha_t.squeeze(1).squeeze(1)
        sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.squeeze(1).squeeze(1)
        sqrt_one_minus_alpha_bar_t_m_1 = (
            (sqrt_one_minus_alpha_bar_t_m_1).squeeze(1).squeeze(1)
        )

    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()
    sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()

    gamma_0 = (
        (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())
    )
    gamma_1 = (
        (sqrt_one_minus_alpha_bar_t_m_1.square())
        * (alpha_t.sqrt())
        / (sqrt_one_minus_alpha_bar_t.square())
    )
    gamma_2 = 1 + (sqrt_alpha_bar_t - 1) * (alpha_t.sqrt() + sqrt_alpha_bar_t_m_1) / (
        sqrt_one_minus_alpha_bar_t.square()
    )

    beta_t_hat = (
        (sqrt_one_minus_alpha_bar_t_m_1.square())
        / (sqrt_one_minus_alpha_bar_t.square())
        * (1 - alpha_t)
    )
    return sqrt_alpha_bar_t, gamma_0, gamma_1, gamma_2, beta_t_hat


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


def modify_gammas(sqrt_alpha_bar_t, gamma_0, gamma_1, gamma_2, beta_t_hat):
    return (
        sqrt_alpha_bar_t.unsqueeze(1).unsqueeze(2),
        gamma_0.unsqueeze(1).unsqueeze(2),
        gamma_1.unsqueeze(1).unsqueeze(2),
        gamma_2.unsqueeze(1).unsqueeze(2),
        beta_t_hat.unsqueeze(1).unsqueeze(2),
    )




def NST_normalize(device, inp, mask=None):
    inp = inp.to(device)
    means = torch.sum(inp, dim=1) / torch.sum(mask == 1, dim=1)
    means = means.unsqueeze(1).detach()
    x_enc = inp.sub(means)
    x_enc = x_enc.masked_fill(mask == 0, 0)
    stdev = torch.sqrt(
        torch.sum(x_enc * x_enc, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5
    )
    stdev = stdev.unsqueeze(1).detach()
    inp = x_enc.div(stdev)
    return inp, means, stdev


def NST_denormalize(outputs, means, stdev, pred_len):
    dec_out = outputs.mul(stdev[:, 0, :].unsqueeze(1).repeat(1, pred_len, 1))
    outputs = dec_out.add(means[:, 0, :].unsqueeze(1).repeat(1, pred_len, 1))
    return outputs


def invalid(name, tensor):
    if torch.isnan(tensor).any():
        print(f"{name} is NaN")
        print(tensor)
        return True
    if torch.isinf(tensor).any():
        print(f"{name} is Inf")
        print(tensor)
        return True

    return False


def make_beta_schedule(schedule="linear", num_timesteps=1000, start=1e-5, end=1e-2):
    if schedule == "linear":
        betas = torch.linspace(start, end, num_timesteps)
    elif schedule == "const":
        betas = end * torch.ones(num_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start**0.5, end**0.5, num_timesteps) ** 2
    elif schedule == "jsd":
        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    elif schedule in {"cosine", "cosine_reverse"}:
        max_beta = 0.999
        cosine_s = 0.008
        betas = torch.tensor([
            min(
                1
                - (
                    math.cos(
                        ((i + 1) / num_timesteps + cosine_s)
                        / (1 + cosine_s)
                        * math.pi
                        / 2
                    )
                    ** 2
                )
                / (
                    math.cos(
                        (i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2
                    )
                    ** 2
                ),
                max_beta,
            )
            for i in range(num_timesteps)
        ])
        if schedule == "cosine_reverse":
            betas = betas.flip(0)
    elif schedule == "cosine_anneal":
        betas = torch.tensor([
            start
            + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi))
            for t in range(num_timesteps)
        ])
    return betas
