import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformer.Constants as Constants
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from scipy import stats
from itertools import accumulate
import matplotlib.pyplot as plt
from base import Flow
#from scipy.integrate import cumtrapz
    
def _compute_knots(lengths, lower, upper):
    knots = torch.cumsum(lengths, dim=-1)
    knots = F.pad(knots, pad=(1, 0), mode='constant', value=0.0)
    knots = (upper - lower) * knots + lower
    knots[..., 0], knots[..., -1] = lower, upper  # This is expensive
    lengths = knots[..., 1:] - knots[..., :-1]
    return lengths, knots


class Spline(Flow):
    def __init__(self, 
        n_knots, 
        left=0., 
        right=1., 
        bottom=0., 
        top=1., 
        tails='undefined', 
        spline_order=2,
        min_bin_width = 1e-2,
        min_bin_height = 1e-2,
        min_derivative = 1e-2,
        min_lambda = 0.025,
        **kwargs
        ):
        """Rational linear/quadratic spline flow.

        Rational quadratic spline is based on https://github.com/bayesiains/nsf
        Rational linear spline is based on https://github.com/hmdolatabadi/LRS_NF

        We found RLS to provide no noticeable improvement, so we use RQS in our experiments.

        Args:
            n_knots: Number of knots for the spline.
            left: min input
            right: max input
            bottom: min output
            top: max output
            tails: behavior at tails either linear or undefined
            spline_order: either 1 or 2 - use rational linear or rational quadratic splines
            min_bin_width: minimum width of a spline segment
            min_bin_height: minimum height of a spline segment
            min_derivative: minimum derivative at a knot
            min_lambda: minimum value for lambda - only applies to rational linear splines
        """
        super().__init__()
        assert spline_order in [1, 2], 'Order rational polynomials of spline_order %i are not supported!' % spline_order
        assert tails in ['linear', 'undefined'], '%s tails are not supported!' % tails
        self.spline_order = spline_order
        self.tails = tails

        self.n_derivatives = n_knots + (1 if tails == 'undefined' else -1)
        self.n_knots = n_knots
        self.n_params = 2*self.n_knots + self.n_derivatives + (self.spline_order == 1) * self.n_knots
        self._params = nn.Parameter(torch.zeros(self.n_params))

        # We store everything as parameters to have it better accessible on the gpu
        self.left = nn.Parameter(torch.tensor(left), requires_grad=False)
        self.right = nn.Parameter(torch.tensor(right), requires_grad=False)
        self.bottom = nn.Parameter(torch.tensor(bottom), requires_grad=False)
        self.top = nn.Parameter(torch.tensor(top), requires_grad=False)

        self.min_bin_width = nn.Parameter(torch.tensor(min_bin_width), requires_grad=False)
        self.min_bin_height = nn.Parameter(torch.tensor(min_bin_height), requires_grad=False)
        self.min_derivative = nn.Parameter(torch.tensor(min_derivative), requires_grad=False)
        self.min_lambda = nn.Parameter(torch.tensor(min_lambda), requires_grad=False)

        self.min_d_constant = nn.Parameter(torch.log(torch.exp(1 - self.min_derivative) - 1), requires_grad=False)

        self.rest_width = nn.Parameter(1.0 - self.min_bin_width * self.n_knots, requires_grad=False)
        self.rest_height = nn.Parameter(1.0 - self.min_bin_height * self.n_knots, requires_grad=False)

        self.eps = torch.finfo(torch.get_default_dtype()).eps

        if self.min_bin_width * self.n_knots > 1.0:
            raise ValueError('Minimal bin width too large for the number of bins')
        if self.min_bin_height * self.n_knots > 1.0:
            raise ValueError('Minimal bin height too large for the number of bins')

    def reset_parameters(self, a=1.0):
        self._params.data.uniform_(-a, a)
        
    def forward(self, x):
        """Forward transformation.

        Args:
            x: Inputs, shape [batch_size, seq_len, 1]

        Returns:
            y: Outputs, shape [batch_size, seq_len, 1]
            log_det_jac: Log determinant of Jacobian, shape [batch_size, seq_len, 1]
        """
        y, log_det_jac = self._unconstrained_spline(x, self._params, inverse=False)
        return y, log_det_jac

    @torch.jit.export
    def inverse(self, y):
        """Inverse transformation.

        Args:
            y: Inputs, shape [batch_size, seq_len, 1]

        Returns:
            x: Outputs, shape [batch_size, seq_len, 1]
            log_det_jac: Log determinant of Jacobian, shape [batch_size, seq_len, 1]
        """
        x, inv_log_det_jac = self._unconstrained_spline(y, self._params, inverse=True)
        return x, inv_log_det_jac

    def _unconstrained_spline(self, inputs, params,affect,base, inverse: bool = False, tail="exp"):
        if not inverse:
            inside_interval_mask = (inputs >= self.bottom) & (inputs <= self.top)
        else:
            inside_interval_mask = (inputs >= self.left) & (inputs <= self.right)
        outside_interval_mask = ~inside_interval_mask
        outputs = inputs.clone()
        outputs[outside_interval_mask]=self.top
        logabsdet = torch.zeros_like(inputs)
        if self.tails == 'undefined' or inside_interval_mask.any():
            outputs, logabsdet = self._rational_spline(
                outputs, params,
                inverse=inverse
            )
        outputs[outside_interval_mask] = inputs[outside_interval_mask]
        
        if(tail=="linear"):
            outputs[outside_interval_mask]= inputs[outside_interval_mask]
            logabsdet[outside_interval_mask] = 0
        elif tail=="exp":
            outputs[outside_interval_mask]= inputs[outside_interval_mask]
            #only linear is acceptable(currently)
            logabsdet[outside_interval_mask] = -(inputs[outside_interval_mask]-self.top)*affect[outside_interval_mask] + base[outside_interval_mask]
        return outputs, logabsdet

    def _rational_spline(self, inputs, params, inverse: bool = False):
        # Decompose input
        width_logits = params[..., :self.n_knots]
        #print("width_logits.shape",width_logits.shape)
        height_logits = params[..., self.n_knots:2*self.n_knots]
        derivative_logits = params[..., self.n_knots*2:self.n_knots*2+self.n_derivatives]

        # Append the fixed derivatives
        if self.tails == 'linear':
            derivative_logits = F.pad(derivative_logits, pad=(1, 1), mode='constant', value=self.min_d_constant)
        
        # Normalize widths and heights
        widths = self.min_bin_width + self.rest_width * F.softmax(width_logits, dim=-1)
        heights = self.min_bin_height + self.rest_height * F.softmax(height_logits, dim=-1)
        #print("widths.shape",widths.shape)
        #print("widths.shape",heights.shape)

        # Compute knots
        widths, cum_widths = _compute_knots(widths, self.left, self.right)
        heights, cum_heights = _compute_knots(heights, self.bottom, self.top)
        #print("widths.shape",widths.shape)
        #print("width",widths)
        #print("cum_widths",cum_widths)
        #print("heights.shape",heights.shape)
        #print("cum_widths.shape",cum_widths.shape)
        #print("cum_heights.shape",cum_heights.shape)

        # Compute derivatives

        # Ensure positive derivatives
        derivatives = F.softplus(derivative_logits) + self.min_derivative

        # Find corresponding segments
        #print(cum_widths)
        #print(cum_widths[0,0,0,:])
        bin_idx = self._search_sorted(cum_heights if inverse else cum_widths, inputs)[..., None]
        #print(bin_idx.shape)
        #print(bin_idx[0,0,0,:,:])
        n_type=bin_idx.shape[0]
        n_batch=bin_idx.shape[1]
        n_timelen=bin_idx.shape[2]
        n_dim=bin_idx.shape[3]
        n_kt=widths.shape[3]

        n = len(bin_idx)

        widths = widths.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt)

        cum_widths = cum_widths.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt+1)

        heights = heights.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt)
        cum_heights = cum_heights.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt+1)
        
        derivatives = derivatives.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt+1)

        # Select input data
        input_widths = widths.gather(-1, bin_idx)[..., 0]
        input_cum_widths = cum_widths.gather(-1, bin_idx)[..., 0]
        input_heights = heights.gather(-1, bin_idx)[..., 0]
        input_cum_heights = cum_heights.gather(-1, bin_idx)[..., 0]
        input_delta = input_heights / input_widths
        input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
        input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]

        if self.spline_order == 1:
            # Prepare lambdas
            lambda_logits = params[..., self.n_knots*2+self.n_derivatives:]
            lambdas = (1 - 2*self.min_lambda) * torch.sigmoid(lambda_logits) + self.min_lambda
            lambdas=lambdas.unsqueeze(-2).expand(n_type,n_batch,n_timelen,n_dim,n_kt)
            #print(lambdas.shape)
            #lambdas = lambdas.expand(n, lambdas.shape[-1])
            input_lambdas = lambdas.gather(-1, bin_idx)[..., 0]
            
            result = self._rational_linear_spline(
                inputs=inputs, 
                x_k1_x_k=input_widths, 
                x_k=input_cum_widths, 
                y_k1_y_k=input_heights, 
                y_k=input_cum_heights,
                s_k=input_delta,
                d_k=input_derivatives,
                d_k1=input_derivatives_plus_one,
                lambdas=input_lambdas,
                inverse=inverse)

        elif self.spline_order == 2:
            result = self._rational_quadratic_spline(
                inputs=inputs, 
                x_k1_x_k=input_widths, 
                x_k=input_cum_widths,
                y_k1_y_k=input_heights, 
                y_k=input_cum_heights,
                s_k=input_delta,
                d_k=input_derivatives,
                d_k1=input_derivatives_plus_one,
                inverse=inverse)

        else:
            raise ValueError('Order %i rational polynomials are not support' % self.spline_order)
        return result

    def _rational_quadratic_spline(self,
            inputs,
            x_k1_x_k,
            x_k,
            y_k1_y_k,
            y_k,
            s_k,
            d_k,
            d_k1,
            inverse: bool = False):
        # Notation from https://arxiv.org/abs/1906.04032
        if inverse:
            y = inputs

            y_y_k = y - y_k

            dk1_dk_2sk = d_k1 + d_k - 2 * s_k
            y_y_k_dk1_dk_2sk = y_y_k * dk1_dk_2sk

            a = y_k1_y_k * (s_k - d_k) + y_y_k_dk1_dk_2sk
            b = y_k1_y_k * d_k - y_y_k_dk1_dk_2sk
            c = -s_k * y_y_k

            root = b.pow(2) - 4*a*c
            # assert (root >= 0).all()  # this requires waiting for cuda synchronization -> expensive

            xi = 2*c / (-b - torch.sqrt(root))
            outputs = xi * x_k1_x_k + x_k

            xi_inv_xi = xi * (1 - xi)
            xi2 = xi.pow(2)
            inv_xi2 = (1 - xi).pow(2)

            denominator = s_k + dk1_dk_2sk * xi_inv_xi
            derivative_numerator = s_k.pow(2) * (d_k1 * xi2 + 2 * s_k * xi_inv_xi + d_k * inv_xi2)

            logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
            return outputs, -logabsdet
        else:
            xi = (inputs - x_k) / x_k1_x_k
            xi_inv_xi = xi * (1 - xi)
            xi2 = xi.pow(2)
            inv_xi2 = (1 - xi).pow(2)

            numerator = y_k1_y_k * (s_k * xi2 + d_k * xi_inv_xi)
            denominator = s_k + (d_k1 + d_k - 2*s_k) * xi_inv_xi
            
            outputs = y_k + numerator / denominator

            derivative_numerator = s_k.pow(2) * (d_k1 * xi2 + 2 * s_k * xi_inv_xi + d_k * inv_xi2)
            logabsdet = torch.log(derivative_numerator) - 2*torch.log(denominator)
            return outputs, logabsdet

    def _rational_linear_spline(self, 
            inputs,
            x_k1_x_k,
            x_k,
            y_k1_y_k,
            y_k,
            s_k,
            d_k,
            d_k1,
            lambdas,
            inverse: bool = False,
            w_k: float = 1.):
        # Notation from https://arxiv.org/pdf/2001.05168.pdf
        inv_lambdas = 1. - lambdas

        w_k1 = torch.sqrt(d_k / d_k1) * w_k
        w_m = (lambdas * w_k * d_k + inv_lambdas * w_k1 * d_k1) * s_k
        
        y_k1 = y_k + y_k1_y_k
        y_m = (inv_lambdas * w_k * y_k + lambdas * w_k1 * y_k1) / (inv_lambdas * w_k + lambdas * w_k1)

        if inverse:
            y = inputs
            left_of_lambda = (y <= y_m).float()
            right_of_lambda = 1. - left_of_lambda

            # Lets cache some variables to avoid redundant computations
            y_k_y = y_k - y
            y_k1_y = y_k1 - y
            y_y_m = y - y_m
            y_k1_y_m = y_k1 - y_m
            y_m_y_k = y_m - y_k

            numerator = lambdas * w_k * y_k_y * left_of_lambda \
                + (lambdas * w_k1 * y_k1_y + w_m * y_y_m) * right_of_lambda
            
            denominator = (w_k * y_k_y + w_m * y_y_m) * left_of_lambda \
                + (w_k1 * y_k1_y + w_m * y_y_m) * right_of_lambda

            phi = numerator / denominator
            outputs = x_k1_x_k*phi + x_k

            derivative_numerator = lambdas * w_k * w_m * y_m_y_k * left_of_lambda \
                + inv_lambdas * w_m * w_k1 * y_k1_y_m * right_of_lambda

            derivative_numerator *= x_k1_x_k

            logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator))
        else:
            x = inputs
            phi = (x - x_k) / x_k1_x_k
            left_of_lambda = (phi <= lambdas).float()
            right_of_lambda = 1. - left_of_lambda

            # Lets cache some variables to avoid redundant computations
            l_phi = lambdas - phi
            phi_l = phi - lambdas
            inv_phi = 1 - phi

            numerator = (w_k * y_k * l_phi + w_m * y_m * phi) * left_of_lambda \
                + (w_m * y_m * inv_phi + w_k1 * y_k1 * phi_l) * right_of_lambda

            denominator = (w_k * l_phi + w_m * phi) * left_of_lambda \
                + (w_m * inv_phi + w_k1 * phi_l) * right_of_lambda

            outputs = numerator/denominator

            derivative_numerator = lambdas * w_k * w_m * (y_m - y_k) * left_of_lambda \
                + (inv_lambdas * w_m * w_k1 * (y_k1 - y_m)) * right_of_lambda

            derivative_numerator /= x_k1_x_k
            
            logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator))
        return outputs, logabsdet

    def _search_sorted_old(self, bin_locations, inputs):
        #print("bin_locations.shape",bin_locations.shape)
        bin_locations[..., -1] += self.eps
        #print("inputs.shape",inputs.shape)
        return torch.sum(
            inputs[..., None] >= bin_locations,
            dim=-1
        ) - 1
    
    def _search_sorted(self, bins,inputs):
        
        bins = bins.clone()
        bins[..., -1] += self.eps

        inputs_expanded = inputs.unsqueeze(-1)  # shape: (..., X, 1)
        bins_expanded = bins.unsqueeze(-2)      # shape: (..., 1, Y)

        comparison = inputs_expanded >= bins_expanded  # shape: (..., X, Y)
        bin_idx = comparison.sum(dim=-1) - 1  # shape: (..., X)
        #print("bin_idx",bin_idx)
        return bin_idx


def softplus(x, beta):
    # hard thresholding at 20
    temp = beta * x
    # temp[temp > 20] = 20
    return 1.0 / beta * torch.log(1 + torch.exp(temp))

def get_non_pad_mask(seq):
    """ Get the non-padding positions. """

    assert seq.dim() == 2
    return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)

def get_attn_key_pad_mask(seq_k, seq_q):
    """ For masking out the padding part of key sequence. """

    # expand to fit the shape of key query attention matrix
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(Constants.PAD)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk
    return padding_mask

def get_non_event_mask(seq):
    """ For masking out the non-event time point"""
    len_q = seq.size(1)
    padding_mask = seq.eq(Constants.GRID)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk
    return padding_mask

def get_subsequent_mask(seq):
    """ For masking out the subsequent info, i.e., masked self-attention. """

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls
    return subsequent_mask

class MultiMapping(nn.Module):
    def __init__(self, mode_dim, k, M):
        super(MultiMapping, self).__init__()
        self.M = M
        self.fc_layers = nn.ModuleList([nn.Linear(mode_dim, k) for _ in range(M)])
    
    def forward(self, encoding):
        batch_size, time_len, mode_dim = encoding.shape
        outputs = []
        for i in range(self.M):
            fc_output = self.fc_layers[i](encoding.view(-1, mode_dim))
            fc_output = fc_output.view(batch_size, time_len, -1)
            outputs.append(fc_output)
        
        stacked_outputs = torch.stack(outputs, dim=0)
        return stacked_outputs

class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """

    def __init__(self, temperature, attn_dropout=0.2):
        super().__init__()

        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn
    
class MultiHeadAttention(nn.Module):
    """ Multi-Head Attention module """

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True):
        super().__init__()

        self.normalize_before = normalize_before
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        nn.init.xavier_uniform_(self.w_qs.weight)
        nn.init.xavier_uniform_(self.w_ks.weight)
        nn.init.xavier_uniform_(self.w_vs.weight)

        self.fc = nn.Linear(d_v * n_head, d_model)
        nn.init.xavier_uniform_(self.fc.weight)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout)

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q
        if self.normalize_before:
            q = self.layer_norm(q)

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)  # For head axis broadcasting.

        output, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        output = self.dropout(self.fc(output))
        output += residual

        if not self.normalize_before:
            output = self.layer_norm(output)
        return output, attn

class PositionwiseFeedForward(nn.Module):
    """ Two-layer position-wise feed-forward neural network. """

    def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True):
        super().__init__()

        self.normalize_before = normalize_before

        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)

        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        if self.normalize_before:
            x = self.layer_norm(x)

        x = F.gelu(self.w_1(x))
        x = self.dropout(x)
        x = self.w_2(x)
        x = self.dropout(x)
        x = x + residual

        if not self.normalize_before:
            x = self.layer_norm(x)
        return x

class EncoderLayer(nn.Module):
    """ Compose with two layers """

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, normalize_before=True):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, dropout=dropout, normalize_before=normalize_before)
        self.pos_ffn = PositionwiseFeedForward(
            d_model, d_inner, dropout=dropout, normalize_before=normalize_before)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output *= non_pad_mask

        enc_output = self.pos_ffn(enc_output)
        enc_output *= non_pad_mask

        return enc_output, enc_slf_attn

class MAS_Encoder(nn.Module):
    """ A encoder model with self attention mechanism. """

    def __init__(
            self,
            num_types, d_model, d_inner,
            n_layers, n_head, d_k, d_v, dropout, opt):
        super().__init__()

        self.d_model = d_model
        self.opt=opt
        #print(opt.device)
        # position vector, used for temporal encoding
        self.position_vec = torch.tensor(
            [math.pow(10000.0, 2.0 * (i // 2) / d_model) for i in range(d_model)],
            device=torch.device(opt.device))
        # )

        # event type embedding
        self.event_emb = nn.Embedding(num_types + 1, d_model, padding_idx=Constants.PAD).to(opt.device)
        # self.event_emb = nn.Embedding(num_types + 1, d_model, padding_idx=Constants.PAD)

        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout, normalize_before=False)
            for _ in range(n_layers)])

    def temporal_enc(self, time, non_pad_mask):
        """
        Input: batch*seq_len.
        Output: batch*seq_len*d_model.
        """
        # print('time',time.requires_grad)
        tt = time.unsqueeze(-1) / self.position_vec
        mask = torch.zeros_like(tt).bool()
        result = torch.zeros_like(tt)
        mask[..., 0::2] = True

        result += torch.sin(tt)*mask
        result += torch.cos(tt)*~mask
        # print(result.size(),non_pad_mask.size())
        return result * non_pad_mask

    def forward(self, event_type, event_time, non_pad_mask):
        """ Encode event sequences via masked self-attention. """

        # prepare attention masks
        # slf_attn_mask is where we cannot look, i.e., the future and the padding
        slf_attn_mask_subseq = get_subsequent_mask(event_type)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=event_type, seq_q=event_type)
        slf_attn_mask_keypad = slf_attn_mask_keypad.type_as(slf_attn_mask_subseq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        tem_enc = self.temporal_enc(event_time, non_pad_mask)
        enc_output = self.event_emb(event_type)
        #print(event_type)
        #print(event_type.shape)
        if event_type.max() > 22 or event_type.min() < 0:
            print("wrong event type")
        for name, param in self.event_emb.named_parameters():
            if torch.any(torch.isnan(param.data)):
                print("wrong parameter")

        for enc_layer in self.layer_stack:
            enc_output += tem_enc
            enc_output, _ = enc_layer(
                enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)
        return enc_output

class MAS_Transformer(nn.Module):
    
    """ A sequence to sequence model with attention mechanism. """

    def __init__(
            self,n_knots,
            num_types, d_model=16, d_inner=8,
            n_layers=1, n_head=1, d_k=16, d_v=16, dropout=0.1, 
            left=0., right=1., bottom=0., top=1., tails='undefined', 
            spline_order=2, min_bin_width = 1e-2,min_bin_height = 1e-2,
            min_derivative = 1e-2,min_lambda = 0.025,opt=None):
        super().__init__()

        self.encoder = MAS_Encoder(
            num_types=num_types,
            d_model=d_model,
            d_inner=d_inner,
            n_layers=n_layers,
            n_head=n_head,
            d_k=d_k,
            d_v=d_v,
            dropout=dropout,
            opt=opt
        )

        self.name = 'thp'
        self.num_types = num_types
        self.normalize = None
        self.d_inner = d_inner
        self.data_name = opt.data_name
        
        # self.method = opt.method


        self.base_layer = nn.Sequential(
                nn.Linear(d_model, num_types, bias=True)
                )

        self.affect_layer = nn.Sequential(
                nn.Linear(d_model, num_types, bias=True),
                nn.Tanh()
                )
        self.intensity_layer = nn.Sequential(
                nn.Softplus(beta=1.0)
                )
        
        
        
        self.spline=Spline(n_knots=n_knots, left=left, right=right, bottom=bottom, top=top, tails=tails, spline_order=spline_order, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, min_lambda=min_lambda)

        self.multimap=MultiMapping(d_model, self.spline.n_params, num_types)
        
        self.inconsistent_T = opt.inconsistent_T
    
    def forward(self, event_type, event_time, time_gap, opt):
        """
        Return the hidden representations and predictions.
        For a sequence (l_1, l_2, ..., l_N), we predict (l_2, ..., l_N, l_{N+1}).
        Input: event_type: batch*seq_len;
               event_time: batch*seq_len.
        Output: enc_output: batch*seq_len*model_dim;
                type_prediction: batch*seq_len*num_classes (not normalized);
                time_prediction: batch*seq_len.
        """

        time_gap = torch.cat((event_time[:,0:1], time_gap), axis = 1)
        event_time = torch.concatenate((torch.zeros(event_time.shape[0], 1).to(opt.device), event_time), axis = 1)
        event_type = torch.cat((torch.ones(event_type.shape[0], 1).type(torch.long).to(opt.device), event_type), axis = 1)
        

        non_pad_mask = get_non_pad_mask(event_type)
        #print(event_type)
        #print(event_type.shape)
        #print(event_time.shape)
        enc_output = self.encoder(event_type, event_time, non_pad_mask)
        self.enc_output = enc_output
        
        loss = self.compute_loss_mle(event_type, event_time, time_gap, non_pad_mask, num_grid = opt.num_grid, opt=opt)
        
        return loss, enc_output
            
    def get_intensity(self, t, event_type, event_time, time_gap, non_pad_mask):

        if t.ndim == 2:
            t = t.unsqueeze(2)
        assert t.ndim == 3
        para=self.multimap(self.enc_output)[:,:,:-1,:]
        self.affect = self.affect_layer(self.enc_output).permute(2,0,1).unsqueeze(-1)[:,:,:-1,:].expand(-1,-1,-1,t.shape[-1])
        self.base = self.base_layer(self.enc_output).permute(2,0,1).unsqueeze(-1)[:,:,:-1,:].expand(-1,-1,-1,t.shape[-1])
        #print(t.shape)
        #print(para.shape)
        cum_intensity,intensity = self.spline._unconstrained_spline(t.unsqueeze(dim=0).expand(self.num_types,-1,-1,-1),para, self.affect, self.base)
        intensity=intensity.exp()
        intensity=intensity.permute(1,2,3,0)
        all_lambda = intensity * non_pad_mask[:, :-1, None, :]
        cum_intensity=cum_intensity.permute(1,2,3,0)
        all_cum_intensity=cum_intensity * non_pad_mask[:, :-1, None, :]
        return all_lambda, all_cum_intensity

    def compute_loss_mle(self, event_type, event_time, time_gap, non_pad_mask, num_grid ,opt):
        
        non_pad_mask = get_non_pad_mask(event_type)
            
        event_ll, non_event_ll = self.log_likelihood(event_time, time_gap, event_type, opt.num_grid,opt.MLE_method)
        event_loss = -torch.sum(event_ll - non_event_ll)
        loss = event_loss
        return loss
    
    def predict(self, event_type, event_time, time_gap, opt):
        """
        Return the hidden representations and predictions.
        For a sequence (l_1, l_2, ..., l_N), we predict (l_2, ..., l_N, l_{N+1}).
        Input: event_type: batch*seq_len;
               event_time: batch*seq_len.
        Output: enc_output: batch*seq_len*model_dim;
                type_prediction: batch*seq_len*num_classes (not normalized);
                time_prediction: batch*seq_len.
        """
        time_gap = torch.cat((event_time[:,0:1], time_gap), axis = 1)
        event_time = torch.concatenate((torch.zeros(event_time.shape[0], 1).to(opt.device), event_time), axis = 1)
        event_type = torch.cat((torch.ones(event_type.shape[0], 1).type(torch.long).to(opt.device), event_type), axis = 1)
        non_pad_mask = get_non_pad_mask(event_type)
        enc_output = self.encoder(event_type, event_time, non_pad_mask)
        self.enc_output = enc_output
        intensity_pred ,_= self.get_intensity(time_gap, event_type, event_time, time_gap, non_pad_mask)
        intensity_pred=intensity_pred.squeeze(2)
        _, type_pred = torch.max(intensity_pred, dim=-1)
        return type_pred + 1
    
    def predict_time(self,event_type,event_time,time_gap,opt):
        time_gap = torch.cat((event_time[:,0:1], time_gap), axis = 1)

        event_time_1=event_time
        event_time = torch.concatenate((torch.zeros(event_time.shape[0], 1).to(opt.device), event_time), axis = 1)
        event_type = torch.cat((torch.ones(event_type.shape[0], 1).type(torch.long).to(opt.device), event_type), axis = 1)
        time_gap_new=torch.linspace(0,opt.pre_window,opt.num_grid).to(opt.device).reshape(1,1,-1).expand(time_gap.shape[0],time_gap.shape[1],-1)
        time_gap_new=time_gap_new.to(torch.double)
        non_pad_mask = get_non_pad_mask(event_type)
        enc_output = self.encoder(event_type, event_time, non_pad_mask)
        self.enc_output = enc_output
        intensity_pred , cum_intensity_pre= self.get_intensity(time_gap_new, event_type, event_time, time_gap, non_pad_mask)
        intensity_pred=intensity_pred.sum(-1)
        cum_intensity_pre=cum_intensity_pre.sum(-1)
        t_predict=(intensity_pred*time_gap_new*(-cum_intensity_pre).exp()).sum(-1)
        return t_predict

    def compute_event(self, event_time, time_gap, event_type, non_pad_mask):
    
        type_mask = torch.zeros([*event_type.size(), self.num_types], device=event_time.device)
        for i in range(self.num_types):
            type_mask[:, :, i] = (event_type == i + 1).bool().to(event_time.device)

        all_lambda,_ = self.get_intensity(time_gap, event_type, event_time, time_gap, non_pad_mask)
        all_lambda=all_lambda.squeeze(2)

        event = torch.sum(all_lambda * type_mask[:, 1:, :], dim=2)
        event += math.pow(10, -9)
        event.masked_fill_(~non_pad_mask[:,1:].squeeze(2).bool(), 1.0)
        result = torch.log(event+1e-10) * non_pad_mask[:,1:].squeeze(2)
        return result
    
    def compute_event_exact(self, event_time, time_gap, event_type, non_pad_mask):
    
        type_mask = torch.zeros([*event_type.size(), self.num_types], device=event_time.device)
        for i in range(self.num_types):
            type_mask[:, :, i] = (event_type == i + 1).bool().to(event_time.device)

        all_lambda, all_cum_lambda = self.get_intensity(time_gap, event_type, event_time, time_gap, non_pad_mask)
        all_lambda=all_lambda.squeeze(2)
        all_cum_lambda=all_cum_lambda.squeeze(2)

        event = torch.sum(all_lambda * type_mask[:, 1:, :], dim=2)
        event += math.pow(10, -9)
        event.masked_fill_(~non_pad_mask[:,1:].squeeze(2).bool(), 1.0)
        result = torch.log(event+1e-10) * non_pad_mask[:,1:].squeeze(2)

        return result, all_cum_lambda.sum(dim=-1)

    def compute_integral_unbiased(self, event_time, time_gap, event_type, non_pad_mask, num_grid):

        """ Log-likelihood of non-events, using Monte Carlo integration. """

        num_samples = num_grid
        if self.normalize == 'log':
            time_low = min(-1.0,time_gap.min()-1.0)
        else:
            time_low = 0
        temp_time = (time_gap.unsqueeze(2) - time_low) * \
                    torch.rand([*time_gap.size(), num_samples], device=event_time.device) + time_low

        if self.num_types >= 100:
            all_lambda = None
            for i in range(num_samples):
                lambda_i,_ = self.get_intensity(temp_time[:,:,i:i+1], event_type, event_time, time_gap, non_pad_mask)
                if all_lambda == None:
                    all_lambda = torch.sum(lambda_i, dim=(2,3)) 
                else:
                    all_lambda += torch.sum(lambda_i, dim=(2,3)) 
            all_lambda /= num_samples
        else:
            all_lambda,_ = self.get_intensity(temp_time, event_type, event_time, time_gap, non_pad_mask)
            all_lambda = torch.sum(all_lambda, dim=(2,3)) / num_samples
            

        unbiased_integral = all_lambda * (time_gap - time_low) * non_pad_mask.squeeze(-1)[:,1:]
        
        return unbiased_integral

    def log_likelihood(self, event_time, time_gap, event_type, num_grid, method="MC"):
        """ Log-likelihood of sequence. """
        if(method=="MC"):
            non_pad_mask = get_non_pad_mask(event_type)

            event_ll = self.compute_event(event_time, time_gap, event_type, non_pad_mask)
            event_ll = torch.sum(event_ll, dim=-1)

            non_event_ll = self.compute_integral_unbiased(event_time, time_gap, event_type, non_pad_mask, num_grid)
            non_event_ll = torch.sum(non_event_ll, dim=-1)

            return event_ll, non_event_ll
        elif(method=="Exact"):
            non_pad_mask = get_non_pad_mask(event_type)
            event_ll,non_event_ll=self.compute_event_exact(event_time, time_gap, event_type, non_pad_mask)
            return event_ll, non_event_ll
        else:
            raise NotImplementedError("Method {} not implemented.".format(method))
