#REF: https://github.com/fteufel/PyTorch-GRU-D

###############
#   Package   #
###############
import os
import math
import warnings
import numbers
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from copy import deepcopy
from typing import Optional

#############
#   Class   #
#############
class GRUDCell(nn.Module):
    def __init__(self,
                input_size: int,
                hidden_size: int,
                output_size: int,
                x_mean: float = 0.,
                dropout: float = 0.,
                return_hidden: bool = False,
                ):
        super(GRUDCell, self).__init__()
        # variable in the module
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.return_hidden = return_hidden # controls the output. True if another GRU-D layer follows

        x_mean = torch.tensor(x_mean, requires_grad=True)
        self.register_buffer('x_mean', x_mean)
        assert isinstance(dropout, numbers.Number) and 0<= dropout <= 1 and not isinstance(dropout, bool), ValueError('dropout should be a number in range [0, 1] representing the probability of an element being zeroed.')
        self.dropout = nn.Dropout(dropout)

        #set up all the operations that are needed in the forward pass
        self.w_dg_x = nn.Linear(input_size,input_size, bias=True)
        self.w_dg_h = nn.Linear(input_size, hidden_size, bias = True)

        self.w_xz = nn.Linear(input_size, hidden_size, bias=False)
        self.w_hz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.w_mz = nn.Linear(input_size, hidden_size, bias=True)

        self.w_xr = nn.Linear(input_size, hidden_size, bias=False)
        self.w_hr = nn.Linear(hidden_size, hidden_size, bias=False)
        self.w_mr = nn.Linear(input_size, hidden_size, bias=False)

        self.w_xh = nn.Linear(input_size, hidden_size, bias=False)
        self.w_hh = nn.Linear(hidden_size, hidden_size, bias=False)
        self.w_mh = nn.Linear(input_size, hidden_size, bias=True)

        self.w_hy = nn.Linear(hidden_size, output_size, bias=True)

        hidden_state = torch.zeros(self.hidden_size, requires_grad=True)
        self.register_buffer('hidden_state', hidden_state)
        self.register_buffer('x_last_obs', torch.zeros(input_size))

        self._reset_parameters()

    def _reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)

    @property
    def _flat_weights(self):
        return list(self._parameters.values())
    
    def forward(self, x: Tensor, mask: Tensor, delta: Tensor, mean: Tensor) -> Tensor:
        h = getattr(self, 'hidden_state')
        x_mean = mean #getattr(self, 'x_mean')
        x_last_obs = getattr(self, 'x_last_obs')

        device = next(self.parameters()).device
        output_tensor = torch.empty([x.size(0), x.size(1), self.output_size], dtype=x.dtype, device=device)
        hidden_tensor = torch.empty(x.size(0), x.size(1), self.hidden_size, dtype=x.dtype, device=device)

        for timestep in range(x.size(1)):
            x_ = torch.squeeze(x[:, timestep, :])
            m_ = torch.squeeze(mask[:,timestep, :])
            d_ = torch.squeeze(delta[:,timestep, :])

            gamma_x = torch.exp(-1 * F.relu(self.w_dg_x(d_)))
            gamma_h = torch.exp(-1 * F.relu(self.w_dg_h(d_)))
            
            # Imputing
            x_last_obs = torch.where(m_ > 0, x_, x_last_obs)
            #x_ = m_ * x_ + (1 - m_) * (gamma_x * x_ + (1 - gamma_x) * x_mean) # This should be checked.
            x_ = m_ * x_ + (1 - m_) * (gamma_x * x_last_obs + (1 - gamma_x) * x_mean)

            # Implement with MLoss 
            # However, the source code didn't contain dropout layer
            # REF: https://arxiv.org/pdf/1603.05118.pdf

            h = gamma_h * h
            z = torch.sigmoid(self.w_xz(x_) + self.w_hz(h) + self.w_mz(m_))
            r = torch.sigmoid(self.w_xr(x_) + self.w_hr(h) + self.w_mr(m_))

            h_tilde = torch.tanh(self.w_xh(x_) + self.w_hh(r * h) + self.w_mh(m_))

            h = (1 - z) * h + z * h_tilde
            h = self.dropout(h)
            
            step_output = torch.sigmoid(self.w_hy(h))
            output_tensor[:, timestep, :] = step_output
            hidden_tensor[:, timestep, :] = h

        return (output_tensor, hidden_tensor) if self.return_hidden else output_tensor


if __name__ == '__main__':
    torch.cuda.set_device(0)
    use_gpu = torch.cuda.is_available()
    device = torch.device('cuda' if use_gpu else 'cpu')

    #variable
    x = torch.rand(1, 4, 38)
    m = torch.empty_like(x).bernoulli_()
    d = torch.rand(1, 4, 38)

    model = GRUDCell(input_size = 38,
                    hidden_size = 128,
                    output_size = 128,
                    )

    x = x.to(device)
    m = m.to(device)
    d = d.to(device)

    model.to(device)

    model(x, m ,d)
    breakpoint()

