import torch
import torch.nn.functional as F
import math
import scipy
import numpy as np
import tqdm 

def LIM_sampler(args, config, x, y, model, sde, levy, sde_clamp=None,
                masked_data=None, mask=None, t0=None, device='cuda'):
    
    if args.sample_type not in ['sde', 'ode', 'sde_imputation']:
        raise Exception("Invalid sample type")
    
    is_isotropic= config.diffusion.is_isotropic 
    steps = args.nfe
    eps = 1e-5
    method = args.sample_type
    
    if sde_clamp is None:
        sde_clamp = config.sampling.sde_clamp
    
    def score_model(x, t):
        
        if config.model.is_conditional:
            out = model(x, t, y)
        else:
            out = model(x, t)
        return out

    def impainted_noise(data, noise, mask, t, device):
        sigma = sde.marginal_std(t)
        x_coeff = sde.diffusion_coeff(t)

        if sde.alpha == 2:
            e_L = torch.randn(size=(data.shape)).to(device)
        else:
            if is_isotropic:
                e_L = levy.sample(alpha=sde.alpha, size=data.shape, is_isotropic=True, clamp=sde_clamp,clamp_threshold = config.diffusion.clamp_threshold).to(device)
            else:
                e_L = torch.clamp(levy.sample(alpha=sde.alpha, size=data.shape, is_isotropic=False, clamp=None).to(device), 
                                      min=-sde_clamp, max=sde_clamp)
        
        data = x_coeff[:, None, None, None] * data + sigma[:, None, None, None] * e_L
        masked_data = data * mask + noise * (1-mask)

        return masked_data

    def ode_score_update(x, s, t):
        """
        input: x_s, s, t
        output: x_t
        """
        score_s = score_model(x, s) * torch.pow(sde.marginal_std(s), -(sde.alpha-1))[:,None,None,None]
        
        beta_step = sde.beta(s) * (s - t)

        x_coeff = sde.diffusion_coeff(t) * torch.pow(sde.diffusion_coeff(s), -1)
        
        if sde.alpha == 2:
            score_coeff = torch.pow(sde.marginal_std(s), sde.alpha-1) * sde.marginal_std(t) * ( 1- torch.exp(h_t))
            x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s
        else:             
            a = sde.diffusion_coeff(t)*torch.pow(sde.diffusion_coeff(s), -1)
            score_coeff = -sde.alpha*(1- a)
            
            x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s
            
        return x_t

    def sde_score_update(x, s, t):
        """
        input: x_s, s, t
        output: x_t
        """
        score_s = score_model(x, s) * torch.pow(sde.marginal_std(s), -(sde.alpha-1))[:,None,None,None]
        
        beta_step = sde.beta(s) * (s - t)

        a = torch.exp(sde.marginal_log_mean_coeff(t) - sde.marginal_log_mean_coeff(s))
        x_coeff = a

        noise_coeff = torch.pow(-1 + torch.pow(a, sde.alpha), 1/sde.alpha)

        if sde.alpha == 2:
            score_s = score_model(x, s) * torch.pow(sde.marginal_std(s) + 1e-5, -(sde.alpha-1))[:,None,None,None]
            e_B = torch.randn_like(x).to(device)

            x_coeff = 1 + beta_step/sde.alpha
            score_coeff = beta_step
            noise_coeff = torch.pow(beta_step, 1 / sde.alpha)
            x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s + noise_coeff[:, None, None,None] * e_B
        else:
            if is_isotropic:
                e_L = levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=True, clamp=sde_clamp).to(device)
            else:
                e_L = torch.clamp(levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=False, clamp=None).to(device), 
                                  min=-sde_clamp, max=sde_clamp)

            score_coeff = sde.alpha ** 2 * (-1 + a)
            x_t = x_coeff[:,None,None,None] * x + score_coeff[:, None, None, None] * score_s + noise_coeff[:, None, None,None] * e_L
        
        # x_t = x_t.type(torch.float32)
       
        return x_t
    
    def sde_score_update_imputation(data, mask, x, s, t):
        score_s = score_model(x, s) * torch.pow(sde.marginal_std(s), -(sde.alpha-1))[:,None,None,None]

        beta_step = sde.beta(s) * (s - t)
        beta_step = (sde.marginal_log_mean_coeff(t)-sde.marginal_log_mean_coeff(s))*sde.alpha

        x_coeff = 1 + beta_step/sde.alpha
        a = sde.diffusion_coeff(t)*torch.pow(sde.diffusion_coeff(s),-1)

        noise_coeff = torch.pow(-1+a,1/sde.alpha)

        if sde.alpha == 2:
            e_B = torch.randn_like(x).to(device)
           
            score_coeff = beta_step
            x_t = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score_s + noise_coeff[:, None, None,None] * e_B
        else:
            if is_isotropic:
                e_L = levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=True, clamp=sde_clamp,clamp_threshold = config.diffusion.clamp_threshold).to(device)
            else:
                e_L = torch.clamp(levy.sample(alpha=sde.alpha, size=x.shape, is_isotropic=False, clamp=None).to(device), 
                                  min=-sde_clamp, max=sde_clamp)
                
            score_coeff = sde.alpha * beta_step
            score_coeff =(sde.marginal_log_mean_coeff(t)-sde.marginal_log_mean_coeff(s))*sde.alpha**2
            score_coeff = sde.alpha**2*(-1+a)
            x_t = x_coeff[:, None, None, None] * x+ score_coeff[:, None, None, None] * score_s + noise_coeff[:, None, None,None] * e_L

        x_t = impainted_noise(data, x_t, mask, t,device)

        return x_t


    # Sampling steps    
    # timesteps = torch.linspace(sde.T, eps, steps + 1).to(device)  # linear
    timesteps = torch.pow(torch.linspace(np.sqrt(sde.T), np.sqrt(eps), steps + 1), 2).to(device) # quadratic
    
    
    with torch.no_grad():
        if method == "sde_imputation":
            x = impainted_noise(masked_data, x, mask, t0, device)
            
        for i in tqdm.tqdm(range(steps)):
            vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]
            
            if method == 'sde':
                x = sde_score_update(x, vec_s, vec_t)
                # clamp threshold : re-normalization
                if config.sampling.clamp_threshold :
                    size = x.shape
                    l = len(x)
                    x = x.reshape((l, -1))
                    indices = x.norm(dim=1) > config.sampling.clamp_threshold
                    x[indices] = x[indices] / x[indices].norm(dim=1)[:, None] * config.sampling.clamp_threshold
                    x = x.reshape(size)
            
            elif method == 'sde_imputation':
                x = sde_score_update_imputation(masked_data, mask, x, vec_s, vec_t)
                x = sde_score_update(x, vec_s, vec_t)

                size = x.shape
                l = len(x)
                x = x.reshape(( l, -1))
                indices = x.norm(dim=1) > config.diffusion.clamp_threshold
                x[indices] = x[indices] / x[indices].norm(dim=1)[:, None] * config.diffusion.clamp_threshold
                x = x.reshape(size)
    
            elif method == 'ode':
                x = ode_score_update(x, vec_s, vec_t)

    return x