import numpy as np
import torch
import torch
import torch.nn as nn
import pytorch_lightning as pl
from tqdm import tqdm
from src.model.score_net import UnetMLP
from src.libs.ema import EMA
from src.libs.SDE_m import VP_SDE ,concat_vect ,deconcat
learning_rate = 1e-4



T0 = 1
vtype = 'rademacher'
lr = 0.001

from src.libs.SDE import VariancePreservingSDE, PluginReverseSDE ,get_normalizing_constant


class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x): 
        return torch.sigmoid(x)*x



class Minde_c(pl.LightningModule):
    
    def __init__(self,dim_x,dim_y ,lr = 1e-3,mod_list=["x","y"],use_skip = True, 
                 debias = False, weighted = False,use_ema = False ,
                 d = 0.5,test_samples = None,gt = 0.0,batch_size =64,
                 ):
        super(Minde_c, self).__init__()
        self.dim_x =dim_x
        self.dim_y =dim_y
        self.mod_list = mod_list
        self.gt = gt 
        self.weighted = weighted


        if use_skip == True:
            dim = (dim_x +dim_y)
            if dim <=10:
                hidden_dim = 64
            elif dim <=50:
                 hidden_dim = 128
            else:
                 hidden_dim = 256

            time_dim = hidden_dim
            self.score = UnetMLP(dim= (dim_x +dim_y) , 
                                 init_dim= hidden_dim ,
                                 dim_mults= [], 
                                 time_dim= time_dim 
                                 ,nb_mod= 2,
                                 out_dim=dim_x )
  
        self.d =d
        self.stat = None
        self.debias = debias
        self.lr = lr
        self.use_ema = use_ema

        self.save_hyperparameters("d","debias","lr","use_ema","weighted","dim_x","dim_y","gt","batch_size")

        self.test_samples =  test_samples
        self.T = torch.nn.Parameter(torch.FloatTensor([T0]), requires_grad=False)
        self.model_ema = EMA(self.score, decay=0.999) if use_ema else None
        self.sde = VP_SDE(importance_sampling=self.debias ,liklihood_weighting=False)
        
        
    def training_step(self, batch, batch_idx):
       
        self.train()

        loss = self.sde.train_step_cond(batch,self.score,d = self.d).mean()  # forward and compute loss

        self.log("loss",loss)

        return {"loss":loss}
    
    


        # self.logger.experiment.add_scalars('e_scores', 
        #                                    {'e_x': logs ["x"], 
        #                                     'e_y': logs ["y"],
        #                                     'e_xy':logs ["xy"]}, 
        #                                    global_step=self.global_step)
      


    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:
            self.model_ema.update(self.score)


    def score_inference(self,x,t,std):
        with torch.no_grad():
            self.eval() 
            if self.use_ema:
                self.model_ema.module.eval()
                return self.model_ema.module(x,t,std)
            else:
                return self.score(x,t,std)

    def validation_step(self, batch, batch_idx):
        self.eval()
 
        loss = self.sde.train_step_cond(batch,self.score,d = self.d).mean()  # # forward and compute loss
        self.log("loss_test",loss)
        return {"loss":loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.score.parameters(), lr= self.lr,amsgrad=False  )
        return optimizer

    def on_train_epoch_end(self) -> None:
        super().on_train_epoch_end()
        if self.current_epoch % 50 == 0 :
            mi_debias_square= self.mi_compute(self.test_samples,debias=True)
            mi_non_debias_square = self.mi_compute(self.test_samples,debias=False )
            mi_debias = self.mi_compute_non_square(self.test_samples,debias=True)
            mi_non_debias = self.mi_compute_non_square(self.test_samples,debias=False )
       
            self.logger.experiment.add_scalars('Estimation mi',  
                                               {'gt': self.gt, 
                                                'mi_imp': mi_debias,
                                                'mi': mi_non_debias,
                                                "mi_square_imp":mi_debias_square,
                                                "mi_square":mi_non_debias_square,
                                                }, global_step=self.global_step)
        
    def mi_compute_non_square(self,data,debias =False,sigma =1.0, eps = 1e-5):

        self.sde.device = self.device
        self.score.eval()

        x,y = data["x"],data["y"]

        mods_list = list(data.keys())
        mods_sizes = [data[key].size(1) for key in mods_list ]
        nb_mods = len(mods_list)

        if debias:
                t_ = self.sde.sample_debiasing_t([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(self.device) 
        else:
                t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(self.device) * (self.T - eps) + eps
            # t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(x) *self.T+1e-3

        t_n = t_.expand((x.shape[0],nb_mods ) )
        
        Y, _, std, g , mean = self.sde.sample(t_n, data,mods_list)

        
        std_x = std["x"]
        mean_x = mean["x"]

        y_x = concat_vect({
             "x": Y["x"],
             "y": torch.zeros_like(Y["y"])}
        )

        y_xc = concat_vect({
            "x": Y["x"],
            "y": data["y"]}
        )

        mask_time_x = torch.tensor( [1,0] ).to(self.device).expand(t_n.size()) 
       
        t_n_x = t_n * mask_time_x + 0.0 * (1 - mask_time_x)
        t_n_c = t_n * mask_time_x + 1.0 * (1 - mask_time_x)



        with torch.no_grad():
            if debias:
                a_x = - self.score_inference(y_x, t_n_x, None).detach()
                a_xy = - self.score_inference(y_xc, t_n_c, None).detach()
       

            else:
                a_x = - self.score_inference(y_x, t_n_x, std_x).detach()
                a_xy = - self.score_inference(y_xc, t_n_c, std_x).detach()
     
        N=x.size(1)
        M=x.size(0)

        #a_cond = concat_vect({"x":a_x["x"],"y":a_y["y"]})


        chi_t_x = mean_x **2 * sigma **2 + std_x**2
        ref_score_x = (Y["x"])/chi_t_x # was *g
        



        if debias:
                #std = std["x"][:,0].reshape(t_.shape)
                const = get_normalizing_constant((1,),T = 1-eps ).to(x)

                e_x= -const *0.5* ((a_x + std_x* ref_score_x )**2).sum()/ M 

                e_xc= -const *0.5* ((a_xy + std_x* ref_score_x )**2).sum()/ M 

        else:
                g = g["x"].reshape(g["x"].size(0),1)

                e_x= -0.5* (g**2*(a_x + ref_score_x )**2).sum()/ M 

                e_xc= -0.5* (g**2*(a_xy + ref_score_x )**2).sum()/ M 
                
        return  e_x - e_xc 
    





    def mi_compute(self,data,debias =False, eps = 1e-5):

        self.sde.device = self.device
        self.score.eval()

        x,y = data["x"],data["y"]

        mods_list = list(data.keys())
        mods_sizes = [data[key].size(1) for key in mods_list ]
        nb_mods = len(mods_list)

        if debias:
                t_ = self.sde.sample_debiasing_t([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(self.device) 
        else:
                t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(self.device) * (self.T - eps) + eps
            # t_ = torch.rand([x.size(0), ] + [1 for _ in range(x.ndim - 1)]).to(x) *self.T+1e-3

        t_n = t_.expand((x.shape[0],nb_mods ) )
        
        Y, _, std, g , mean = self.sde.sample(t_n, data,mods_list)

        
        std_x = std["x"]
      
        y_x = concat_vect({
             "x": Y["x"],
             "y": torch.zeros_like(Y["y"])}
        )

        y_xc = concat_vect({
            "x": Y["x"],
            "y": data["y"]}
        )

        mask_time_x = torch.tensor( [1,0] ).to(self.device).expand(t_n.size()) 
       
        t_n_x = t_n * mask_time_x + 0.0 * (1 - mask_time_x)
        t_n_c = t_n * mask_time_x + 1.0 * (1 - mask_time_x)



        with torch.no_grad():
            if debias:
                a_x = - self.score_inference(y_x, t_n_x, None).detach()
                a_xy = - self.score_inference(y_xc, t_n_c, None).detach()
       

            else:
                a_x = - self.score_inference(y_x, t_n_x, std_x).detach()
                a_xy = - self.score_inference(y_xc, t_n_c, std_x).detach()


        N=x.size(1)
        M=x.size(0)

        #a_cond = concat_vect({"x":a_x["x"],"y":a_y["y"]})

        if debias:
                #std = std["x"][:,0].reshape(t_.shape)
                const = get_normalizing_constant((1,),T = 1 ).to(x)

                est_score =  const *0.5* ((a_x - a_xy)**2).sum()/ M 

        else:
                g = g["x"].reshape(g["x"].size(0),1)
               
                est_score =  0.5* (g**2*(a_x - a_xy )**2).sum()/ M 

                
        return  est_score.detach()
    

