# Adapted from https://openreview.net/forum?id=F3s69XzWOia
# Modified  for handling irregularly sampled time-series
from torch import nn
import torch
from torch.autograd import Variable
import math

class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, gamma, epsilon,tau=1.0):
        super(coRNNCell, self).__init__()
        self.gamma = gamma
        self.epsilon = epsilon
        self.tau = tau
        self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)

    def forward(self,x,dt, hy,hz):
        hz = hz + dt * self.tau * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
                                   - self.gamma * hy - self.epsilon * hz)
        hy = hy + dt * self.tau * hz

        return hy, hz

class coRNN(nn.Module):
    def __init__(self, n_inp, n_hid, n_out,  gamma, epsilon,tau,return_sequences):
        super(coRNN, self).__init__()
        self.n_hid = n_hid
        self.return_sequences =return_sequences
        self.cell = coRNNCell(n_inp,n_hid,gamma,epsilon,tau)
        self.readout = nn.Linear(n_hid, n_out)
        self.n_out = n_out

    def forward(self, x,timespans,mask=None):
        batch_size = x.size(0)
        seq_len = x.size(1)
        device = x.device
        ## initialize hidden states
        hy = Variable(torch.zeros(batch_size,self.n_hid, device=device))
        hz = Variable(torch.zeros(batch_size,self.n_hid, device=device))
        
        outputs = []
        last_output = torch.zeros((batch_size, self.n_out), device=device)
        for t in range(seq_len):
            ts = timespans[:, t].view(batch_size,1)
            hy, hz = self.cell(x[:,t],ts,hy,hz)
            current_output = self.readout(hy)
            
            outputs.append(current_output)
            if mask is not None:
                cur_mask = mask[:, t].view(batch_size, 1)
                last_output = cur_mask * current_output + (1.0 - cur_mask) * last_output
            else:
                last_output = current_output
        if self.return_sequences:
            outputs = torch.stack(outputs, dim=1)  # return entire sequence
        else:
            outputs = last_output  # only last item
        return outputs
        

