
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


def extract(v, t, x_shape):
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def forward(self, x_0, labels):
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t =   extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
                extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='mean')
        return loss


class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T, w = 0.):
        super().__init__()

        self.model = model
        self.T = T
        self.w = w

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
        self.register_buffer('alphas_bar', alphas_bar)
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps

    def p_mean_variance(self, x_t, t, labels):
        var = extract(self.posterior_var, t, x_t.shape)
        eps = self.model(x_t, t, labels)
        nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
        eps = (1. + self.w) * eps - self.w * nonEps
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
        return xt_prev_mean, var

    def forward(self, x_T, labels, epoch):
        x_t = x_T
        with tqdm(reversed(range(self.T)), dynamic_ncols=True) as tqdmDataLoader:
            for time_step in tqdmDataLoader:
                t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
                mean, var= self.p_mean_variance(x_t=x_t, t=t, labels=labels)
                if time_step > 0:
                    noise = torch.randn_like(x_t)
                else:
                    noise = 0
                x_t = mean + torch.sqrt(var) * noise
                assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
                tqdmDataLoader.set_postfix(ordered_dict={"epoch": epoch})
        x_0 = x_t
        return torch.clip(x_0, -1, 1)


    def ddim_sample(self, x_T, labels, epoch, num_steps=10):
        x_t = x_T
        B = x_t.size(0)
        timesteps = torch.linspace(self.T - 1, 0, steps=num_steps, dtype=torch.long, device=x_t.device)
        with tqdm(range(num_steps), dynamic_ncols=True) as pbar:
            for i in pbar:
                t = timesteps[i].repeat(B)
                t_prev = timesteps[i + 1].repeat(B) if i < num_steps - 1 else torch.zeros_like(t)

                eps_cond = self.model(x_t, t, labels)
                eps_uncond = self.model(x_t, t, torch.zeros_like(labels))
                eps = (1 + self.w) * eps_cond - self.w * eps_uncond

                alpha_bar_t = extract(self.alphas_bar, t, x_t.shape)
                alpha_bar_t_prev = extract(self.alphas_bar, t_prev, x_t.shape)
                sqrt_1m_ab_t = extract(self.sqrt_one_minus_alphas_bar, t, x_t.shape)
                sqrt_1m_ab_t_prev = extract(self.sqrt_one_minus_alphas_bar, t_prev, x_t.shape)

                x0_hat = (x_t - sqrt_1m_ab_t * eps) / alpha_bar_t.sqrt()
                x_t = alpha_bar_t_prev.sqrt() * x0_hat + sqrt_1m_ab_t_prev * eps
                pbar.set_postfix(ordered_dict={"epoch": epoch, "t": int(t[0].item())})
        return torch.clip(x_t, -1, 1)
        # return x_t




