import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import math

class FilterLinear(nn.Module):
    def __init__(self, in_features, out_features, filter_square_matrix, bias=True, device='cuda'):
        '''
        filter_square_matrix : filter square matrix, whose each elements is 0 or 1.
        '''
        super(FilterLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        use_gpu = torch.cuda.is_available()
        self.filter_square_matrix = None
        if use_gpu:
            self.filter_square_matrix = Variable(filter_square_matrix, requires_grad=False).to(device)
        else:
            self.filter_square_matrix = Variable(filter_square_matrix, requires_grad=False)
        
        self.weight = Parameter(torch.Tensor(out_features, in_features)).to(device)
        if bias:
            self.bias = Parameter(torch.Tensor(out_features)).to(device)
        else:
            self.register_parameter('bias', None).to(device)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.linear(input, self.filter_square_matrix.mul(self.weight), self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'

class GRUD(nn.Module):
    def __init__(self, input_size, hidden_size, X_mean, device='cuda'):
        """
        Recurrent Neural Networks for Multivariate Times Series with Missing Values
        GRU-D: GRU exploit two representations of informative missingness patterns, i.e., masking and time interval.
        cell_size is the size of cell_state.
        
        Implemented based on the paper: 
        @article{che2018recurrent,
          title={Recurrent neural networks for multivariate time series with missing values},
          author={Che, Zhengping and Purushotham, Sanjay and Cho, Kyunghyun and Sontag, David and Liu, Yan},
          journal={Scientific reports},
          volume={8},
          number={1},
          pages={6085},
          year={2018},
          publisher={Nature Publishing Group}
        }
        
        GRU-D:
            input_size: variable dimension of each time
            hidden_size: dimension of hidden_state
            mask_size: dimension of masking vector
            X_mean: the mean of the historical input data
        """
        
        super(GRUD, self).__init__()
        
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.delta_size = input_size
        self.mask_size = input_size
        self.device = device
        
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            self.identity = torch.eye(input_size).to(self.device)
            self.zeros_x = Variable(torch.zeros(input_size).to(self.device))
            self.zeros_h = Variable(torch.zeros(hidden_size).to(self.device))
            self.X_mean = Variable(torch.Tensor(X_mean).to(self.device))
        else:
            self.identity = torch.eye(input_size)
            self.zeros = Variable(torch.zeros(hidden_size))
            self.X_mean = Variable(torch.Tensor(X_mean))
        
        self.zl = nn.Linear(input_size + hidden_size, hidden_size).to(device)
        self.rl = nn.Linear(input_size + hidden_size, hidden_size).to(device)
        self.hl = nn.Linear(input_size + hidden_size, hidden_size).to(device)

        self.gamma_x_l = FilterLinear(self.delta_size, self.delta_size, self.identity, device = self.device)
        
        self.gamma_h_l = nn.Linear(self.delta_size, self.hidden_size).to(device)
        # self.output_last = output_last

        self.gru2 = nn.GRUCell(
            self.hidden_size,
            self.hidden_size
        ).to(device)
        self.gru3 = nn.GRUCell(
            self.hidden_size,
            self.hidden_size
        ).to(device)
        self.linear = nn.Linear(hidden_size, input_size).to(device)
        
    def step(self, x, x_last_obsv, x_mean, h, mask, delta):
        batch_size = x.shape[0]
        dim_size = x.shape[1]

        delta_x = torch.exp(-torch.max(self.zeros_x, self.gamma_x_l(delta)))
        delta_h = (torch.exp(-torch.max(self.zeros_h, self.gamma_h_l(delta))))
        
        x = mask * x + (1 - mask) * (delta_x * x_last_obsv + (1 - delta_x) * x_mean)
        h = delta_h * h
        
        combined = torch.cat((x, h), 1)
        z = F.sigmoid(self.zl(combined))
        r = F.sigmoid(self.rl(combined))
        combined_r = torch.cat((x, r * h), 1)
        h_tilde = F.tanh(self.hl(combined_r))
        h = (1 - z) * h + z * h_tilde
        
        return h
    
    def forward(self, input):
        batch_size = input.size(0)
        type_size = input.size(1)
        step_size = input.size(2)
        spatial_size = input.size(3)
        
        Hidden_State = self.initHidden(batch_size)
        X = torch.squeeze(input[:,0,:,:])
        X_last_obsv = torch.squeeze(input[:,1,:,:])
        Mask = torch.squeeze(input[:,2,:,:])
        Delta = torch.squeeze(input[:,3,:,:])
        
        outputs = None
        hidden2 = torch.randn(X.shape[0],self.hidden_size).to(self.device)
        hidden3 = torch.randn(X.shape[0],self.hidden_size).to(self.device)
        out = torch.zeros(X.shape[0],X.shape[1],self.input_size).to(self.device)
        cur_X = X[:,0,:]
        for i in range(step_size):
            Hidden_State = self.step(cur_X\
                                     , cur_X\
                                     , torch.squeeze(self.X_mean[:,i:i+1,:])\
                                     , Hidden_State\
                                     , torch.squeeze(Mask[:,i:i+1,:])\
                                     , torch.squeeze(Delta[:,i:i+1,:]))
            hidden2 = self.gru2(Hidden_State,hidden2)
            hidden3 = self.gru3(hidden2,hidden3)
            output = self.linear(hidden3)
            out[:,i,:] = output
            cur_X = output
        return out
    
    def initHidden(self, batch_size):
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            Hidden_State = Variable(torch.zeros(batch_size, self.hidden_size).to(self.device))
            return Hidden_State
        else:
            Hidden_State = Variable(torch.zeros(batch_size, self.hidden_size))
            return Hidden_State