# -*- coding: utf-8 -*-
import torch.utils.data as utils
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import math
import numpy as np
import pandas as pd
import time
import pdb
from torch import Tensor
from typing import Optional, Tuple
from torch.nn import init

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class ResidualConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(ResidualConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, x_other):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(self.norm(x_other))

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 16, num_heads: int = 8):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model % num_heads should be zero."

        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
        #self.value_proj = nn.Linear(1, self.d_head * num_heads)

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            mask: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        batch_size = value.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD
        #value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD
        value = value.repeat(1,1,self.d_head * self.num_heads).view(batch_size, -1, self.num_heads, self.d_head)
        

        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD
        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD
        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN

        context, attn = self.scaled_dot_attn(query, key, value, mask)

        context = context.view(self.num_heads, batch_size, -1, self.d_head)
        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND

        return context, attn

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention proposed in "Attention Is All You Need"
    Compute the dot products of the query with all keys, divide each by sqrt(dim),
    and apply a softmax function to obtain the weights on the values

    Args: dim, mask
        dim (int): dimention of attention
        mask (torch.Tensor): tensor containing indices to be masked

    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
        - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
        - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
        - **mask** (-): tensor containing indices to be masked

    Returns: context, attn
        - **context**: tensor containing the context vector from attention mechanism.
        - **attn**: tensor containing the attention (alignment) from the encoder outputs.
    """
    def __init__(self, dim: int, dropout_rate=0.2):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)
        #self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        #score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
        score = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1))

        if mask is not None:
            score.masked_fill_(mask.view(score.size()), -float('Inf'))

        attn = F.softmax(score, -1)
        #attn = self.dropout(attn)
        context = torch.bmm(attn, value)
        return context, attn

class FilterLinear(nn.Module):
    def __init__(self, in_features, out_features, filter_square_matrix, bias=True):
        '''
        filter_square_matrix : filter square matrix, whose each elements is 0 or 1.
        '''
        super(FilterLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        use_gpu = torch.cuda.is_available()
        self.filter_square_matrix = None
        if use_gpu:
            self.filter_square_matrix = Variable(filter_square_matrix.cuda(), requires_grad=False)
        else:
            self.filter_square_matrix = Variable(filter_square_matrix, requires_grad=False)
        
        #self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.weight = Parameter(torch.Tensor(in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        #stdv = 1. / math.sqrt(self.weight.size(1))
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
#         print(self.weight.data)
#         print(self.bias.data)

    def forward(self, input):
        return (input * self.weight[None, :]) + self.bias[None,:]
        #return F.linear(input, self.weight, self.bias)
        #return F.linear(input, self.filter_square_matrix.mul(self.weight), self.bias)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'in_features=' + str(self.in_features) \
            + ', out_features=' + str(self.out_features) \
            + ', bias=' + str(self.bias is not None) + ')'

class ScalingLayer(nn.Module):
    def __init__(self, in_features, bias=True):
        super(ScalingLayer, self).__init__()
        self.in_features = in_features
        
        self.weight = Parameter(torch.Tensor(in_features, 1))
        self.bias = Parameter(torch.Tensor(in_features, 1))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        #stdv = 1. / math.sqrt(self.bias.size(0))
        #self.bias.data.uniform_(-stdv, stdv)
        init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    def forward(self, input):
        return (self.weight.T * input) + self.bias[:,0]
        
class TACD_GRU(nn.Module):
    def __init__(self, input_size, cell_size, hidden_size, X_mean=None, 
        output_last = False, args=None, use_encoder=False, f_out=None):
        
        super(TACD_GRU, self).__init__()
        
        self.hidden_size = hidden_size
        self.delta_size = input_size
        self.mask_size = input_size
        self.n_future_ts = args.tacd_pred_horizon
        #self.f_out = f_out
        
        self.use_gpu = use_gpu = torch.cuda.is_available()
        if use_gpu:
            self.identity = torch.eye(input_size).cuda()
            self.zeros = Variable(torch.zeros(input_size).cuda())
            self.zeros_hidden = Variable(torch.zeros(hidden_size).cuda())
            #self.X_mean = Variable(torch.Tensor(X_mean).cuda())
            #self.X_mean = Variable(torch.zeros_like(X_mean).cuda())
        else:
            self.identity = torch.eye(input_size)
            self.zeros = Variable(torch.zeros(input_size))
            self.zeros_hidden = Variable(torch.zeros(hidden_size))
            #self.X_mean = Variable(torch.Tensor(X_mean))
        self.X_mean = None
        sample_period = torch.tensor(args.sample_period).double()
        self.sample_period = nn.Parameter(sample_period, 
            requires_grad=False)
        mean_inter_arr = torch.tensor(args.mean_inter_arr).double()
        self.mean_inter_arr = nn.Parameter(mean_inter_arr, requires_grad=False)
        
        self.use_encoder = use_encoder

        self.refinement = args.grudplus_refine
        self.args = args

        if self.refinement == 'attention_rnn':
            pass
        else:
            if self.use_encoder:
                self.encoder = nn.Linear(input_size+hidden_size+self.mask_size, hidden_size) 
                self.zl = nn.Linear(hidden_size, hidden_size)
                self.rl = nn.Linear(hidden_size, hidden_size)
                self.hl = nn.Linear(hidden_size, hidden_size)
            else:
                self.zl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)

                #self.zl = nn.Linear(input_size + self.mask_size, hidden_size)
                #print('bias in Z layer in GRU is set to -100.')
                #with torch.no_grad(): self.zl.bias.fill_(-100.)
                #print('zl bias is fixed!!')
                #self.zl.bias.requires_grad = False

                self.rl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)
                self.hl = nn.Linear(input_size + hidden_size + self.mask_size, hidden_size)
            

        if self.refinement == 'decayed_hidden_to_inputs':
            assert False, "incorrect refinement"
            self.hidden_to_obs = nn.Linear(hidden_size, input_size)
        elif self.refinement == 'hidden_delta_t_to_inputs':
            assert False, "incorrect refinement"
            self.time_embed_dim = 8
            #self.time_periodic = nn.Linear(1, self.time_embed_dim - 1)
            #self.time_linear = nn.Linear(1, 1)
            self.time_linear = nn.Linear(input_size, self.time_embed_dim)
            self.hidden_to_obs = nn.Linear(hidden_size + self.time_embed_dim, input_size)

        elif self.refinement == 'decayed_h_input_delta_to_inputs':
            # REFINEMENT 5
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.hidden_state_emb_dim = 2
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = ScaledDotProductAttention(self.time_embed_dim+self.event_embed_dim+1)
            self.hidden_to_obs = nn.Linear(hidden_size, input_size * self.hidden_state_emb_dim)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)

        elif self.refinement == 'refinement_6':
            # REFINEMENT 6
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.hidden_state_emb_dim = 1
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = ScaledDotProductAttention(self.time_embed_dim+self.event_embed_dim)
            #self.combine_att_and_f_x = nn.Linear(hidden_size, 1)
            self.combine_att_and_f_x = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1))
            self.hidden_to_obs = nn.Linear(hidden_size, input_size * self.hidden_state_emb_dim)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            #self.query_event_embedding = nn.Embedding(input_size, self.time_embed_dim + self.event_embed_dim)
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)
            self.att_scaling = ScalingLayer(input_size)
            if self.args.task == 'classification':
                self.n_classes = n_classes = 7 # 0-6
                self.class_label_ids = nn.Parameter(data=torch.LongTensor(range(n_classes)), requires_grad=False) 
                self.class_embedding = nn.Embedding(n_classes, self.event_embed_dim+self.time_embed_dim)
                self.logits_linear = nn.Linear(n_classes, n_classes)

        elif self.refinement == 'attention_rnn':
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.m_attn_n_heads = self.hidden_size
            assert self.event_embed_dim % self.m_attn_n_heads == 0, "event embeddings should be divisible by # hidden units"
            self.hidden_decoder = nn.Linear(self.event_embed_dim, self.hidden_size)
            self.decoder_linear = nn.Linear(self.hidden_size, 1)
            self.droupout = nn.Dropout(p=0.1)
            self.hidden_state_emb_dim = 1
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = ScaledDotProductAttention(self.time_embed_dim+self.event_embed_dim)
            self.m_att = MultiHeadAttention(self.event_embed_dim, num_heads=self.m_attn_n_heads)
            self.combine_att_and_f_x = nn.Linear(hidden_size, 1)
            self.hidden_to_obs = nn.Linear(hidden_size, input_size * self.hidden_state_emb_dim)
            # event IDs and embeddings
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            self.event_embedding2 = nn.Embedding(input_size, 2)
            # hidden IDs and hidden embedding
            self.hidden_ids = nn.Parameter(data=torch.LongTensor(range(hidden_size)), requires_grad=False)
            self.hidden_embedding = nn.Embedding(hidden_size, self.event_embed_dim)

        elif self.refinement == 'refinement_7':
            # REFINEMENT 7: bypass when no observations are made. for example: extrapolation
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.hidden_state_emb_dim = 1
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = ScaledDotProductAttention(self.time_embed_dim+self.event_embed_dim)
            self.combine_att_and_f_x = nn.Linear(hidden_size, 1)
            self.hidden_to_obs = nn.Linear(hidden_size, input_size * self.hidden_state_emb_dim)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)

        elif self.refinement == 'refinement_8':
            # REFINEMENT 6 = REFINEMENT 5 - self.hidden_to_obs
            # same as refinement 5 but not hidden embedding for attention
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = ScaledDotProductAttention(self.time_embed_dim + self.event_embed_dim)
            #self.hidden_to_obs = nn.Linear(hidden_size, input_size)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)

        elif self.refinement == 'refinement_9':
            # REFINEMENT 7 = Refinement 6 + MHA
            self.time_embed_dim = args.tacd_time_emb
            self.event_embed_dim = args.tacd_event_emb
            self.time_periodic = nn.Linear(1, self.time_embed_dim-1)
            self.time_linear = nn.Linear(1, 1)
            self.att = MultiHeadAttention(self.time_embed_dim + self.event_embed_dim, num_heads=4)
            self.att_to_x_hat = nn.Linear(self.time_embed_dim + self.event_embed_dim, 1)
            self.residual_connection = ResidualConnection(self.time_embed_dim + self.event_embed_dim, 0.2)
            #self.att = ScaledDotProductAttention(self.time_embed_dim + self.event_embed_dim)
            #self.hidden_to_obs = nn.Linear(hidden_size, input_size)
            self.event_embedding = nn.Embedding(input_size, self.event_embed_dim)
            self.event_ids = nn.Parameter(data=torch.LongTensor(range(input_size)), requires_grad=False)

        elif self.refinement == 'decayed_h_input_delta_to_delta_inputs':
            assert False, "incorrect refinement"
            self.time_embed_dim = 8
            self.time_periodic = nn.Linear(input_size, self.time_embed_dim-1)
            self.time_linear = nn.Linear(input_size, 1)
            #self.hidden_to_obs = nn.Linear(hidden_size + self.time_embed_dim + input_size, input_size)
            self.hidden_to_obs = nn.Linear(hidden_size + self.time_embed_dim, input_size)

        #self.gamma_x_l = FilterLinear(self.delta_size, self.delta_size, self.identity)
        
        #self.gamma_h_l = nn.Linear(self.delta_size, hidden_size)
        self.gamma_h_l = nn.Linear(1, hidden_size)
        self.output_last = output_last


        
    def step(self, x, x_last_obsv, x_mean, h, mask, delta, curr_delta_t, 
        all_time_points, curr_time, all_attn_logits, curr_step_idx):
        
        batch_size = x.size(0)
        feat_size = x.size(1)
        t_size = all_time_points.size(1)
        
        # gamma(delta tau) to decay the hidden state
        delta_h = torch.exp(-torch.max(self.zeros_hidden, self.gamma_h_l(curr_delta_t.unsqueeze(-1))))

        '''
        delta_h = torch.exp(-torch.max(self.zeros_hidden, 
            self.gamma_h_l(curr_delta_t.unsqueeze(-1) / self.mean_inter_arr)))
        '''
         

        if self.refinement == 'refinement_6':
            # decay the hidden state
            h = delta_h * h

            # Prediction function
            if self.args.task == 'extrapolation' or \
                self.args.task == 'next_obs_prediction':
                # add \Delta T prediction horizon to delta_t
                t_pred_idx = torch.where(((all_time_points - curr_time) > 0).any(axis=0))[0][:self.n_future_ts]

            elif self.args.task == 'classification':
                # include current timestep as well
                t_pred_idx = torch.where(((all_time_points - curr_time) >= 0).all(axis=0))[0][:self.n_future_ts]

            elif self.args.task == 'interpolation':
                t_pred_idx = torch.where(((all_time_points - curr_time) >= 0).all(axis=0))[0][:self.n_future_ts]

            else:
                assert False, "unknown task"

            t_pred_len = t_pred_idx.size(0)

            delta_w_pred_horizon = delta[:,None,:] +  (all_time_points - curr_time)[:, t_pred_idx ,None]
            '''
            # sample period rescaling
            sample_period = self.sample_period[None,None,:].repeat(batch_size, t_pred_len, 1)
            delta_w_pred_horizon = delta_w_pred_horizon / sample_period
            '''

            '''
            # prediction based on deacyed hidden state
            future_delta_ts = (all_time_points - curr_time)[:, t_pred_idx]
            future_delta_hs = torch.exp(-torch.max(self.zeros_hidden, self.gamma_h_l(future_delta_ts.unsqueeze(-1))))
            future_hs = future_delta_hs * h[:,None,:]
            x_hat_decayed_h = self.hidden_to_obs(future_hs)
            combo_coefficient = F.sigmoid(self.combine_att_and_f_x(future_hs))
            '''

            if self.args.dataset == 'physionet':
                #eps = 0.01 # for physionet. t in [0,48.0] hours
                eps = 0.0001 # for physionet. t in [0,48.0] hours
            elif self.args.dataset == 'ushcn':
                eps = 0.01 # for ushcn. mean del t=2
            elif self.args.dataset == 'mimic':
                eps = 10 # in seconds 

            N_EPS = 1

            # event embeddings
            #event_emb = self.event_embedding(self.event_ids)[None,:].repeat(batch_size, t_pred_len, 1, 1)
            event_emb = self.event_embedding(self.event_ids)[None,None,:,None].repeat(batch_size, t_pred_len, 1, N_EPS, 1)
            #query_emb = self.query_event_embedding(self.event_ids)[None, :].repeat(batch_size, t_pred_len, 1, 1)

            #embed the times
            # add the eps dim
            #base_eps_vector = torch.tensor([-eps, 0.0, eps]).double().to(x.device)
            #base_eps_vector = torch.tensor([-eps, 0.0, eps, self.mean_inter_arr.item()]).double().to(x.device)
            base_eps_vector = torch.tensor([0.0]).double().to(x.device)
            delta_w_pred_horizon = (base_eps_vector[None,None,None,:] + delta_w_pred_horizon[...,None])
            # delta emb = batch x time x feat x emb x eps_dim
            delta_emb = torch.cat([self.time_linear(delta_w_pred_horizon.unsqueeze(-1)), 
                self.time_periodic(delta_w_pred_horizon.unsqueeze(-1)), 
                event_emb], dim=-1)
    
            delta_emb_3d = delta_emb.view(-1, feat_size, delta_emb.size(-1)) # keep only feat and emb dim
            #query_emb_3d = query_emb.view(-1, feat_size, delta_emb.size(3))

            # classification
            if self.args.task == 'classification':
            
                '''
                base_class_last_obs = all_time_points[:,curr_step_idx,None].repeat(1, self.n_classes)
                if curr_step_idx > 0:
                    pdb.set_trace()
                    prev_cls_deltas = all_time_points[:,curr_step_idx:curr_step_idx+1] - all_time_points[:, :curr_step_idx]
                    prev_pred_classes = all_attn_logits[:,:curr_step_idx, :].argmax(axis=-1)


                target_delta = torch.zeros((batch_size,t_pred_len,self.n_classes), 
                        dtype=torch.float64).to(delta_emb_3d.device)# predict at current time
                '''

                target_emb = target_class_embeddings = self.class_embedding(self.class_label_ids)[None, None, :, :].repeat(batch_size, t_pred_len, 1, 1)
                #target_emb = torch.cat([self.time_linear(target_delta.unsqueeze(-1)),
                #    self.time_periodic(target_delta.unsqueeze(-1)),
                #    target_class_embeddings], dim=-1)
                target_emb = target_emb.view(-1,
                    self.n_classes, target_emb.size(-1))
                clf_x_last_obsv = x_last_obsv[:,None,:].repeat(1, t_pred_len, 1)
                clf_x_last_obsv_for_attn = clf_x_last_obsv.view(-1, feat_size, 1)
                clf_logits, clf_attn = self.att(target_emb, delta_emb_3d, clf_x_last_obsv_for_attn)
                clf_logits = clf_logits.view(batch_size, t_pred_len, self.n_classes)
                clf_logits = clf_logits + F.relu(self.logits_linear(clf_logits))
            else:
                clf_logits = None

            # attention weights = self attention, value = last value
            x_last_obsv = x_last_obsv[:,None,:,None].repeat(1, t_pred_len, 1, N_EPS)
            x_last_obsv_for_attn = x_last_obsv.view(-1, feat_size, 1)
            # self attention like (paper formulation)
            x_hat, attn = self.att(delta_emb_3d, delta_emb_3d, x_last_obsv_for_attn)
            #x_hat, attn = self.att(delta_emb_3d, query_emb_3d, x_last_obsv_for_attn)
            #x_hat = x_hat.view(batch_size, t_pred_len, N_EPS, feat_size)
            x_hat = x_hat.view(batch_size, t_pred_len, feat_size, N_EPS)
            x_hat = self.att_scaling(x_hat.permute(0,1,3,2))
            x_hat = x_hat.permute(0,1,3,2)
            # attention weights model the residuals
            x_hat = x_last_obsv + x_hat
    
            # mask out non-observations
            x = mask * x


        else:
            # default GRUD
            assert False, "running GRUD from GRUDPlus"
            h = delta_h * h
            x = mask * x + (1 - mask) * (delta_x * x_last_obsv + (1 - delta_x) * x_mean)
        

        # default GRU logic
        combined = torch.cat((x, h, mask), 1)
        if self.use_encoder:
            combined = self.encoder(combined)
        z = F.sigmoid(self.zl(combined))
        r = F.sigmoid(self.rl(combined))
        combined_r = torch.cat((x, r * h, mask), 1)
        if self.use_encoder:
            combined_r = self.encoder(combined_r)
        h_tilde = F.tanh(self.hl(combined_r))
        output_h = (1 - z) * h + z * h_tilde

        '''
        # new logic for all-missing input consistency
        combined = torch.cat((x, h, mask), 1)
        if self.use_encoder:
            combined = self.encoder(combined)
        max_val, _ = mask.max(axis=-1)
        z = max_val[:,None] * F.sigmoid(self.zl(combined))
        r = F.sigmoid(self.rl(combined))
        combined_r = torch.cat((x, r * h, mask), 1)
        if self.use_encoder:
            combined_r = self.encoder(combined_r)
        h_tilde = F.tanh(self.hl(combined_r))
        output_h = (1 - z) * h + z * h_tilde
        '''
    
        # combine after observing x_t and updating to h_t+1
        # prediction based on deacyed hidden state
        future_delta_ts = (all_time_points - curr_time)[:, t_pred_idx, None].repeat(1,1,N_EPS)
        #future_delta_ts = (all_time_points - curr_time)[:, t_pred_idx] / self.mean_inter_arr
        future_delta_ts = future_delta_ts + base_eps_vector[None,None,:]
        future_delta_hs = torch.exp(-torch.max(self.zeros_hidden, self.gamma_h_l(future_delta_ts.unsqueeze(-1))))
        future_hs = future_delta_hs * output_h[:,None,None,:]

        # single W for all future timesteps
        x_hat_decayed_h = self.hidden_to_obs(future_hs)
        x_hat_decayed_h = x_hat_decayed_h.permute(0,1,3,2)
        #x_hat_decayed_h = x_last_obsv + x_hat_decayed_h

        combo_coefficient = F.sigmoid(self.combine_att_and_f_x(future_hs[:,:,0,:]))
        x_hat = x_hat[...,0]
        x_hat_decayed_h = x_hat_decayed_h[...,0]
        # previous combination logic
        if self.args.grudplus_ablation_mode == 'no_ablation':
            if self.args.tacd_add_noise > 0.0:
                # add noise to x^c
                noise_prop = self.args.tacd_add_noise 
                uni_rand_preds = torch.rand(x_hat_decayed_h.shape)\
                    .double().to(x_hat_decayed_h.device)
                x_hat_decayed_h = noise_prop * uni_rand_preds + \
                    (1 - noise_prop) * x_hat_decayed_h

            final_x_hat = combo_coefficient * x_hat + \
                        (1 - combo_coefficient) * x_hat_decayed_h
        elif self.args.grudplus_ablation_mode == 'attention_only':
            final_x_hat = x_hat
        elif self.args.grudplus_ablation_mode == 'context_only':
            final_x_hat = x_hat_decayed_h

        '''
        if self.args.grudplus_ablation_mode == 'no_ablation':
            # prefer function with larger long-term changes and 
            # small local changes
            non_zero_eps = 1e-12
            attn_small = ((x_hat[:,:,:,1] - x_hat[:,:,:,2]).abs() + non_zero_eps).log()
            attn_large = ((x_hat[:,:,:,1] - x_hat[:,:,:,3]).abs() + non_zero_eps).log()
            ctxt_small = ((x_hat_decayed_h[:,:,:,1] - x_hat_decayed_h[:,:,:,2]).abs() + non_zero_eps).log()
            ctxt_large = ((x_hat_decayed_h[:,:,:,1] - x_hat_decayed_h[:,:,:,3]).abs() + non_zero_eps).log()
            #w_attn = attn_large - attn_small
            #w_ctxt = ctxt_large - ctxt_small
            w_attn =  - attn_small
            w_ctxt =  - ctxt_small
            denom = w_attn + w_ctxt
            w_attn = w_attn / denom
            if w_attn.numel() > 0:
                assert w_attn.min() >= 0.0 and w_attn.max() <= 1.0, "should be between 0 and 1"
            combo_coefficient = w_attn.mean(axis=-1)[:,:,None].detach()
            # feat-wise estimator weighting
            w_attn = w_attn.detach()
            final_x_hat = w_attn * x_hat[...,1] + (1-w_attn) * x_hat_decayed_h[...,1]
        
            # avg(feat) estimator weighting
            #final_x_hat = combo_coefficient * x_hat[...,1] + (1-combo_coefficient) * x_hat_decayed_h[...,1]


            # another combination strategy
            #mask out higher variance
            #final_x_hat = x_hat[...,1] # eps=0
            #final_x_hat = x_hat.mean(axis=-1)
            #final_x_hat[std_ctxt < std_attn] = x_hat_decayed_h[...,1][std_ctxt < std_attn]
            #final_x_hat[std_ctxt < std_attn] = x_hat_decayed_h.mean(axis=-1)[std_ctxt < std_attn]
        elif self.args.grudplus_ablation_mode == 'attention_only':
            #final_x_hat = x_hat.mean(axis=-1)
            final_x_hat = x_hat[...,1]
        elif self.args.grudplus_ablation_mode == 'context_only':
            final_x_hat = x_hat_decayed_h.mean(axis=-1)
            #final_x_hat = x_hat_decayed_h[...,1]
        '''

        return output_h, final_x_hat, t_pred_idx, clf_logits, combo_coefficient

    def pad_time_correctly(self, time_points):
        '''
            O(batch_size) loop solution
            Converts trailing zeros after max ts to max ts
        '''
        max_ts, max_idx = time_points[:,:].max(axis=1)
        for i in range(len(time_points)):
            time_points[i, max_idx[i]+1:] = max_ts[i]
        return time_points
    
    def forward(self, input, valid_mask, time_points):

        # pre process time_points
        time_points = self.pad_time_correctly(time_points)
        assert torch.diff(time_points).min() >= 0, "moving back in time"

        batch_size = input.size(0)
        #type_size = input.size(1)
        step_size = input.size(1)
        #spatial_size = input.size(3)
        
        hidden_state = self.initHidden(batch_size)
        X = input
        Mask = valid_mask
        #X = torch.squeeze(input[:,0,:,:])
        #X_last_obsv = torch.squeeze(input[:,1,:,:])
        #Mask = torch.squeeze(input[:,2,:,:])
        #Delta = torch.squeeze(input[:,3,:,:])
        
        outputs = None
        last_values = torch.zeros((batch_size, input.shape[-1]), dtype=torch.double).cuda()
        #prev_last_values = torch.zeros((batch_size, input.shape[-1]), dtype=torch.double).cuda()
        time_delta = torch.zeros((batch_size, input.shape[-1])).cuda().double()
        #prev_time_delta = torch.zeros((batch_size, input.shape[-1])).cuda().double()

        cumm_x_hat_t = torch.zeros_like(X)
        cumm_combo_wts = 0.5 * torch.ones_like(X[...,0])[...,None]
        #agg_clf_logits = []
        if self.args.task == 'classification':
            agg_clf_logits = torch.zeros((X.size(0), X.size(1), self.n_classes)).double().to(X.device)
        else:
            agg_clf_logits = None
        for i in range(step_size):
            current_input = torch.squeeze(X[:,i:i+1,:], dim=1)
            current_mask = torch.squeeze(Mask[:,i:i+1,:], dim=1)

            ### Update state model ###
            # update last observed values x*
            last_values[current_mask] = current_input[current_mask]

            if i > 0:
                curr_delta_t = (time_points[:,i] - time_points[:,i-1]).double()
            else:
                curr_delta_t = torch.zeros((batch_size), dtype=torch.double).cuda()

            # update delta ts for x*
            if i > 0:
                #time_delta = time_delta + (time_points[:,i] - time_points[:,i-1])[:,None].repeat(1, input.shape[-1])
                time_delta = time_delta + (curr_delta_t)[:,None].repeat(1, input.shape[-1])

            time_delta[current_mask] = 0.0
            current_time = time_points[:,i:i+1]
             
            #print('processing step: {}'.format(i))
            curr_atleast_one_obs = current_mask.any(-1)
            #print('processing: ',i)
            hidden_state, x_hat_t, t_pred_idx, clf_logits, \
            combo_weights = self.step(current_input
                                     , last_values.clone()
                                     , None
                                     , hidden_state
                                     , current_mask.double()
                                     , time_delta.clone()
                                     , curr_delta_t
                                     , time_points
                                     , current_time
                                     , agg_clf_logits
                                     , i)

            # considers history and future
            #cumm_x_hat_t = cumm_x_hat_t + (1 / step_size) * x_hat_t

            if self.args.task == 'classification':
                #assert len(t_pred_idx) == 1 and t_pred_idx[0] == i, \
                assert t_pred_idx[0] == i, \
                    "unexpeced prediction indices for classification"
                #agg_clf_logits.append(clf_logits)
                agg_clf_logits[:,t_pred_idx ,:] = agg_clf_logits[:,t_pred_idx ,:] + \
                    (1/t_pred_idx.size(0)) * clf_logits

            elif self.args.task == 'extrapolation':
                assert (t_pred_idx > i).all().item(), "unexpeced prediction indices for extrapolation"

            if t_pred_idx.size(0) > 0:
                cumm_x_hat_t[:,t_pred_idx ,:] = cumm_x_hat_t[:,t_pred_idx ,:] + (1/t_pred_idx.size(0)) * x_hat_t
                cumm_combo_wts[:,t_pred_idx ,:] = combo_weights
    
            if i == 0:
                # hack to make last observed values for zero-th time prediction
                cumm_x_hat_t[:,0,:] = last_values.clone()
                
            if outputs is None:
                outputs = hidden_state.unsqueeze(1)
            else:
                outputs = torch.cat((outputs, hidden_state.unsqueeze(1)), 1)

                
        if self.output_last:
            return outputs[:,-1,:]
        else:
            #return outputs, torch.stack(x_hats).permute(1,0,2)
            #agg_clf_logits = torch.cat(agg_clf_logits, dim=-1).transpose(2,1)
            return outputs, cumm_x_hat_t, agg_clf_logits, cumm_combo_wts
    
    def initHidden(self, batch_size):
        use_gpu = torch.cuda.is_available()
        if use_gpu:
            hidden_state = Variable(torch.zeros(batch_size, self.hidden_size).cuda())
            return hidden_state
        else:
            hidden_state = Variable(torch.zeros(batch_size, self.hidden_size))
            return hidden_state
