import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    # with torch.no_grad():
    #initialize x_T as given in the paper
    largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
    
    #setup iteration variables
    # x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
    y_upsampling = H_funcs.upsampling(y_0)
    n = x.size(0)
    seq_next = [-1] + list(seq[:-1])
    x0_preds = []
    xs = [x]


    t = (torch.ones(n) * seq[-1]).to(x.device)
    # alpha_t和alpha_{t-1}
    at = compute_alpha(b, t.long())
    noise = torch.randn_like(x)
    x_T = noise * (1 - at).sqrt()
    xt = x_T
    eta=0.85
    N = 1
    for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
        t = (torch.ones(n) * i).to(x.device)
        next_t = (torch.ones(n) * j).to(x.device)
        at = compute_alpha(b, t.long())
        at_next = compute_alpha(b, next_t.long())
        if cls_fn == None:
            et = model(xt, t)
        else:
            et = model(xt, t, classes)
            et = et[:, :3]
            et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
        
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
        
        # calcultate mu and sigma in DDNM
        sigma_t = (1 - at_next**2).sqrt()
        sigma_y = sigma_0
        if sigma_t[0,0,0,0] >= at_next[0,0,0,0]*sigma_y:
            lambda_t = 1.
            gamma_t = (sigma_t**2 - (at_next*sigma_y)**2).sqrt()
        else:
            lambda_t = (sigma_t)/(at_next*sigma_y)
            gamma_t = 0.
        # gamma_t = (1-at_next).sqrt()
        gamma_t = 1

        x0_t_hat = x0_t + lambda_t * H_funcs.upsampling(y_0 - H_funcs.downsampling(x0_t))
        c1 = (1 - at_next).sqrt() * eta
        c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)

        # different from the paper, we use DDIM here instead of DDPM
        xt_next = at_next.sqrt() * x0_t_hat + gamma_t * (c1 * torch.randn_like(x0_t) + c2 * et)
        xt = xt_next

        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds