import torch
from tqdm import tqdm
import torchvision.utils as tvu
import os
import numpy as np
import torch.optim as optim
import lpips


def lmap_rps(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=200, vae_lr=0.5, w_prior=0.15, noise_t=50, renoise_t=100, lam=0.1, use_psld=True, cls_fn=None, classes=None):
    # print(y_0.shape)
    img = H_funcs.H_pinv(y_0).view([1, 3, 256, 256])
    # img = img + torch.randn_like(img) * 0.01
    # with torch.enable_grad():
    #     img_with_grad = img.clone().requires_grad_(True)
    #     optimizer = optim.AdamW([img_with_grad], lr=0.001)
    #     for _ in range(5):
    #         optimizer.zero_grad()
    #         loss = torch.mean((y_0 - H_funcs.forward(img_with_grad))**2) # MSE
    #         loss.backward()
    #         optimizer.step()
    #         # print(loss.item())
    # img = img_with_grad.detach()
    # print(img.shape)
    x0_init = model.encode_first_stage(img)
    pixels_recon = model.decode_first_stage(x0_init)
    dis = torch.mean((pixels_recon - img)**2).item()
    # print(dis)
    # for _ in range(10):
    #     x0_init = model.encode_first_stage(model.decode_first_stage(x0_init))
    # optimize_iters = 25 + max(int((dis-0.11)/0.0002), 0)
    # x0_init = torch.randn_like(x0_init)
    n = x.size(0)
    x0_preds = []
    xs = []
    # loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
    w_prior_init = w_prior
    with torch.enable_grad():
        x0_t_with_grad = x0_init.clone().requires_grad_(True)
        optimizer = optim.AdamW([x0_t_with_grad], lr=vae_lr)
        # optimizer = optim.SGD([x0_t_with_grad], lr=vae_lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=optimize_iters, eta_min=1e-5)
        for steps_n in tqdm(range(optimize_iters)):
            optimizer.zero_grad()
            pixels = model.decode_first_stage(x0_t_with_grad)
            # print(x0_t_with_grad)
            loss_likelihood = torch.sum((y_0-H_funcs.forward(pixels))**2) # MAP
            # loss_likelihood = torch.sum(torch.abs(y_0-H_funcs.forward(pixels)))
            
            # if 
            # print(loss_likelihood.item() / (y_0.shape[0] * y_0.shape[1] * y_0.shape[2] * y_0.shape[3]))
            # if steps_n > 20 and loss_likelihood.item() / (y_0.shape[0] * y_0.shape[1] * y_0.shape[2] * y_0.shape[3]) < 0.0238:
            #     break
            # if steps_n <= 280:
            #     loss_prior = w_prior * torch.sum(x0_t_with_grad**2)
            # else:
            # i = 400 // (steps_n - 280)
            i = int(noise_t)
            # prior
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            # score
            # with torch.enable_grad():
            # xt_with_grad = xt.clone().requires_grad_(True)
            xt_with_grad = at.sqrt() * x0_t_with_grad + (1-at).sqrt() * torch.randn_like(x0_t_with_grad)
            et = model.apply_model(xt_with_grad, t, cond=None)
            x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
            # loss_prior = w_prior * torch.sum((x0_t - x0_t_with_grad)**2)
            loss_prior = w_prior * torch.sum(et.detach() * x0_t_with_grad)
            # 
            # w_prior -= w_prior_init / optimize_iters
            loss = loss_likelihood + loss_prior
            # print(loss_likelihood / loss_prior)
            loss.backward()
            optimizer.step()
            scheduler.step()
    x0_map = x0_t_with_grad.detach()
    x0_t = x0_t_with_grad.detach()

    with torch.no_grad():
        # x0_t += 0.2 * torch.randn_like(x0_t)
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        at_init = alphas_cumprod[int(renoise_t)-1] if renoise_t > 0 else torch.tensor(1.0).cuda()
        noise = torch.randn_like(x0_t)
        xt = at_init.sqrt() * x0_t + (1 - at_init).sqrt() * noise
        xt_map = at_init.sqrt() * x0_map + (1 - at_init).sqrt() * noise
        
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            if i >= int(renoise_t):
                continue
            
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            if j < 0:
                at_next = torch.tensor(1.0).cuda()
            # at_next_next = alphas_cumprod[j_next]
            # score
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                t = (torch.ones(n) * i).to(x.device)
                et = model.apply_model(xt_with_grad, t, cond=None)
                et1 = et
                et2 = (xt_with_grad - at.sqrt() * x0_map) / (1 - at).sqrt()
                # et = et1 + 0.1 * (et1 - et2)

                x0_t = (xt_with_grad - et1 * (1 - at).sqrt()) / at.sqrt()
                et_cfg = et1 + 0.0 * (et1 - et2)
                x0_t_cfg = (xt_with_grad - et_cfg * (1 - at).sqrt()) / at.sqrt()
                pixels = model.decode_first_stage(x0_t_cfg)
                if use_psld:
                    loss1 =torch.linalg.norm(y_0 - H_funcs.forward(pixels) - sigma_0 * torch.randn_like(y_0)) ** 2.0

                    pixels_recon = H_funcs.proj(pixels, y_0)
                    loss2 = torch.linalg.norm(x0_t - model.encode_first_stage(pixels_recon.detach())) ** 2.0
                    loss = loss1 + loss2*lam
                    grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0] * at
                    # grad = grad
                else:
                    # DPS
                    # loss = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                    pixels = pixels.clamp(-1, 1)
                    loss = torch.linalg.norm(y_0 - H_funcs.forward(pixels)) ** 2.0
                    # loss = torch.sum(torch.abs(y_0 - H_funcs.forward(pixels)))
                    # x0_t = model.encode_first_stage(pixels)
                    # loss = torch.linalg.norm(y_0 - H_funcs.forward(pixels)) ** 2.0
                    
                    grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0] * at
                # xt_with_grad.grad.zero_()
            # xt_next_prime = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            # undapte x0_t
            # cnt += 1
            alpha_t_bar = at
            alpha_t_next_bar = at_next
            alpha_t = alpha_t_bar/alpha_t_next_bar
            beta_t = 1-alpha_t
            noise = torch.randn_like(x)
            # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
            # DDPM update
            eta = 1.0
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            # sigma = eta * (1-at_next).sqrt()
            # xt_next = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t) - grad * lr * at
            # xt_next = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t) - grad * lr
            # xt_next = beta_t * at_next.sqrt() / (1-at) * x0_t + (1 - at_next) / (1 - at) * alpha_t.sqrt() * xt + sigma * torch.randn_like(x0_t) - lr * grad
            noise = torch.randn_like(x0_t)
            # xt_next = beta_t * at_next.sqrt() / (1-at) * x0_t_cfg + (1 - at_next) / (1 - at) * alpha_t.sqrt() * xt + sigma * noise - lr * grad
            xt_next = at_next.sqrt() * x0_t_cfg + (1-at_next - sigma**2).sqrt() * et + sigma * noise - lr * grad
            # xt_next = mean + var.sqrt() * torch.randn_like(x)

            # xt_map_next = beta_t * at_next.sqrt() / (1-at) * x0_map + (1 - at_next) / (1 - at) * alpha_t.sqrt() * xt + sigma * torch.randn_like(x0_t) - lr * grad
            xt_map_next = beta_t * at_next.sqrt() / (1-at) * x0_map + (1 - at_next) / (1 - at) * alpha_t.sqrt() * xt + sigma * noise - lr * grad
            xt_map = xt_map_next
            xt = xt_next
        # for _ in range(50):
        #     with torch.enable_grad():
        #         xt = xt.clone().requires_grad_(True)
        #         pixels = model.decode_first_stage(xt)
        #         # pixels = img.clone().requires_grad_(True)
        #         # pixels = model.decode_first_stage(x0_t)
        #         loss = torch.linalg.norm(y_0 - H_funcs.forward(pixels)) ** 2.0
        #         grad = torch.autograd.grad(outputs=loss, inputs=xt)[0]
        #         xt = xt.detach() - lr * grad 
    img = model.decode_first_stage(xt)
    # for _ in range(1):
    #     with torch.enable_grad():
    #         pixels = img.clone().requires_grad_(True)
    #         # pixels = model.decode_first_stage(x0_t)
    #         loss = torch.linalg.norm(y_0 - H_funcs.forward(pixels)) ** 2.0
    #         grad = torch.autograd.grad(outputs=loss, inputs=pixels)[0]
    #         img = img - lr * grad
    x0_preds.append(img.to('cpu'))
    xs.append(img.to('cpu'))

    return xs, x0_preds


def resample(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    assert len(seq) == 500
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        gamma = 40
        eta = 1.0
        inter_timesteps = 5
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        cnt = -1
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            # print(at_next)
            # c1 = (1 - at_next).sqrt() * eta
            # c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            # score
            # print(xt)
            # print(t)
            # print(model)
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_t_abaaba = (xt - et * (1 - at)) / at.sqrt()
            xt_next_prime = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t)
            # undapte x0_t
            # dps update
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                et = model.apply_model(xt_with_grad, t, cond=None)
                x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                pixels = model.decode_first_stage(x0_t)
                loss = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0]
                grad = grad
            xt_next_prime = xt_next_prime - grad * lr * at
            xt_next = xt_next_prime
            cnt += 1
            if i < 667 and i > 0:
                # cnt = 0
                index = len(seq) - cnt - 1
                if index % 10 == 0:
                    xt_temp = xt_next
                    for k in range(index, max(index-inter_timesteps, 1)):
                        t = (torch.ones(n) * seq[k]).to(x.device)
                        
                        at_temp = alphas_cumprod[seq[k-1]]
                        at_next_temp = alphas_cumprod[seq[k-2]]
                        sigma_temp = eta * ((1-at_next_temp)/(1-at_temp)).sqrt() * (1-at_temp/at_next_temp).sqrt()
                        et_temp = model.apply_model(xt_temp, t, cond=None)
                        x0_t_temp = (xt_temp - et_temp * (1 - at_temp).sqrt()) / at_temp.sqrt()
                        x0_t_abaaba_temp = (xt_temp - et_temp * (1 - at_temp)) / at_temp.sqrt()
                        xt_temp = at_next_temp.sqrt() * x0_t_temp + (1-at_next_temp - sigma_temp**2).sqrt() * et_temp + sigma_temp * torch.randn_like(x0_t_temp)
                        # undapte x0_t
                    if i > 333:
                        pixels = model.decode_first_stage(x0_t_abaaba)
                        with torch.enable_grad():
                            pixels_with_grad = pixels.clone().requires_grad_(True)
                            optimizer = optim.AdamW([pixels_with_grad], lr=0.01)
                            for epoch in range(2000):
                                optimizer.zero_grad()
                                loss = torch.mean((y_0-H_funcs.forward(pixels_with_grad))**2) # MSE
                                loss.backward()
                                optimizer.step()
                                if loss.item() < 1e-4:
                                    break # early stop
                                # print(loss.item())
                        # resample
                        x0_t_hat = model.encode_first_stage(pixels_with_grad)
                    else:
                        with torch.enable_grad():
                            x0_t_with_grad = x0_t_abaaba.clone().requires_grad_(True)
                            optimizer = optim.AdamW([x0_t_with_grad], lr=0.005)
                            for epoch in range(500):
                                optimizer.zero_grad()
                                pixels = model.decode_first_stage(x0_t_with_grad)
                                # print(x0_t_with_grad)
                                loss = torch.mean((y_0-H_funcs.forward(pixels))**2) # MSE
                                loss.backward()
                                optimizer.step()
                                if loss.item() < 1e-4:
                                    break
                                # print(loss.item())
                    # resample
                        x0_t_hat = x0_t_with_grad.detach()
                        sigma_t_square = gamma * (1-at_next)/at * (1-at/at_next)
                        mean = (sigma_t_square * at_next.sqrt() * x0_t_hat + (1-at_next) * xt_next_prime) / (sigma_t_square + 1 - at_next)
                        if sigma_t_square == 0:
                            var = torch.tensor(0)
                            mean = xt_next_prime
                        else:
                            var = sigma_t_square * (1-at_next) / (sigma_t_square + 1 - at_next)
                            mean = (sigma_t_square * at_next.sqrt() * x0_t_hat + (1-at_next) * xt_next_prime) / (sigma_t_square + 1 - at_next)
                        xt_next = mean + var.sqrt() * torch.randn_like(x)
            xt = xt_next
        with torch.enable_grad():
            x0_t_with_grad = xt.clone().requires_grad_(True)
            optimizer = optim.AdamW([x0_t_with_grad], lr=0.005)
            for epoch in range(500):
                optimizer.zero_grad()
                pixels = model.decode_first_stage(x0_t_with_grad)
                # print(x0_t_with_grad)
                loss = torch.mean((y_0-H_funcs.forward(pixels))**2)
                loss.backward()
                optimizer.step()
                if loss.item() < 1e-4:
                    break
                # print(loss.item())
        xt = x0_t_with_grad.detach()

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def latent_dps(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        gamma = 40
        eta = 1.0
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        # lr = 0.1
        cnt = 0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            at_next_next = alphas_cumprod[j_next]
            # score
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                et = model.apply_model(xt_with_grad, t, cond=None)
                x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                pixels = model.decode_first_stage(x0_t)
                loss = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0]
                grad = grad
            # xt_next_prime = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            # undapte x0_t
            cnt += 1
            alpha_t_bar = at
            alpha_t_next_bar = at_next
            alpha_t = alpha_t_bar/alpha_t_next_bar
            beta_t = 1-alpha_t
            noise = torch.randn_like(x)
            # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
            # DDPM update
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            # sigma = eta * (1-at_next).sqrt()
            xt_next = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t) - grad * lr * at
            # xt_next = mean + var.sqrt() * torch.randn_like(x)
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def ldir(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        gamma = 40
        eta = 1.0
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        # lr = 0.5
        cnt = 0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            at_next_next = alphas_cumprod[j_next]
            # score
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                et = model.apply_model(xt_with_grad, t, cond=None)
                x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                alpha_t_bar = at
                alpha_t_next_bar = at_next
                alpha_t = alpha_t_bar/alpha_t_next_bar
                beta_t = 1-alpha_t
                noise = torch.randn_like(x)
                # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
                # DDPM update
                sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
                xt_next_par = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t)
                pixels = model.decode_first_stage(xt_next_par)
                loss = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0]
                grad = grad
            xt_next = xt_next_par - grad * lr * at
            # xt_next_prime = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            # undapte x0_t
            cnt += 1

            # xt_next = mean + var.sqrt() * torch.randn_like(x)
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def stsl(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, stepsize=0.02, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        eta = 1.0
        d = x.shape[1] * x.shape[2] * x.shape[3]
        N = 2
        lam = 1.0
        K= 5
        nv = 2
        # stepsize = 0.02
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        # lr = 0.5
        cnt = 0
        x0_inverse = model.encode_first_stage(H_funcs.H_pinv(y_0))
        xt = x0_inverse
        for i, j in tqdm(zip(seq_next, seq)):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            # print(at_next)
            if i < 0:
                xt = at_next.sqrt() * xt + (1-at_next).sqrt() * torch.randn_like(x)
            else:
                et = model.apply_model(xt, t, cond=None)
                x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                xt = at_next.sqrt() * x0_t + (1-at_next).sqrt() * et
        # xt = xt
        # xt = torch.randn_like(x)
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            at_next_next = alphas_cumprod[j_next]
            # score
            
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                # lr = 0.01 * 0.998**cnt
                # lr = lr * max(at.item(), 0.01)
                optimizer = optim.Adam([xt_with_grad], lr=lr)
                for _ in range(K):
                    optimizer.zero_grad()
                    et = model.apply_model(xt_with_grad, t, cond=None)
                    s_t = et / (1-at).sqrt()
                    x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                    pixels = model.decode_first_stage(x0_t)
                    loss1 = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                    loss2 = None
                    for _ in range(N):
                        random_noise = torch.randn_like(x)
                        xt_purt = xt_with_grad+random_noise
                        et_purt = model.apply_model(xt_purt, t, cond=None)
                        s_t_purt = et_purt / (1-at).sqrt()
                        if loss2 is None:
                            loss2 = torch.sum(random_noise*(s_t_purt - s_t))
                        else:
                            loss2 += torch.sum(random_noise*(s_t_purt - s_t))
                    # print('loss1:{}, loss2:{}'.format(loss1.item(), loss2.item()/d))
                    # print()
                    loss = lam * loss1 + stepsize/d * loss2/N
                    loss.backward()
                    optimizer.step()
            xt = xt_with_grad.detach()
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            cnt += 1
            alpha_t_bar = at
            alpha_t_next_bar = at_next
            alpha_t = alpha_t_bar/alpha_t_next_bar
            beta_t = 1-alpha_t
            noise = torch.randn_like(x)
            # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
            # DDPM update
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            xt_next = at_next.sqrt() * x0_t + (1-at_next-sigma**2).sqrt() * et + sigma * noise
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def psld(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, lam, N, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        eta = 1.0
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        # lr = 1.0
        cnt = 0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            if j < 0:
                at_next = torch.tensor(1.0).cuda()
            at_next_next = alphas_cumprod[j_next]
            # score
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                et = model.apply_model(xt_with_grad, t, cond=None)
                x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                pixels = model.decode_first_stage(x0_t)
                loss1 =torch.linalg.norm(y_0 - H_funcs.forward(pixels) - sigma_0 * torch.randn_like(y_0))

                pixels_recon = H_funcs.proj(pixels, y_0)
                loss2 = torch.linalg.norm(x0_t - model.encode_first_stage(pixels_recon.detach()))
                loss = loss1 + loss2*lam
                # loss = loss / loss.item() * loss1.item()
                # print(loss.item())
                grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0]
                grad = grad
                # xt_with_grad.grad.zero_()
            # xt_next_prime = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            # undapte x0_t
            cnt += 1
            alpha_t_bar = at
            alpha_t_next_bar = at_next
            alpha_t = alpha_t_bar/alpha_t_next_bar
            beta_t = 1-alpha_t
            noise = torch.randn_like(x)
            # xt = alpha_t.sqrt() * (1-alpha_t_next_bar)/(1-alpha_t_bar) * xt + alpha_t_next_bar.sqrt() * beta_t / (1-alpha_t_bar) * x0_t - grad / norm
            # DDPM update
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            xt_next = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t) - grad * lr * at
            # xt_next = mean + var.sqrt() * torch.randn_like(x)
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu')) 
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds

def daps_latent(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    def ode(xt, t, order=5):
        n = xt.shape[0]
        skip = t // (order - 1)
        if skip > 0:
            seq = range(0, t, skip)
        else:
            seq = [0]
        # print(list(seq))
        # print(seq)
        seq = list(seq)[1:] + [t]
        seq_next = [-1] + list(seq[:-1])
        # b = self.betas
        for i, j in zip(reversed(seq), reversed(seq_next)):
            # steps.append(i)
            t = (torch.ones(n) * i).to(xt.device)
            next_t = (torch.ones(n) * j).to(xt.device)
            # at = compute_alpha(b, t.long())
            # at_next = compute_alpha(b, next_t.long())
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_t = x0_t.clip(-1, 1)
            xt_next = at_next.sqrt() * x0_t + (1-at_next).sqrt() * et
            xt = xt_next
            # print(xt.norm())
        return xt
    def langevin(x0, y_0, eta, at, N=100, nonlinear=True):
        with torch.enable_grad():
            rt = max((1-at).sqrt(), 1e-4)
            # sigma_0 = self.sigma_0
            sigma_0 = 0.01
            x0_variable = x0.detach().clone().requires_grad_()
            for _ in range(N):
                # loss = torch.sum((x0_variable - x0)**2) / (2*rt**2) + torch.sum((self.H_funcs.H(x0_variable)-y_0)**2) / (2*sigma_0**2)
                # error = y_0 - self.H_funcs.H(x0_t)
                if sigma_0 == 0 and not nonlinear:
                    loss = torch.sum((H_funcs.forward(model.decode_first_stage(x0_variable))-y_0)**2)/eta/2
                    # loss = torch.sum((x0_variable - x0)**2) / (2*rt**2) + torch.sum((self.H_funcs.H(x0_variable)-y_0)**2) / (2*sigma_0**2)
                # elif at[0,0,0,0] == 1:
                #     loss = torch.sum((x0_variable - x0)**2)/eta/2
                else:
                    loss = torch.sum((x0_variable - x0)**2) / (2*rt**2) + torch.sum((H_funcs.forward(model.decode_first_stage(x0_variable))-y_0)**2) / (2*sigma_0**2)
                    # loss = torch.sum((x0_variable - x0)**2) * (2*sigma_0**2) / (2*rt**2) + torch.sum((self.H_funcs.H(x0_variable)-y_0)**2)
                grad = torch.autograd.grad(outputs=loss, inputs=x0_variable)[0]
                x0_variable = x0_variable - eta * grad + (2*eta)**0.5 * torch.randn_like(x0)
                # print(x0_variable.norm())
        return x0_variable.detach()
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        eta = 1.0
        x0_preds = []
        xs = []
        cnt = 0
        eta0 = lr
        delta = 1e-2
        order = 5
        T = 1000
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            x0_t = ode(xt, int(t[0]), order=5)
            eta = eta0 * (delta + t[0]/T * (1-delta))
            # eta = at[0,0,0,0]
            x0_t_hat = langevin(x0_t, y_0, eta, at, nonlinear=True)
            x0_t = x0_t_hat
            add_up = (1-at_next).sqrt() * torch.randn_like(x0_t)
            xt_next = at_next.sqrt() * x0_t + add_up
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu')) 
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def dcdp_latent(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=100, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x) * (1 - alphas_cumprod[seq[-1]]).sqrt()
        eta = 1.0
        x0_preds = []
        xs = []
        cnt = 0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            with torch.enable_grad():
                x0_t_with_grad = x0_t.clone().requires_grad_(True)
                optimizer = optim.AdamW([x0_t_with_grad], lr=lr)
                for epoch in range(optimize_iters):
                    optimizer.zero_grad()
                    pixels = model.decode_first_stage(x0_t_with_grad)
                    loss = torch.sum((y_0 - H_funcs.forward(pixels))**2)
                    loss.backward()
                    optimizer.step()
            x0_t = x0_t_with_grad.detach()
            # re-encode
            x0_t = model.encode_first_stage(model.decode_first_stage(x0_t))
            add_up = (1-at_next).sqrt() * torch.randn_like(x0_t)
            xt_next = at_next.sqrt() * x0_t + add_up
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu')) 
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def dmap_latent(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N=2, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        d = x.size(1) * x.size(2) * x.size(3)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x)
        gamma = 40
        eta = 1.0
        x0_preds = []
        xs = []
        # print(xt.shape)
        # print(alphas_cumprod)
        # lr = 0.1
        cnt = 0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            t_next = (torch.ones(n) * j).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            at_next_next = alphas_cumprod[j_next]
            # score
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            # DDPM update
            sigma = eta * ((1-at_next)/(1-at)).sqrt() * (1-at/at_next).sqrt()
            xt_next = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et + sigma * torch.randn_like(x0_t)
            # xt_next = mean + var.sqrt() * torch.randn_like(x)
            # xt = xt_next
            mu_t = at_next.sqrt() * x0_t + (1-at_next - sigma**2).sqrt() * et
            if j >= 0:
                for _ in range(N):
                    with torch.enable_grad():
                        xt_with_grad = xt_next.clone().requires_grad_(True)
                        et = model.apply_model(xt_with_grad, t_next, cond=None)
                        x0_t = (xt_with_grad - et * (1 - at_next).sqrt()) / at_next.sqrt()
                        pixels = model.decode_first_stage(x0_t)
                        loss = torch.norm(y_0 - H_funcs.forward(pixels))
                        grad = torch.autograd.grad(outputs=loss, inputs=xt_with_grad)[0]
                        # grad = grad
                    xt_next = xt_next - lr * grad * at_next
                    xt_next = mu_t + (xt_next - mu_t) / torch.norm(xt_next - mu_t) * sigma * (d ** 0.5)
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu'))
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds

def sitcom_latent(x, seq, model, alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=30, cls_fn=None, classes=None):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        seq_next_next = [-1, -1] + list(seq[:-2])
        alphas_cumprod = torch.cat((alphas_cumprod, torch.tensor([1.0]).cuda()))
        xt = torch.randn_like(x) * (1 - alphas_cumprod[seq[-1]]).sqrt()
        eta = 1.0
        x0_preds = []
        xs = []
        cnt = 0
        lam = 0.0
        for i, j, j_next in tqdm(zip(reversed(seq), reversed(seq_next), reversed(seq_next_next))):
            t = (torch.ones(n) * i).to(x.device)
            at = alphas_cumprod[i]
            at_next = alphas_cumprod[j]
            with torch.enable_grad():
                xt_with_grad = xt.clone().requires_grad_(True)
                optimizer = optim.AdamW([xt_with_grad], lr=lr)
                for epoch in range(optimize_iters):
                    optimizer.zero_grad()
                    et = model.apply_model(xt_with_grad, t, cond=None)
                    x0_t = (xt_with_grad - et * (1 - at).sqrt()) / at.sqrt()
                    pixels = model.decode_first_stage(x0_t)
                    loss = torch.sum((y_0 - H_funcs.forward(pixels))**2) + lam * torch.sum((xt_with_grad - xt)**2)
                    loss.backward()
                    optimizer.step()
            xt = xt_with_grad.detach()
            et = model.apply_model(xt, t, cond=None)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            add_up = (1-at_next).sqrt() * torch.randn_like(x0_t)
            xt_next = at_next.sqrt() * x0_t + add_up
            xt = xt_next

        pixels = model.decode_first_stage(xt)
        x0_t = xt_next = pixels
        x0_preds.append(x0_t.to('cpu')) 
        xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    # torch.cuda.empty_cache()
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        # x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        v = None
        beta=0.0
        et = None
        # init_noise = torch.randn_like(x0_t)
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                xt = at.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at).sqrt()
                if cls_fn == None:
                    et = model(xt, t)
                else:
                    et = model(xt, t, classes)
                    et = et[:, :3]
                    et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                
                if et.size(1) == 6:
                    et = et[:, :3]
                x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()
                # stochastic gradient
                diff = x0_t_new - x0_t
                d = diff
                if v is None:
                    v = d
                else:
                    v = beta * v + (1-beta) * d
                # print(v)
                # print(lr)
                x0_t += lr * v
                x0_t = H_funcs.proj(x0_t, y_0)
                # random_noise = torch.randn_like(x0_t)
                xt_next = x0_t
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds

# It is quite interesting that I found after the submission that this function was used for the noise-free phase retrieval task. Note that alpha_obs is always 1, thus this function is equalivant to the above one except 
# it performs the projection operation twice, which may yield better accuracy. Therefore, I have decided to keep this function here. One can switch to the above function and get a slightly lower results (PSNR 30~31).
def efficient_generalized_steps_phase(x, seq, model, b, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        # x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        v = None
        beta=0.0
        et = None
        init_noise = torch.randn_like(x0_t)
        # alpha_obs is always 1
        alpha_obs=torch.tensor(1)
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                # add noise
                x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * torch.randn_like(x0_t)
                x_obs_t = H_funcs.proj(x_obs_t, y_0, alpha_obs)
                if at[0,0,0,0] <= alpha_obs:
                    xt = (at/alpha_obs).sqrt() * x_obs_t + (1-at/alpha_obs).sqrt() * torch.randn_like(x0_t)
                else:
                    xt = at.sqrt() * x0_t + (1-at).sqrt() * (x_obs_t - alpha_obs.sqrt() * x0_t) / (1-alpha_obs).sqrt()
                if cls_fn == None:
                    et = model(xt, t)
                else:
                    et = model(xt, t, classes)
                    et = et[:, :3]
                    et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                
                if et.size(1) == 6:
                    et = et[:, :3]
                
                x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()
                diff = x0_t_new - x0_t
                d = diff
                if v is None:
                    v = d
                else:
                    v = beta * v + (1-beta) * d
                x0_t_last = x0_t
                x0_t += lr * v
                x0_t = H_funcs.proj(x0_t, y_0, alpha_obs)

                xt_next = x0_t
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))
    return xs, x0_preds


def efficient_generalized_steps_noisy(x, seq, model, b, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        var_obs = H_funcs.eq_var(sigma_0 ** 2)
        alpha_obs = 1 / torch.tensor(1+var_obs)
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        # x_obs_t = (x_T - et * (1 - at/alpha_obs).sqrt()) / (at/alpha_obs).sqrt()
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * torch.randn_like(x0_t)
        # x_obs_t = alpha_obs.sqrt() * x0_t + (1-alpha_obs).sqrt() * noise
        xt = x_T
        v = None
        beta=0.0
        lr_obs=1.0
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                # print(x_obs_t)
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                if at[0,0,0,0] <= alpha_obs:
                    noise = torch.randn_like(x0_t)
                    xt = (at/alpha_obs).sqrt() * x_obs_t + (1-at/alpha_obs).sqrt() * noise
                    et = model(xt, t)
                    if et.size(1) == 6:
                        et = et[:, :3]
                    x0_t_new = x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    x_obs_t_new = alpha_obs.sqrt() * x0_t_new + (1-alpha_obs).sqrt() * torch.randn_like(x0_t_new)
                else:
                    sigma_t_tilde = 0
                    xt = at.sqrt() * x0_t + (1-at - sigma_t_tilde**2).sqrt() * (x_obs_t - alpha_obs.sqrt() * x0_t) / (1-alpha_obs).sqrt()
                    et = model(xt, t)
                    if et.size(1) == 6:
                        et = et[:, :3]
                    x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()
                    x_obs_t_new = x_obs_t
                x0_t += lr * (x0_t_new - x0_t)
                x_obs_t += lr_obs * (x_obs_t_new - x_obs_t)
                if at[0,0,0,0] <= alpha_obs:
                    x_obs_t = H_funcs.proj(x_obs_t, y_0, alpha_obs)
                
                xt_next = x0_t


                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))
    return xs, x0_preds



def efficient_generalized_steps_noisy_SVD(x, seq, model, b, H_funcs, y_0, sigma_0, lr, N, cls_fn=None, classes=None):
    with torch.no_grad():
        #initialize x_T as given in the paper
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        
        #setup iteration variables
        singulars = H_funcs.singulars()
        # print(singulars.shape)
        Sigma = torch.zeros(x.shape[1]*x.shape[2]*x.shape[3], device=x.device)
        Sigma[:singulars.shape[0]] = singulars
        alpha_obs = torch.ones_like(Sigma)
        # alpha_obs = torch.zeros_like(Sigma) 
        alpha_obs[Sigma > 0] = 1 / (1 + (sigma_0 / Sigma[Sigma > 0])**2).unsqueeze(0)
        U_t_y = H_funcs.Ut(y_0)
        Sig_inv_U_t_y = U_t_y / singulars[:U_t_y.shape[-1]] * alpha_obs.sqrt()
        # print(Sig_inv_U_t_y.shape)
        alpha_obs = alpha_obs.view([1, x.shape[1], x.shape[2], x.shape[3]]).repeat(x.shape[0], 1, 1, 1)
        Sig_inv_U_t_y = Sig_inv_U_t_y.view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
        Sigma = Sigma.view([1, x.shape[1], x.shape[2], x.shape[3]]).repeat(x.shape[0], 1, 1, 1)
        # print(torch.sum(Sigma==0))
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]


        t = (torch.ones(n) * seq[-1]).to(x.device)
        at = compute_alpha(b, t.long())
        noise = torch.randn_like(x)
        x_T = noise * (1 - at).sqrt()
        et = model(x_T, t)
        if et.size(1) == 6:
            et = et[:, :3]
        x0_t = (x_T - et * (1 - at).sqrt()) / at.sqrt()
        V_t_x0 = H_funcs.Vt(x0_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
        V_t_x_obs = alpha_obs.sqrt() * V_t_x0 + (1-alpha_obs).sqrt() * torch.randn_like(V_t_x0)
        x_obs_t = H_funcs.V(V_t_x_obs.view([V_t_x_obs.shape[0], -1])).view(x.shape)
        
        # print(x0_t)
        # print(y_upsampling)
        v = None
        beta=0.0
        lr_obs = 1.0
        init_noise = torch.randn_like(x0_t)
        et = None
        #iterate over the timesteps
        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            for _ in range(N):
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                V_t_x0 = H_funcs.Vt(x0_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs = H_funcs.Vt(x_obs_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                smaller_idx = (alpha_obs < at[0,0,0,0])
                # print(smaller_idx)
                larger_idx = (alpha_obs >= at[0,0,0,0])
                V_t_x_t = torch.zeros_like(V_t_x_obs)
                V_t_x_t[larger_idx] = (at[0,0,0,0]/alpha_obs[larger_idx]).sqrt() * V_t_x_obs[larger_idx] + (1-at[0,0,0,0]/alpha_obs[larger_idx]).sqrt() * torch.randn_like(V_t_x_obs[larger_idx])
                V_t_x_t[smaller_idx] = at[0,0,0,0].sqrt() * V_t_x0[smaller_idx] + (1-at[0,0,0,0]).sqrt() * (V_t_x_obs[smaller_idx] - V_t_x0[smaller_idx] * alpha_obs[smaller_idx].sqrt())/(1-alpha_obs[smaller_idx]).sqrt()
                xt = H_funcs.V(V_t_x_t.view([V_t_x_t.shape[0], -1])).view(x.shape)
                if cls_fn == None:
                    et = model(xt, t)
                else:
                    et = model(xt, t, classes)
                    et = et[:, :3]
                    et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
                
                if et.size(1) == 6:
                    et = et[:, :3]

                x0_t_new = (xt - et * (1 - at).sqrt()) / at.sqrt()

                V_t_x0_new = H_funcs.Vt(x0_t_new).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs_new = alpha_obs.sqrt() * V_t_x0_new + (1-alpha_obs).sqrt() * torch.randn_like(V_t_x0_new)
                V_t_x0[larger_idx] = V_t_x0_new[larger_idx]
                V_t_x_obs_new[smaller_idx] = V_t_x_obs[smaller_idx]
                x0_t = H_funcs.V(V_t_x0.view([V_t_x0.shape[0], -1])).view(x.shape)
                x_obs_t_new = H_funcs.V(V_t_x_obs_new.view([V_t_x_obs_new.shape[0], -1])).view(x.shape)
                x0_t += lr * (x0_t_new - x0_t)
                x_obs_t += lr_obs * (x_obs_t_new - x_obs_t)
                V_t_x_obs = H_funcs.Vt(x_obs_t).view([x.shape[0], x.shape[1], x.shape[2], x.shape[3]])
                V_t_x_obs[Sigma > 0] = Sig_inv_U_t_y[Sigma > 0]
                x_obs_t = H_funcs.V(V_t_x_obs.view([V_t_x_obs.shape[0], -1])).view(x.shape)

                xt_next = x0_t
                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

    return xs, x0_preds