
import torch
import random
from src.libs.SDE import VariancePreservingTruncatedSampling

def deconcat(z,mod_list,sizes):
    z_mods={}
    idx=0
    for i,mod in enumerate( mod_list):
        z_mods[mod] = z[:,idx:idx+ sizes[i] ]
        idx +=sizes[i]
    return z_mods

def concat_vect(encodings):
    z = torch.Tensor()
    for key in encodings.keys():
        z = z.to(encodings[key].device)
        z = torch.cat( [z, encodings[key]],dim = -1 )
    return z 

def unsequeeze_dict(data):
        for key in data.keys():
            if data[key].ndim == 1 :
                data[key]= data[key].view(data[key].size(0),1)
        return data


class VP_SDE():
    def __init__(self, 
                 beta_min=0.1, 
                 beta_max=20, 
                 N = 1000,
                 importance_sampling =True ,
                 liklihood_weighting= False,
                 nb_mod = 2
                ):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.N = N
        self.T = 1
        self.importance_sampling = importance_sampling
        self.liklihood_weighting = liklihood_weighting
        self.device = "cuda"
        self.nb_mod = nb_mod
        self.t_epsilon = 1e-3
        
        
        
        
    def beta_t(self,t):
        return self.beta_min + t * (self.beta_max - self.beta_min) 
    
   
    def sde(self,t):
        return -0.5*self.beta_t(t), torch.sqrt(self.beta_t(t))
    
    
    def marg_prob(self, t,x):
        ## return mean std of p(x(t))
        log_mean_coeff = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
        
        log_mean_coeff = log_mean_coeff.to(self.device)
     
        mean = torch.exp(log_mean_coeff)
        std = torch.sqrt(1 - torch.exp(2 * log_mean_coeff))
        return mean * torch.ones_like(x).to(self.device), std.view(-1,1) * torch.ones_like(x).to(self.device)
    

    def sample(self, t, data, mods_list ):
        
        
        nb_mods = len(mods_list)
        self.device = t.device
        
        x_t_m = {}
        std_m = {}
        mean_m = {}
        z_m = {}
        g_m = {}
        
        for i ,mod in enumerate( mods_list ):
            x_mod = data[mod]
           
            z = torch.randn_like(x_mod).to(self.device)
            f,g = self.sde(t[:,i])
            
            mean_i, std_i = self.marg_prob(t[:,i].view(x_mod.shape[0],1),x_mod)

            std_m[mod]= std_i
            mean_m [mod] = mean_i
            z_m [mod] = z
            g_m[mod] = g
            x_t_m[mod] =  mean_i * x_mod +  std_i * z
    

        #score = - score_net(X_t,time = t , std = STD )

        #loss = torch.square((score * STD ) + Z).sum(0)
       
        return x_t_m,z_m, std_m ,g_m , mean_m


    
    
    def train_step(self,data,score_net, eps = 1e-5, d = 0.5 ):
        #data= unsequeeze_dict(data)
        x = concat_vect(data)

        mods_list = list(data.keys())
        mods_sizes = [data[key].size(1) for key in mods_list ]

        nb_mods = len(mods_list)

        if self.importance_sampling:
            t = ( (self.T - eps) * self.sample_debiasing_t(shape=(x.shape[0],1)) + eps ).to(self.device)
        else:
            t = ( (self.T - eps) * torch.rand((x.shape[0],1)) + eps ).to(self.device)

        t_n = t.expand((x.shape[0],nb_mods ) )

        learn_cond = (torch.bernoulli(torch.tensor([d] )) ==1.0)
        mask = [1,1]
        if learn_cond:
            subsets = [[0,1],[1,0]]
            i = random.randint(0, len(mods_list)-1 )
            mask = subsets[i]
            mask_time = torch.tensor( mask ).to(self.device).expand(t_n.size()) 
            t_n = t_n * mask_time

        x_t_m , z_m , std_m , g_m, mean_m  = self.sample(t= t_n, data= data,mods_list= mods_list)

        score = - score_net(concat_vect(x_t_m),t = t_n , std = None )
        
        weight = 1.0
        if learn_cond:
             
            score_m = deconcat(score,mods_list,mods_sizes)
            for idx,i in enumerate(mask):

                if i ==0:
                     ## all the benchmark has two equal size mods
                    dim_clean = score_m[mods_list[idx]].size(1)
                    z_m.pop(mods_list[idx])
                    score_m.pop(mods_list[idx])
                    
                else:
                    dim_diff = score_m[mods_list[idx]].size(1)
            weight += dim_clean/dim_diff
            score = concat_vect(score_m) 
            
        loss =  weight * torch.square( score + concat_vect(z_m) ).sum(1, keepdim=False)
      
        return loss
     










    def train_step_cond(self,data,score_net, eps = 1e-5, d = 0.5 ):
        #data= unsequeeze_dict(data)
        x = concat_vect(data)

        mods_list = list(data.keys())

        nb_mods = len(mods_list)

        if self.importance_sampling:
            t = ( (self.T - eps) * self.sample_debiasing_t(shape=(x.shape[0],1)) + eps ).to(self.device)
        else:
            t = ( (self.T - eps) * torch.rand((x.shape[0],1)) + eps ).to(self.device)

        t_n = t.expand((x.shape[0],nb_mods ) )

        
        mask = [1,0]

        x_t_m , z_m , std_m , g_m, mean_m  = self.sample(t= t_n, data= data,mods_list= mods_list)

        learn_cond = (torch.bernoulli(torch.tensor([d] )) ==1.0)
        if learn_cond:
            mask = [1,0]
            mask_time = torch.tensor( mask ).to(self.device).expand(t_n.size()) 
            t_n = t_n * mask_time + 1.0 * (1 - mask_time)
            #print(learn_cond)
            #print(t_n)
            x_t = concat_vect({"x":x_t_m["x"],
                 "y": data["y"] })
        else:
            mask = [1,0]
            mask_time = torch.tensor( mask ).to(self.device).expand(t_n.size()) 
            t_n = t_n * mask_time + 0.0 * (1 - mask_time)

            x_t = concat_vect({"x":x_t_m["x"],
                 "y": torch.zeros_like(data["y"]) })

            #print(learn_cond)
            #print(t_n)

        score = - score_net(x_t,t = t_n , std = None )
        weight = 1.0  
        loss =  weight * torch.square( score + z_m["x"]).sum(1, keepdim=False)
        return loss






    def sample_debiasing_t(self, shape):
        """
        non-uniform sampling of t to debias the weight std^2/g^2
        the sampling distribution is proportional to g^2/std^2 for t >= t_epsilon
        for t < t_epsilon, it's truncated
        """
        return sample_vp_truncated_q(shape, self.beta_min, self.beta_max, t_epsilon=self.t_epsilon, T=self.T)




    def euler_step(self, x_t, t,dt, score_net ):
        
        
        time = t * torch.ones( (x_t.shape[0],self.nb_mod)).to(self.device)

        mean,std = self.marg_prob(t,x_t)

        with torch.no_grad():
            s = - score_net(x_t,time,std).detach()
        
        f,g = self.sde(t)
        x = x_t - dt*(f*x_t - (g**2) *s)  + g * torch.sqrt( dt).to(self.device)  * torch.randn_like(x_t)
        return x , t-dt

   
    def euler_step_impaiting(self, x_t,x_0, t,dt, score_net , mask = None, subset = None):
             
        time = t  * torch.ones( (x_t.shape[0], self.nb_mod) ).to(self.device)
        
        ## diffuse availible modality
        mean,std = self.marg_prob(t,x_t)
        
        mask_time = torch.tensor([ 1. if i in subset else 0. for i in range(self.nb_mod)]).to(self.device).expand(time.size()) 
        time = time * (1. - mask_time)
        
        x_aux = (x_0 * mask) + x_t * (1. - mask)
       
        ## score
        with torch.no_grad():
            
            s = - score_net(x_aux,time,std).detach()
        
        f,g = self.sde(t)
        ## Euler step
        if t == 0.001:
            noise = 0
        else :
            noise = torch.randn_like(x_t)
        x = x_aux - dt*(f*x_aux - (g**2) *s)  + g * torch.sqrt(dt) * noise
  
        x = x * (1. - mask) + x_0 * mask
        
        return x  , t-dt



    
    def euler_step_c(self, x_t, t,dt, score_net ):
        
        
        time = t  * torch.ones( (x_t.shape[0], self.nb_mod) ).to(self.device)
        
        ## diffuse availible modality
        mean,std = self.marg_prob(t,x_t)
        
        mask_time = torch.tensor([1,0]).to(self.device).expand(time.size()) 
        time = time * mask_time + 0.0 * (1. - mask_time)
        
        x_aux = x_t
       
        ## score
        with torch.no_grad():
          
            s = - score_net(x_aux,time,None).detach()
            s = s /std[:,:s.size(1)]
        
        f,g = self.sde(t)
        ## Euler step
        if t == 0.001:
            noise = 0
        else :
            noise = torch.randn_like(s)
        x_t_up = x_aux[:,:s.size(1)]

        x = x_t_up - dt*(f*x_t_up - (g**2) *s)  + g * torch.sqrt(dt) * noise

        x = torch.cat([x_t_up,x_aux[:,s.size(1):]],dim=1)
     
        return x  , t-dt

    def euler_step_impaiting_c(self, x_t,x_0, t,dt, score_net , mask = None):
             
        time = t  * torch.ones( (x_t.shape[0], self.nb_mod) ).to(self.device)
        
        ## diffuse availible modality
        mean,std = self.marg_prob(t,x_t)
        
        mask_time = torch.tensor([1,0]).to(self.device).expand(time.size()) 
        time = time * mask_time + 1.0 * (1. - mask_time)
        
        x_aux = (x_0 * mask) + x_t * (1. - mask)
       
        ## score
        with torch.no_grad():
            
            s = - score_net(x_aux,time,None).detach()
            s = s /std[:,:s.size(1)]


        f,g = self.sde(t)
        ## Euler step
        if t == 0.001:
            noise = 0
        else :
            noise = torch.randn_like(s)

        x_t_up = x_aux[:,:s.size(1)]

        x = x_t_up - dt*(f*x_t_up - (g**2) *s)  + g * torch.sqrt(dt) * noise
  
        x = torch.cat([x_t_up,x_0[:,s.size(1):]],dim=1)
        
        return x  , t-dt




    def modality_inpainting_c(self, score_net,x, mask , subset):
        
        t = torch.Tensor([1.0]).to(self.device)
        t_ind = 1.0
        dt = torch.Tensor( t/self.N).to(self.device)
        x_c = x    
        while t_ind>0:
            x_c,t = self.euler_step_impaiting_c(x_t= x_c, x_0= x.clone(),t= t, dt = dt, score_net= score_net,mask = mask )
            t_ind = t_ind - dt
        return x_c 
    

    def modality_inpainting(self, score_net,x, mask , subset):
        
        t = torch.Tensor([1.0]).to(self.device)
        t_ind = 1.0
        dt = torch.Tensor( t/self.N).to(self.device)
        x_c = x    
        while t_ind>0:
            x_c,t = self.euler_step_impaiting(x_t= x_c, x_0= x.clone(),t= t, dt = dt, score_net= score_net,mask = mask ,subset = subset)
            t_ind = t_ind - dt
        return x_c 

    def sample_euler(self,x_c, score_net):
   
        
        t = torch.Tensor([1.0]).to(self.device)
        dt = torch.Tensor(t/self.N).to(self.device)
        
        mean,std = self.marg_prob(t,x_c)
        x_c= x_c * mean + std *  torch.randn_like(x_c).to(self.device) 
       
        while t>0:
            x_c,t= self.euler_step(x_c,t,dt,score_net)
        return x_c 
    
    
    def sample_euler_c(self,x_c, score_net):
   
        
        t = torch.Tensor([1.0]).to(self.device)
        dt = torch.Tensor(t/self.N).to(self.device)
        
        mean,std = self.marg_prob(t,x_c)
        x_c= x_c * mean + std *  torch.randn_like(x_c).to(self.device) 
       
        while t>0:
            x_c,t= self.euler_step_c(x_c,t,dt,score_net)
        return x_c 


def sample_vp_truncated_q(shape, beta_min, beta_max, t_epsilon, T):
    if isinstance(T, float) or isinstance(T, int):
        T = torch.Tensor([T]).float()
    u = torch.rand(*shape).to(T)
    vpsde = VariancePreservingTruncatedSampling(beta_min=0.1, beta_max=20., t_epsilon=t_epsilon)
    return vpsde.inv_Phi(u.view(-1), T).view(*shape)