import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        singulars = H_funcs.singulars()
        # print(singulars.shape)
        Sigma = torch.zeros(x.shape[1]*x.shape[2]*x.shape[3], device=x.device)
        Sigma[:singulars.shape[0]] = singulars
        alpha_obs = torch.ones_like(Sigma)
        # alpha_obs = torch.zeros_like(Sigma) 
        alpha_obs[Sigma > 0] = 1 / (1 + (sigma_0 / Sigma[Sigma > 0])**2).unsqueeze(0)
        U_t_y = H_funcs.Ut(y_0)
        Sig_inv_U_t_y = U_t_y / singulars[:U_t_y.shape[-1]] * alpha_obs.sqrt()
        # print(Sig_inv_U_t_y.shape)
        alpha_obs = alpha_obs.view([1, x.shape[1], x.shape[2], x.shape[3]]).repeat(x.shape[0], 1, 1, 1)
        Sig_inv_U_t_y = Sig_inv_U_t_y.view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
        Sigma = Sigma.view([1, x.shape[1], x.shape[2], x.shape[3]]).repeat(x.shape[0], 1, 1, 1)
        # print(torch.sum(Sigma==0))
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        # alpha_t和alpha_{t-1}
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        V_t_x0 = H_funcs.Vt(x0_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
        V_t_x_obs = alpha_obs.sqrt() * V_t_x0 + (1-alpha_obs).sqrt() * torch.randn_like(V_t_x0)
        x_obs_t = H_funcs.V(V_t_x_obs.view([V_t_x_obs.shape[0], -1])).view(x.shape)
        
        # print(x0_t)
        # print(y_upsampling)
        v = None
        beta=0.0
        # 20steps: 0.5
        # 100steps: 0.1
        lr=0.1
        N=1
        lam=1
        init_noise = torch.randn_like(x0_t)
        et = None
        eta=0.85
        gamma=0.05
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                # alpha_t和alpha_{t-1}
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                # 重新加噪，Stochastic Encoding
                # 这里需要根据不同的奇异值设置不同的噪声水平
                V_t_x0 = H_funcs.Vt(x0_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs = H_funcs.Vt(x_obs_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                # 加噪到at
                smaller_idx = (alpha_obs < at[0,0,0,0])
                # print(smaller_idx)
                larger_idx = (alpha_obs >= at[0,0,0,0])
                V_t_x_t = torch.zeros_like(V_t_x_obs)
                V_t_x_t[larger_idx] = (at[0,0,0,0]/alpha_obs[larger_idx]).sqrt() * V_t_x_obs[larger_idx] + (1-at[0,0,0,0]/alpha_obs[larger_idx]).sqrt() * torch.randn_like(V_t_x_obs[larger_idx])
                V_t_x_t[smaller_idx] = at[0,0,0,0].sqrt() * V_t_x0[smaller_idx] + (1-at[0,0,0,0]).sqrt() * (V_t_x_obs[smaller_idx] - V_t_x0[smaller_idx] * alpha_obs[smaller_idx].sqrt())/(1-alpha_obs[smaller_idx]).sqrt()
                # 整理
                xt = H_funcs.V(V_t_x_t.view([V_t_x_t.shape[0], -1])).view(x.shape)
                # 计算score
                if cls_fn == None:
                    et = model(xt, t)
                else:
                    et = model(xt, t, classes)
                    et = et[:, :3]
                    et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                
                if et.size(1) == 6:
                    et = et[:, :3]

                x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()

                V_t_x0_new = H_funcs.Vt(x0_t_new).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs_new = alpha_obs.sqrt() * V_t_x0_new + (1-alpha_obs).sqrt() * torch.randn_like(V_t_x0_new)
                # 回退不需要更新的分量
                V_t_x0[larger_idx] = V_t_x0_new[larger_idx] # 这里是为了初始化x_0
                V_t_x_obs_new[smaller_idx] = V_t_x_obs[smaller_idx]
                # 映射回原空间
                x0_t = H_funcs.V(V_t_x0.view([V_t_x0.shape[0], -1])).view(x.shape)
                x_obs_t_new = H_funcs.V(V_t_x_obs_new.view([V_t_x_obs_new.shape[0], -1])).view(x.shape)
                # 更新
                x0_t += lr * (x0_t_new - x0_t)
                x_obs_t += 1.0 * (x_obs_t_new - x_obs_t) # x_obs的学习率固定为1
                # 投影
                V_t_x_obs = H_funcs.Vt(x_obs_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs[Sigma > 0] = Sig_inv_U_t_y[Sigma > 0]
                x_obs_t = H_funcs.V(V_t_x_obs.view([V_t_x_obs.shape[0], -1])).view(x.shape)

                xt_next = x0_t
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds