import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, A_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        # x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
        # 观测对应的噪声水平
        var_obs = H_funcs.ratio ** 2 * sigma_0 ** 2
        # var_obs = 0.1
        y_upsampling = H_funcs.upsampling(y_0)/torch.sqrt(torch.tensor(1+var_obs))
        # y_upsampling = H_funcs.upsampling(y_0)
        alpha_obs = 1 / torch.tensor(1+var_obs)
        sigma_y = sigma_0
        y = y_0
        # print(alpha_obs)
        # alpha_obs = torch.tensor(0.)
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        # alpha_t和alpha_{t-1}
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x_obs_t = (x_T - et * (1 - at/alpha_obs).sqrt()) / (at/alpha_obs).sqrt()
        # x_obs_t = y_upsampling

        
        xt = x_T
        v = None
        beta=0.0
        # 1000steps: 0.1
        # 100steps: 1.8
        # 20steps: 1.7
        lr=1.0
        N=1
        lam=1
        init_noise = torch.randn_like(x_obs_t)
        et = None
        eta=0.85
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                # print(at[0,0,0,0])
                # 如果当前噪声水平大于观测噪声，用优化方法优化
                if at[0,0,0,0] <= alpha_obs:
                    
                    # 重新加噪，Stochastic Encoding
                    xt = (at/alpha_obs).sqrt() * x_obs_t + torch.randn_like(x_obs_t) * (1 - at/alpha_obs).sqrt()
                    # 重新加噪，Restricted Resampling
                    # xt = (at/alpha_obs).sqrt() * x_obs_t + (1 - eta**2) ** 0.5 * (1 - at/alpha_obs).sqrt() * init_noise + eta * (1 - at/alpha_obs).sqrt() * torch.randn_like(x_obs_t)
                    # 计算error。是一个epsilon-predictor
                    if cls_fn == None:
                        et = model(xt, t)
                    else:
                        et = model(xt, t, classes)
                        et = et[:, :3]
                        et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                    
                    if et.size(1) == 6:
                        et = et[:, :3]
                

                    # 计算新的x_obs_t，即obs噪声水平上的x
                    x_obs_t_new = (xt - et * (1 - at/alpha_obs).sqrt()) / (at/alpha_obs).sqrt()
                    x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    # 扩散模型的更新方向
                    diff = x_obs_t_new - x_obs_t
                    d = diff
                    # 计算动量
                    if v is None:
                        v = d
                    else:
                        v = beta * v + (1-beta) * d
                    # 更新x0_t
                    x_obs_t += lr * v
                    # lr += 0.01
                    # v += x0_t + y_upsampling - H_funcs.upsampling(H_funcs.downsampling(x0_t))
                    x_obs_t = x_obs_t + y_upsampling - H_funcs.upsampling(H_funcs.downsampling(x_obs_t))
                    # et = (xt - at.sqrt() * x0_t) / (1-at).sqrt()
                    # print('x0_t norm:{}'.format(torch.norm(x0_t[0])))
                    xt_next = x_obs_t
                    if at_next[0,0,0,0] > alpha_obs:
                        # 加噪回at，启动DDNM
                        # print(at)
                        # print(alpha_obs)
                        # x_obs_t = y_upsampling
                        xt = (at/alpha_obs).sqrt() * x_obs_t + (1-at/alpha_obs).sqrt() * torch.randn_like(x_obs_t)
                        if cls_fn == None:
                            et = model(xt, t)
                        else:
                            et = model(xt, t, classes)
                            et = et[:, :3]
                            et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                        
                        if et.size(1) == 6:
                            et = et[:, :3]
                        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                        # 使用观测约束x0_t，并修正噪声

                        # x0_t += H_funcs.upsampling(y_0 - H_funcs.downsampling(x0_t)) * ((1-at_next)/at_next/var_obs).sqrt() / eta
                        sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0]

                        # Eq. 17
                        x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv(
                            A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
                        ).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size()) 
                        
                        xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise(
                            torch.randn_like(x0_t).reshape(x0_t.size(0), -1), 
                            at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size())
                        xt = xt_next
                        # xt_next = x_obs_t / alpha_obs.sqrt()
                        # xt_next = x_obs_t / torch.tensor(alpha_obs).sqrt()

                else:
                    # continue
                    if cls_fn == None:
                        et = model(xt, t)
                    else:
                        et = model(xt, t, classes)
                        et = et[:, :3]
                        et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                    
                    if et.size(1) == 6:
                        et = et[:, :3]
                    x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0]

                    # Eq. 17
                    x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv(
                        A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
                    ).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size()) 
                    
                    xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise(
                        torch.randn_like(x0_t).reshape(x0_t.size(0), -1), 
                        at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size())
                    # xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x0_t) + c2 * et

                    xt = xt_next
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds