import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        var_obs = sigma_0 ** 2 *(256+128)**2 / 256**2
        # (2 * (256+128)/256)**2是padding，2是虚部
        alpha_obs = torch.tensor(1/(1+var_obs))
        y_standard = y_0 * alpha_obs.sqrt()
        # alpha_t和alpha_{t-1}
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()

        
        # print(x0_t)
        # print(y_upsampling)
        v = None
        beta=0.0
        # 1000steps: 0.1
        # 100steps: 1.8
        # 20steps: 1.7
        if sigma_0 == 0:
            lr=1.5
        else:
            lr=1.5
        N=1
        lam=1
        init_noise = torch.randn_like(x0_t)
        # et = None
        eta=0.85
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                # alpha_t和alpha_{t-1}
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                # 重新加噪，Stochastic Encoding
                x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * torch.randn_like(x0_t)
                x_obs_t = H_funcs.prox_by_error_bp(x_obs_t, y_standard)
                # 更新x_0
                if at[0,0,0,0] <= alpha_obs:
                    xt = (at/alpha_obs).sqrt() * x_obs_t + (1-at/alpha_obs).sqrt() * torch.randn_like(x0_t)
                else:
                    xt = at.sqrt() * x0_t + (1-at).sqrt() * (x_obs_t - alpha_obs.sqrt() * x0_t) / (1-alpha_obs).sqrt()
                # 重新加噪，Restricted Resampling
                # print(at)
                # print(at_next)
                # xt = at.sqrt() * x0_t + (1 - eta**2) ** 0.5 * (1 - at).sqrt() * et + eta * (1 - at).sqrt() * torch.randn_like(x0_t)
                
                # if et is None:
                #     xt = at.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at).sqrt()
                # else:
                #     xt = at.sqrt() * x0_t + (1-at).sqrt() * (1-eta**2)**0.5 * et + torch.randn_like(x0_t) * (1-at).sqrt() * eta
                # print(torch.norm(xt[0]))
                # 计算error。是一个epsilon-predictor
                if cls_fn == None:
                    et = model(xt, t)
                else:
                    et = model(xt, t, classes)
                    et = et[:, :3]
                    et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                
                if et.size(1) == 6:
                    et = et[:, :3]
                
                # 计算新的x0_t
                x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()

                # 扩散模型的更新方向
                diff = x0_t_new - x0_t
                # 观测的梯度方向
                # gradient_obs = H_funcs.upsampling(H_funcs.downsampling(x0_t) - y_0)
                # 合方向
                # lam = (1-at)[0,0,0,0]
                # lam = torch.sqrt(torch.sum(diff**2, dim=[1,2,3], keepdim=True)) / torch.sqrt(torch.sum(gradient_obs**2, dim=[1,2,3], keepdim=True) + 1e-8)
                # print(lam[0])
                lam=1
                # d = diff - lam * gradient_obs
                d = diff

                # print('diff norm:{}'.format(torch.norm(diff[0])))
                # print('obs gradient norm:{}'.format(torch.norm(gradient_obs[0])))
                # 计算动量
                if v is None:
                    v = d
                else:
                    v = beta * v + (1-beta) * d
                # 更新x0_t
                # print(1-at[0,0,0,0])
                # if beta > 0:
                #     beta -= 0.1
                x0_t_last = x0_t
                x0_t += lr * v

                # prox: 用梯度下降近似
                # with torch.enable_grad():
                #     for itr in range(10):
                #         x0_t_copy = x0_t.clone().requires_grad_(True)
                #         y_pred = H_funcs.forward(x0_t_copy)
                #         loss = torch.linalg.norm(y_pred-y_0)**2
                #         grad = torch.autograd.grad(outputs=loss, inputs=x0_t_copy)[0]
                #         x0_t -= 0.1 * grad
                # prox: 用误差投影
                # x0_t = H_funcs.prox_by_error_bp(x0_t, y_0)

                # et = (xt - at.sqrt() * x0_t) / (1-at).sqrt()
                # print('x0_t norm:{}'.format(torch.norm(x0_t[0])))
                xt_next = x0_t
                
                # x0_t = y_upsampling

                # print(torch.norm(y_upsampling[0]))
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds