################
#   Packages   #
################
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F

##############
#   Models   #
##############
class GRU(nn.Module):
    def __init__(self,
                NUM_LEN: int = None,
                CAT_LEN: int = None,
                encoder_hidden_size: int = 104,
                encoder_num_layers: int = 4,
                encoder_dropout: float = 0.44,
                decoder_down_factor: int = 2,
                decoder_dropout: float = 0.2,
                output_size: int = 1,
                ):
        super(GRU, self).__init__()
        assert NUM_LEN is not None, ValueError("NUM_LEN should be positive integer.")
        assert CAT_LEN is not None, ValueError("CAT_LEN should be positive integer.")

        # define variables
        self.input_size = (NUM_LEN + CAT_LEN) * 2
        self.encoder_hidden_size = encoder_hidden_size
        self.encoder_num_layers = encoder_num_layers
        self.encoder_dropout = encoder_dropout

        # GRU (encoder)
        self.encoder = nn.GRU(
            self.input_size,
            encoder_hidden_size,
            encoder_num_layers,
            batch_first = True,
            dropout = encoder_dropout,
        )

        self.decoder = nn.Sequential(
            nn.Linear(encoder_hidden_size, self.input_size // decoder_down_factor),
            nn.GELU(),
            nn.Dropout(decoder_dropout),
            nn.Linear(self.input_size // decoder_down_factor, output_size),
        )
    
    def init_hidden(self, batch_size):
        return torch.zeros(self.encoder_num_layers, batch_size, self.encoder_hidden_size, device=torch.device("cuda:0"))
    
    def forward(self, x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask):
        x = torch.cat([x_num, x_cat, x_num_mask, x_cat_mask], dim=2)

        h = self.init_hidden(x.size(0))
        output, hidden = self.encoder(x, h)
        output = F.relu(output)

        last_vec = output[:, -1]

        output = self.decoder(last_vec)
        prob = torch.sigmoid(output)

        return prob

if __name__ == '__main__':
    pass
