####### The code is built based on I2SB (https://github.com/NVlabs/I2SB/tree/master)

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from functools import partial
import torch
import torch.nn as nn
from .util import unsqueeze_xdim
import pygrappa.grappa as grappa

from ipdb import set_trace as debug
import randoms
def compute_gaussian_product_coef(sigma1, sigma2):
    """ Given p1 = N(x_t|x_0, sigma_1**2) and p2 = N(x_t|x_1, sigma_2**2)
        return p1 * p2 = N(x_t| coef1 * x0 + coef2 * x1, var) """

    denom = sigma1**2 + sigma2**2
    coef1 = sigma2**2 / denom
    coef2 = sigma1**2 / denom
    var = (sigma1**2 * sigma2**2) / denom
    return coef1, coef2, var



class Adative_Network():
    def __init__(self, betas, device, net):

        self.device = device
        self.net = net
        # compute analytic std: eq 11
        std_fwd = np.sqrt(np.cumsum(betas))
        std_bwd = np.sqrt(np.flip(np.cumsum(np.flip(betas))))
        mu_x0, mu_x1, var = compute_gaussian_product_coef(std_fwd, std_bwd)
        std_sb = np.sqrt(var)

        # tensorize everything
        to_torch = partial(torch.tensor, dtype=torch.float32)
        self.betas = to_torch(betas).to(device)
        self.std_fwd = to_torch(std_fwd).to(device)
        self.std_bwd = to_torch(std_bwd).to(device)
        self.std_sb  = to_torch(std_sb).to(device)
        self.mu_x0 = to_torch(mu_x0).to(device)
        self.mu_x1 = to_torch(mu_x1).to(device)

    def get_std_fwd(self, step, xdim=None):
        std_fwd = self.std_fwd[step]
        return std_fwd if xdim is None else unsqueeze_xdim(std_fwd, xdim)

    def compute_pred_x0(self, step, xt, net_out, clip_denoise=False):
        """ Given network output, recover x0. This should be the inverse of Eq 12 """
        std_fwd = self.get_std_fwd(step, xdim=xt.shape[1:])
        pred_x0 = xt - std_fwd * net_out
        return pred_x0


    def p_posterior(self, nprev, n, x_n, x0):
        """ Sample p(x_{nprev} | x_n, x_0), i.e. eq 4"""

        assert nprev < n
        std_n     = self.std_fwd[n]
        std_nprev = self.std_fwd[nprev]
        std_delta = (std_n**2 - std_nprev**2).sqrt()

        mu_x0, mu_xn, var = compute_gaussian_product_coef(std_nprev, std_delta)
        xt_prev = mu_x0 * x0 + mu_xn * x_n

        return xt_prev, mu_xn, mu_x0

    def pred_x0_fn(self, xt, step):
        step = torch.full((xt.shape[0],), step, device=self.device, dtype=torch.long)
        out = self.net(xt, step)
        return self.compute_pred_x0(step, xt, out)

    def pred(self, x_input, pred_x0, t=150, total_nfe=999):
        x_inter, mu_xn, mu_x0 = self.p_posterior(total_nfe - t, total_nfe, x_input, pred_x0)
        pred_x0 = self.pred_x0_fn(x_inter, total_nfe - t)

        return pred_x0, mu_xn, mu_x0
