###############
#   Package   #
###############
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from typing import Tuple

##############
#   Models   #
##############
class TE(nn.Module):
    def __init__(self,
                d_model: int = 144,
                num_heads: int = 1,
                ff_dim: int = 144,
                dropout: float = 0.49,
                num_layers: int = 3,
                seq_len: int = None,
                NUM_LEN: int = None,
                CAT_LEN: int = None,
                decoder_down_factor: int = 2,
                decoder_dropout: float = 0.44,
                output_size: int = 1,
                ):
        super(TE, self).__init__()
        # define variables and check it
        assert (seq_len is not None) and (seq_len > 0), ValueError("seq_len should be positive integer.")
        assert (NUM_LEN is not None) and (NUM_LEN > 0), ValueError("NUM_LEN should be positive integer.")
        assert (CAT_LEN is not None) and (CAT_LEN > 0), ValueError("CAT_LEN should be positive integer.")
        self.seq_len = seq_len

        self.d_model = d_model
        self.input_size = (NUM_LEN + CAT_LEN) * 2

        # sinusoidal embedding
        d_pos_vec = d_model
        n_position = seq_len
        positional_enc = np.array(
            [
            [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)]
            if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)
            ]
        )
        positional_enc[1:, 0::2] = np.sin(positional_enc[1:, 0::2])
        positional_enc[1:, 1::2] = np.cos(positional_enc[1:, 1::2])
        self.pos_emb = torch.from_numpy(positional_enc).type(torch.FloatTensor)
        self.pos_emb = self.pos_emb.to("cuda:0")

        # total embedding
        self.total_embedding = nn.Linear(self.input_size, d_model)

        # transformer encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model,
            num_heads,
            dim_feedforward = ff_dim,
            dropout = dropout,
            activation = 'gelu',
            batch_first = True,
        )

        self.encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        
        # linear decoder
        self.decoder = nn.Sequential(
                nn.Linear(d_model, d_model // decoder_down_factor),
                nn.GELU(),
                nn.Dropout(decoder_dropout),
                nn.Linear(d_model // decoder_down_factor, output_size),
                                    )
    
    def forward(self, x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask):
        '''
            param:
                x_num_idx: the vector to indicate the idx of each numerical data.
                        dim = (batch size, summarization times, NUM_LEN)
                x_num: the value of each numerical data.
                        dim = (batch size, summarization times, NUM_LEN)
                x_num_mask: the mask to indicate which value is missing.
                        0 = missing value, 1 = non-missing value.
                        dim = (batch size, summarization times, NUM_LEN)
                x_cat_idx: the vector to indicate the idx of each categorical data.
                        dim = (batch size, summarization times, CAT_LEN)
                x_cat: the value of each categorical data.
                        dim = (batch size, summarization times, CAT_LEN)
                x_cat_mask: the mask to indicate which value is missing.
                        0 = missing value, 1 = non-missing value.
                        dim = (batch size, summarization times, CAT_LEN)
            output:
                prob: the probability of the sample tested positive.
                        dim = (batch size, output_size=1)
        '''
        # concatenate numerical and categorical data
        x_feat = torch.cat([x_num, x_cat], dim=2)
        x_mask = torch.cat([x_num_mask, x_cat_mask], dim=2)

        x = torch.cat([x_feat, x_mask], dim=2)

        # total embedding
        output = self.total_embedding(x)

        # positional encoding
        output += self.pos_emb

        # put x to transformer encoder
        output = self.encoder(output, src_key_padding_mask=None)

        # deal with all vectors.
        ''' this part should be surveyed.'''
        last_vec = torch.mean(output, dim=1)

        # decoder layer
        output = self.decoder(last_vec)
        prob = torch.sigmoid(output)
        return prob

if __name__ == '__main__':
    pass
