import time
import torch

from experiments.utils_metrics import compute_log_normal_pdf, mean_squared_error
from model.ivp_vae import IVPVAE
import math
import torch.nn as nn
class IVPVAE_Extrap(IVPVAE):
    def __init__(
            self,
            args,
            embedding_nn,
            embedding_nn_gnn,
            reconst_mapper,
            diffeq_solver1): 

        super().__init__(
            args,
            embedding_nn,
            embedding_nn_gnn,
            reconst_mapper,
            diffeq_solver1)

        self.args = args

    def compute_prediction_results(self, batch, k_iwae=1): 
    
        results, forward_info = self.forward(batch, k_iwae)

        data_out = batch['data_out'] 
        mask_out = batch['mask_out'] 

        if self.args.extrap_full == True:


            mask_extrap = batch['mask_extrap']

            pred_x = forward_info['pred_x']

            results['forward_time'] = time.time() - self.time_start

            results["mse"] = mean_squared_error(
                data_out, pred_x, mask=mask_out, mask_select=mask_extrap[..., None]).detach()

            results["mse_extrap"] = mean_squared_error(
                data_out, pred_x, mask=mask_out, mask_select=~mask_extrap[..., None]).detach()
        else:
        
            if self.args.ivp_solver == 'gnn' or self.args.ivp_solver == 'invergnn' or self.args.ivp_solver == 'graph':
                initial_state_h = forward_info['initial_state'].view(forward_info['initial_state'].shape[0],forward_info['initial_state'].shape[1],forward_info['initial_state'].shape[2],37,1)
                sol_z = self.ivp_solver1(
                    forward_info['initial_state'], initial_state_h,batch['times_out'].unsqueeze(0), self.adj)
                zn = sol_z[:, :, -1:, :] 
                zn_h = zn.view(zn.shape[0], zn.shape[1], zn.shape[2], 37, 1)    
                times_back = torch.neg(batch['times_out'])
                sol_z_back = self.ivp_solver1(
                    zn, zn_h,times_back.unsqueeze(0), self.adj)
                sol_z_back = sol_z_back.flip(2)    

            else:
                sol_z = self.ivp_solver1(
                    forward_info['initial_state'],batch['times_out'].unsqueeze(0))
                zn = sol_z[:, :, -1:, :] 
                times_back = torch.neg(batch['times_out'])
                sol_z_back = self.ivp_solver1(
                    zn,times_back.unsqueeze(0))
                sol_z_back = sol_z_back.flip(2)   

            next_data = data_out.repeat(k_iwae, 1, 1, 1)
            next_mask = mask_out.repeat(k_iwae, 1, 1, 1)

            pred_x = self.reconst_mapper(sol_z)
            results['forward_time'] = time.time() - self.time_start

            rec_likelihood = compute_log_normal_pdf(
                next_data, mask_out, pred_x, self.args)


            loss_next = -torch.logsumexp(rec_likelihood, dim=0)

            loss_next = torch.mean(loss_next, dim=0)
            assert (not torch.isnan(loss_next))
            recon_loss = ((sol_z-sol_z_back)**2).mean()

            results["loss"] = results["loss"] + recon_loss
            
            if self.args.train_w_reconstr:

                results["loss"] = results["loss"] + \
                    self.args.ratio_nl * loss_next
            else:

                results["loss"] = loss_next

            mse_extrap = mean_squared_error(next_data, pred_x, mask=next_mask)
            results["mse_extrap"] = torch.mean(mse_extrap).detach()

        return results

    def run_validation(self, batch):
        return self.compute_prediction_results(batch, k_iwae=self.args.k_iwae)

class ScaleDotProductAttention(nn.Module):
    """
    compute scale dot product attention

    Query : given sentence that we focused on (decoder)
    Key : every sentence to check relationship with Qeury(encoder)
    """

    def __init__(self, c):
        super(ScaleDotProductAttention, self).__init__()
        self.w_q = nn.Linear(c, c,device='cuda:1')
        self.w_k = nn.Linear(c, c,device='cuda:1')
        self.softmax = nn.Softmax(dim = 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x_input,mask=None, e=1e-12):
      
        x = x_input.permute(0, 2, 1, 3)
        shape = x.shape
        x_shape = x.reshape((shape[0],shape[1], -1))
        batch_size, length, c = x_shape.size()
        q = self.w_q(x_shape)
        k = self.w_k(x_shape)
        k_t = k.view(batch_size, c, length) 
        score = (q @ k_t) / math.sqrt(c) 

        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.dropout(self.softmax(score))
        score = torch.mean(score.float(), dim=0)
        return score, k
    

