import torch
from noise import *
from utilities import *
from visualize import *
import tqdm
from tqdm import tqdm
from forward_process import *


def slerp(z1, z2, alpha):
    theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
    return (
        torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
        + torch.sin(alpha * theta) / torch.sin(theta) * z2
    )


def my_generalized_steps(y, x, seq, model, b, config, eta2, eta3, constants_dict, eraly_stop = True):
    with torch.no_grad():
        n = x.size(0)
        
        m = seq[0].size(0)  # Length of each sequence
        x0_preds = []
        xs = [x]

        if config.model.one_step:
                t = (torch.ones(n) * seq[-1]).to(x.device)
                at = compute_alpha(b, t.long(), config)
                xt = xs[-1].to(config.model.device)
                et = model(xt, t)
                yt = at.sqrt() * y + (1- at).sqrt() *  et
                et_hat = et - (1 - at).sqrt() * eta2 * (yt-xt)
                x0 = (xt - (1- at).sqrt() *  et_hat) * (1/at.sqrt())
                xs.append(x0.to('cpu'))
        else:      
            for idx in reversed(range(m)):
                t = torch.tensor([s[idx] for s in seq]).to(config.model.device)
                if idx == 0:
                    next_t = torch.full((n,), -1, device=x.device, dtype=torch.long)
                else:
                    next_t = torch.tensor([s[idx-1] for s in seq]).to(x.device)
                
                at = compute_alpha2(b, t.long(), config)

                at_next = compute_alpha2(b, next_t.long(),config)

                xt = xs[-1].to(config.model.device)
                
                et = model(xt, t)

                yt = at.sqrt() * y + (1- at).sqrt() *  et

                #DDAD error correction
                et_hat = et - (1 - at).sqrt() * eta2 * (yt-xt)

                x0_t = (xt - et_hat * (1 - at).sqrt()) / at.sqrt()


                x0_preds.append(x0_t.to('cpu')) 

                c1 = (
                    config.model.eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
                )
                c2 = ((1 - at_next) - c1 ** 2).sqrt()

                xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et_hat


                xs.append(xt_next.to('cpu'))


    return xs, x0_preds


def my_generalized_steps_visual(y, x, orig_image,seq, model, b, config, eta2, eta3, constants_dict, idx, vae, show_image, step_list):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        show_image = show_image
        reconstruction_y = 1 / 0.18215 * y
        reconstruction_y = vae.decode(reconstruction_y.to(config.model.device)).sample
        reconstruction_x = 1 / 0.18215 * x
        reconstruction_x = vae.decode(reconstruction_x.to(config.model.device)).sample
        step_list.append(orig_image)
        step_list.append(reconstruction_x)
        step_list.append(reconstruction_y)
        
        for index, (i, j) in enumerate(zip(reversed(seq), reversed(seq_next))):
            t = (torch.ones(n) * i).to(x.device)
            t_one = (torch.ones(n)).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(b, t.long(), config)
            at_half = compute_alpha(b,(t/2).long(), config)
            at_next = compute_alpha(b, next_t.long(),config)
            xt = xs[-1].to(config.model.device)
            
            et = model(xt, t)
            yt = at.sqrt() * y + (1- at).sqrt() *  et

            et_hat = et - (1 - at).sqrt() * eta2 * (yt-xt) #unet_condition(xt, yt, t) #torch.clamp((yt-xt), min=torch.min(yt-xt)*30/100, max=torch.max(yt-xt)*30/100) #torch.clamp((yt-xt), min=torch.min(y-x) + (torch.mean(yt-xt)- torch.min(yt-xt))/2, max=torch.max(yt-xt)-(torch.max(yt-xt)- torch.mean(yt-xt))/2) #unet_condition(xt, yt, t)  # * 50 *   # unet_condition(xt, yt, t)    (yt-xt)

            x0_t = (xt - et_hat * (1 - at).sqrt()) / at.sqrt()
   

            x0_preds.append(x0_t.to('cpu')) 
            if index == 0:
                c1 = torch.zeros_like(x0_t)
                c2 = torch.zeros_like(x0_t)
            else:
                c1 = (
                    config.model.eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
                )
                c2 = ((1 - at_next) - c1 ** 2).sqrt()
            
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et_hat

            reconstruction_xt_next = 1 / 0.18215 * xt_next
            reconstruction_xt_next = vae.decode(reconstruction_xt_next.to(config.model.device)).sample
            
            step_list.append(reconstruction_xt_next)
            
            xs.append(xt_next.to('cpu'))
        
        plt.figure(figsize=(50,10))
        for i in range(len(step_list)):
            print(i)
            plt.subplot(1, len(step_list), i+1).axis('off')
            plt.imshow(show_image(step_list[i]))
            if i == 0:
                plt.title("x_0_orig")
            
            elif i == 1:
                plt.title("x_T_hat")
            else:
                plt.title(f'step {len(step_list)-i}')
            
               
        plt.savefig('results_time/{}sample{}steps{}_noise_{}.png'.format(config.data.category,idx,len(step_list),config.model.noise_sampling))
        plt.close()
        
    return xs, x0_preds



def DA_generalized_steps(y, x, seq, model, b, config, eta2, eta3, constants_dict, eraly_stop = True):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        
        xs = [x]

        if config.model.one_step:
                t = (torch.ones(n) * seq[-1]).to(x.device)
                at = compute_alpha(b, t.long(), config)
                xt = xs[-1].to(config.model.device)
                et = model(xt, t)
                yt = at.sqrt() * y + (1- at).sqrt() *  et
                et_hat = et - (1 - at).sqrt() * eta2 * (yt-xt)
                x0 = (xt - (1- at).sqrt() *  et_hat) * (1/at.sqrt())
                xs.append(x0.to('cpu'))
        else:      
            for index, (i, j) in enumerate(zip(reversed(seq), reversed(seq_next))):
                t = (torch.ones(n) * i).to(x.device)
                t_one = (torch.ones(n)).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long(), config)
                at_half = compute_alpha(b,(t/2).long(), config)
                at_next = compute_alpha(b, next_t.long(),config)
                xt = xs[-1].to(config.model.device)
                
                et = model(xt, t)
                
                yt = at.sqrt() * y + (1- at).sqrt() *  et

                et_hat = et - (1 - at).sqrt() * eta2 * (yt-xt)
  
                
                x0_t = (xt - et_hat * (1 - at).sqrt()) / at.sqrt()

                


                x0_preds.append(x0_t.to('cpu')) 
                if index == 0:
                    c1 = torch.zeros_like(x0_t)
                    c2 = torch.zeros_like(x0_t)
                else:
                    c1 = (
                        config.model.eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
                    )
                    c2 = ((1 - at_next) - c1 ** 2).sqrt()
                
                xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et_hat

  

                xs.append(xt_next.to('cpu'))


    return xs, x0_preds
