import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


# x: 初始噪声；seq: 时间下标序列；model: 扩散模型；betas: beta序列；H_funcs: 观测矩阵；y_0: 观测；sigma_0: 观测噪声标准差；clas_fn: 分类器；classes: 类别
def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        # x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
        var_obs = sigma_0 ** 2
        y_standard = y_0 / torch.sqrt(torch.tensor(1+var_obs))
        alpha_obs = 1 / torch.tensor(1+var_obs)

        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        # alpha_t和alpha_{t-1}
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * et

        v = None
        beta=0.0
        # 4.0 for 20steps
        # 1.0 for 100steps
        N = 1
        lr=1.0
        lam=1
        init_noise = torch.randn_like(x0_t)
        eta=0.85
        lam=1.0
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                # print(x_obs_t)
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                # alpha_t和alpha_{t-1}
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())


                # 更新x_obs
                # x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * torch.randn_like(x0_t)
                # x_obs_t = H_funcs.prox(x_obs_t, y_standard)
                # 更新x_0
                # 加噪
                if at[0,0,0,0] <= alpha_obs:
                    noise = torch.randn_like(x0_t)
                    xt = (at/alpha_obs).sqrt() * x_obs_t + (1-at/alpha_obs).sqrt() * noise
                    et = model(xt, t)
                    if et.size(1) == 6:
                        et = et[:, :3]
                    # x_obs_t_new = (xt - et * (1 - at/alpha_obs).sqrt()) / (at/alpha_obs).sqrt()

                    x0_t_new = x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    # x_obs_t_new = alpha_obs.sqrt() * x0_t_new + ((alpha_obs/at-alpha_obs).sqrt() - (alpha_obs/at - 1).sqrt()) * noise
                    x_obs_t_new = alpha_obs.sqrt() * x0_t_new + (1-alpha_obs).sqrt() * torch.randn_like(x0_t_new)
                else:
                    sigma_t_tilde = 0
                    xt = at.sqrt() * x0_t + (1-at - sigma_t_tilde**2).sqrt() * (x_obs_t - alpha_obs.sqrt() * x0_t) / (1-alpha_obs).sqrt()
                    et = model(xt, t)
                    if et.size(1) == 6:
                        et = et[:, :3]
                    x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    # et = torch.randn_like(x0_t)
                    # x_obs_t_new = x0_t_new * alpha_obs.sqrt() + (1-alpha_obs).sqrt() * et
                    x_obs_t_new = x_obs_t
                    # x_obs_t_new = (alpha_obs/at).sqrt() * xt + (1 - alpha_obs/at).sqrt() * torch.randn_like(x0_t)
                # if at[0,0,0,0] <= alpha_obs:
                #     d = lr * (noise - et) * (1-at[0,0,0,0]).sqrt() / at[0,0,0,0].sqrt()
                # else:
                # d = lr * (x0_t_new - x0_t)
                # x0_t += lr * (x0_t_new - x0_t) * at[0,0,0,0].sqrt() / alpha_obs.sqrt()
                x0_t += lr * (x0_t_new - x0_t)
                x_obs_t += 1.0 * (x_obs_t_new - x_obs_t)
                x_obs_t = H_funcs.prox(x_obs_t, y_standard)
                
                xt_next = x0_t


                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds