###############
#   Package   #
###############
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from copy import deepcopy
from typing import Optional

###########################
#   Package from myself   #
###########################
from models.grud_cell import GRUDCell

#############
#   Class   #
#############
class GRU_D(nn.Module):
    def __init__(self,
                input_size: int,
                output_size: int,
                num_layers: int = 1,
                x_mean: float = 0.,
                dropout: float = 0.5,
                decoder_down_factor: int = 2,
                ):
        super(GRU_D, self).__init__()
        self.num_layers = num_layers

        # Construct Mulilayer GRU_D
        grud_layer = GRUDCell(input_size = input_size,
                             hidden_size = input_size,
                             output_size = input_size,
                             x_mean = 0.,
                             dropout = dropout,
                             return_hidden = True,
                             )

        self.gru_d = nn.ModuleList(
                                  [deepcopy(grud_layer) for _ in range(num_layers)]
                                  )

        self.MLP = nn.Sequential(
            nn.Linear(input_size, input_size // decoder_down_factor),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(input_size // decoder_down_factor, output_size),
        )

    def forward(self, x_num: Tensor, x_num_mask: Tensor, x_num_delta: Tensor, x_cat: Tensor, x_cat_mask: Tensor, x_cat_delta: Tensor) -> Tensor:
        x = torch.cat([x_num, x_cat], dim=2)
        x_mean = torch.cat([torch.mean(x_num.view(-1, x_num.size(-1)), dim=0), torch.mode(x_cat.view(-1, x_cat.size(-1)), dim=0)[0]], dim=0)
        mask = torch.cat([x_num_mask, x_cat_mask], dim=2)
        delta = torch.cat([x_num_delta, x_cat_delta], dim=2)
    
        hidden = x
        for layer in self.gru_d:
            output, hidden = layer(hidden, mask, delta, x_mean)

        last_vec = output[:, -1, :]
        output = self.MLP(last_vec)
        prob = torch.sigmoid(output)

        return prob


if __name__ == '__main__':
    torch.cuda.set_device(0)
    use_gpu = torch.cuda.is_available()
    device = torch.device('cuda' if use_gpu else 'cpu')

    #variable
    x_num = torch.rand(1, 4, 30)
    m_num = torch.empty_like(x_num).bernoulli_()
    d_num = torch.rand(1, 4, 30)
    x_cat = torch.rand(1, 4, 8)
    m_cat = torch.empty_like(x_cat).bernoulli_()
    d_cat = torch.rand(1, 4, 8)

    model = GRU_D(input_size = 38,
                    output_size = 1,
                    num_layers = 1,
                    )

    x_num = x_num.to(device)
    m_num = m_num.to(device)
    d_num = d_num.to(device)
    x_cat = x_cat.to(device)
    m_cat = m_cat.to(device)
    d_cat = d_cat.to(device)

    model.to(device)

    ot = model(x_num, m_num, d_num, x_cat, m_cat, d_cat)
    breakpoint()
