import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    # with torch.no_grad():
    #initialize x_T as given in the paper
    largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
    
    #setup iteration variables
    n = x.size(0)
    seq_next = [-1] + list(seq[:-1])
    x0_preds = []
    xs = [x]


    t = (torch.ones(n) * seq[-1]).to(x.device)
    # alpha_t和alpha_{t-1}
    at = compute_alpha(b, t.long())
    noise = torch.randn_like(x)
    x_T = noise * (1 - at).sqrt()
    def cal_grad(xt, y, model, H_funcs, t, next_t, cls_fn=None):

        xt.requires_grad_(True)
        at = compute_alpha(b, t.long())
        at_next = compute_alpha(b, next_t.long())
        if cls_fn == None:
            et = model(xt, t)
        else:
            et = model(xt, t, classes)
            et = et[:, :3]
            et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
        
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
        # x0_t.requires_grad_(True)
        grad = torch.zeros_like(xt)
        norm = torch.zeros_like(xt[:,0,0,0])
        for k in range(x0_t.shape[0]):
            # measure_x = H_funcs.downsampling(x0_t[[k], :, :, :])
            measure_x = H_funcs.obs(x0_t[k])
            # print(measure_x)
            loss = torch.sum((measure_x-y)**2)
            # print(loss)
            # loss.requires_grad_(True)
            grad = torch.autograd.grad(outputs=loss, inputs=xt)[0]
            # print(grad[:, :, :, :])
            norm[k] = torch.sqrt(loss).item()
        # print(grad)
        # print(norm)
        # print(x0_t)
        return x0_t, grad, norm, et
    xt = x_T
    for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
        t = (torch.ones(n) * i).to(x.device)
        next_t = (torch.ones(n) * j).to(x.device)
        at = compute_alpha(b, t.long())
        at_next = compute_alpha(b, next_t.long())
        x0_t, grad, norm, et = cal_grad(xt, y_0, model, H_funcs, t, next_t, cls_fn=None)
        # DPS更新
        with torch.no_grad():
            alpha_t_bar = at[0,0,0,0]
            alpha_t_next_bar = at_next[0,0,0,0]
            alpha_t = alpha_t_bar/alpha_t_next_bar
            beta_t = 1-alpha_t
            noise = torch.randn_like(x)
            # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
            # DDPM update
            sigma = ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            xt = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t) - grad / norm
            # DDIM update
            # xt = at_next.sqrt() * x0_t + (1-at_next).sqrt() * et - grad / norm
            xt_next = xt
            x0_preds.append(x0_t.to('cpu'))
            xs.append(xt_next.to('cpu'))

    return xs, x0_preds