# -*- coding: utf-8 -*-
"""
Created on Sat May 12 16:48:54 2018

@author: Zhiyong
"""

import torch.utils.data as utils
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import math
import numpy as np
import pandas as pd
import time
import pdb

class FilterLinear(nn.Module):
    def __init__(self, in_features, out_features, filter_square_matrix, bias=True):
        '''
        filter_square_matrix : filter square matrix, whose each elements is 0 or 1.
        '''
        super(FilterLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        use_gpu = torch.cuda.is_available()
        self.filter_square_matrix = None
        if use_gpu:
            self.filter_square_matrix = Variable(filter_square_matrix.cuda(), requires_grad=False)
        else:
            self.filter_square_matrix = Variable(filter_square_matrix, requires_grad=False)
        
        #self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.weight = Parameter(torch.Tensor(in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        #stdv = 1. / math.sqrt(self.weight.size(1))
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
#         print(self.weight.data)
#         print(self.bias.data)

    def forward(self, input):
        return (input * self.weight[None, :]) + self.bias[None,:]
        #return F.linear(input, self.weight, self.bias)
        #return F.linear(input, self.filter_square_matrix.mul(self.weight), self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'
        
class GRUD(nn.Module):
    def __init__(self, input_size, cell_size, hidden_size, 
        X_mean, output_last = False, use_encoder=False):
        """
        Recurrent Neural Networks for Multivariate Times Series with Missing Values
        GRU-D: GRU exploit two representations of informative missingness patterns, i.e., masking and time interval.
        cell_size is the size of cell_state.
        
        Implemented based on the paper: 
        @article{che2018recurrent,
          title={Recurrent neural networks for multivariate time series with missing values},
          author={Che, Zhengping and Purushotham, Sanjay and Cho, Kyunghyun and Sontag, David and Liu, Yan},
          journal={Scientific reports},
          volume={8},
          number={1},
          pages={6085},
          year={2018},
          publisher={Nature Publishing Group}
        }
        
        GRU-D:
            input_size: variable dimension of each time
            hidden_size: dimension of hidden_state
            mask_size: dimension of masking vector
            X_mean: the mean of the historical input data
        """
        
        super(GRUD, self).__init__()
        
        self.hidden_size = hidden_size
        self.delta_size = input_size
        self.mask_size = input_size
        
        self.use_gpu = use_gpu = torch.cuda.is_available()
        if use_gpu:
            self.identity = torch.eye(input_size).cuda()
            self.zeros = Variable(torch.zeros(input_size).cuda())
            self.zeros_hidden = Variable(torch.zeros(hidden_size).cuda())
            self.X_mean = Variable(torch.Tensor(X_mean).cuda())
            #self.X_mean = Variable(torch.zeros_like(X_mean).cuda())
        else:
            self.identity = torch.eye(input_size)
            self.zeros = Variable(torch.zeros(input_size))
            self.zeros_hidden = Variable(torch.zeros(hidden_size))
            self.X_mean = Variable(torch.Tensor(X_mean))
        
        self.use_encoder = use_encoder

        if self.use_encoder:
            self.encoder = nn.Linear(input_size+hidden_size+self.mask_size, hidden_size) 
            self.zl = nn.Linear(hidden_size, hidden_size)
            self.rl = nn.Linear(hidden_size, hidden_size)
            self.hl = nn.Linear(hidden_size, hidden_size)
        else:
            self.zl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)
            self.rl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)
            self.hl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)
        

        self.gamma_x_l = FilterLinear(self.delta_size, self.delta_size, self.identity)
        
        self.gamma_h_l = nn.Linear(self.delta_size, hidden_size)
        
        self.output_last = output_last

        
    def step(self, x, x_last_obsv, x_mean, h, mask, delta):
        
        batch_size = x.shape[0]
        dim_size = x.shape[1]
        
        delta_x = torch.exp(-torch.max(self.zeros, self.gamma_x_l(delta)))
        #delta_h = torch.exp(-torch.max(self.zeros, self.gamma_h_l(delta)))
        delta_h = torch.exp(-torch.max(self.zeros_hidden, self.gamma_h_l(delta)))
        
        # default
        x = mask * x + (1 - mask) * (delta_x * x_last_obsv + (1 - delta_x) * x_mean)

        h = delta_h * h
        
        combined = torch.cat((x, h, mask), 1)
        if self.use_encoder:
            combined = self.encoder(combined)
        z = F.sigmoid(self.zl(combined))
        r = F.sigmoid(self.rl(combined))
        combined_r = torch.cat((x, r * h, mask), 1)
        if self.use_encoder:
            combined_r = self.encoder(combined_r)
        h_tilde = F.tanh(self.hl(combined_r))
        h = (1 - z) * h + z * h_tilde
        
        return h
    
    def forward(self, input, valid_mask, time_points):
        batch_size = input.size(0)
        #type_size = input.size(1)
        step_size = input.size(1)
        #spatial_size = input.size(3)
        
        Hidden_State = self.initHidden(batch_size)
        X = input
        Mask = valid_mask
        #X = torch.squeeze(input[:,0,:,:])
        #X_last_obsv = torch.squeeze(input[:,1,:,:])
        #Mask = torch.squeeze(input[:,2,:,:])
        #Delta = torch.squeeze(input[:,3,:,:])
        
        outputs = None
        last_values = torch.zeros((batch_size, input.shape[-1]), dtype=torch.double).cuda()
        #last_values = self.X_mean.repeat(batch_size, 1).cuda().double()
        time_delta = torch.zeros((batch_size, input.shape[-1])).cuda().double()
        for i in range(step_size):
            current_input = torch.squeeze(X[:,i:i+1,:], dim=1)
            current_mask = torch.squeeze(Mask[:,i:i+1,:], dim=1)
            last_values[current_mask] = current_input[current_mask]
    
            Hidden_State = self.step(current_input\
                                     #, torch.squeeze(X_last_obsv[:,i:i+1,:])\
                                     , last_values.clone()
                                     #, torch.squeeze(self.X_mean[:,i:i+1,:])\
                                     , self.X_mean.repeat(batch_size, 1)
                                     , Hidden_State\
                                     , current_mask.double()\
                                     , time_delta.clone())

            time_delta[current_mask] = 0.0
            if i < step_size-1:
                time_delta = time_delta + (time_points[:,i+1] - time_points[:,i])[:,None].repeat(1, input.shape[-1])
            if outputs is None:
                outputs = Hidden_State.unsqueeze(1)
            else:
                outputs = torch.cat((outputs, Hidden_State.unsqueeze(1)), 1)
                
        if self.output_last:
            return outputs[:,-1,:]
        else:
            return outputs
    
    def initHidden(self, batch_size):
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            Hidden_State = Variable(torch.zeros(batch_size, self.hidden_size).cuda())
            return Hidden_State
        else:
            Hidden_State = Variable(torch.zeros(batch_size, self.hidden_size))
            return Hidden_State
