
import math
import torch
import torchcde
import datasets
import sklearn.model_selection

import torchdiffeq
from random import SystemRandom
import random
import numpy as np 
from parse import parse_args
import pathlib
import os 
import time 
import tqdm
args = parse_args()
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
ODE_BASED=['node','masknode']
GRUS = ['dt','gru_d','odernn']
CUBICS = ['natural_cubic','cubic']


def split_data(tensor, stratify):
    
    (train_tensor, testval_tensor,
     train_stratify, testval_stratify) = sklearn.model_selection.train_test_split(tensor, stratify,
                                                                                  train_size=0.7,
                                                                                  random_state=0,
                                                                                  shuffle=True,
                                                                                  stratify=stratify)

    val_tensor, test_tensor = sklearn.model_selection.train_test_split(testval_tensor,
                                                                       train_size=0.5,
                                                                       random_state=1,
                                                                       shuffle=True,
                                                                       stratify=testval_stratify)
    return train_tensor, val_tensor, test_tensor
def normalise_data(X,y):
    train_X, _, _ = split_data(X, y) 
    out = []
    for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)):
        train_Xi_nonan = train_Xi.masked_select(~torch.isnan(train_Xi))
        mean = train_Xi_nonan.mean()  # compute statistics using only training data.
        std = train_Xi_nonan.std()
        out.append((Xi - mean) / (std + 1e-5))
    out = torch.stack(out, dim=-1)
    return out

class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

class ODEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(ODEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        # import pdb;pdb.set_trace()
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        # z = z.relu()
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        return z

class _ODERNNFunc(torch.nn.Module):
    def __init__(self, hidden_channels, hidden_hidden_channels, num_hidden_layers):
        super(_ODERNNFunc, self).__init__()
        # import pdb ; pdb.set_trace()
        layers = [torch.nn.Linear(hidden_channels, hidden_hidden_channels)]
        for _ in range(num_hidden_layers - 1):
            layers.append(torch.nn.Tanh())
            layers.append(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels))
        layers.append(torch.nn.Tanh())
        layers.append(torch.nn.Linear(hidden_hidden_channels, hidden_channels))
        self.sequential = torch.nn.Sequential(*layers)

    def forward(self, t, x):
        
        return self.sequential(x)


class DisGruFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(DisGruFunc, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)

    def forward(self, t,x, h,dxdt):
        
        r = self.W_r(x) + self.U_r(h)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h)
        g = g0.tanh()
        
        return (1 - z) * (g - h)


class ContGruFunc_Aug(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(ContGruFunc_Aug, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)

    def forward(self, t,x, h,dxdt):
        if t ==0:
            h_past = h
        else:
            
            h_past = torch.Tensor(np.load(self.file+"/h_past/h_past_"+str(self.rnd)+".npy")).to(h)
            
        r = self.W_r(x) + self.U_r(h_past)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h_past)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h_past)
        g = g0.tanh()
        h_ = torch.mul(z,h_past) + torch.mul((1-z),g) # save h at t 
        
        np.save(self.file+'/h_past/h_past_'+str(self.rnd)+'.npy',h_.cpu().detach().numpy())

        hg = h_past - g
        
        if t==0:
            dhpast_dt = (1 - z) * (g - h)
        else:
            
            dhpast_dt = torch.Tensor(np.load(self.file+"/dhpastdt/dhpastdt_"+str(self.rnd)+".npy")).to(h)
        
        control_gradient = dxdt.derivative(t)  # 256,28
        
        dAdt =((self.W_z.weight @ control_gradient.unsqueeze(-1)) + (self.U_z.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1) # dAdt = 10,49,1
        
        dzdt =torch.mul(torch.mul(z,(1-z)),dAdt)
        
        drdt = torch.mul(torch.mul(r,(1-r)),((self.W_r.weight @ control_gradient.unsqueeze(-1))+(self.U_r.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1)) #drdt : 10,49
        dBdt =(self.W_h.weight @ control_gradient.unsqueeze(-1)).squeeze(-1) + torch.mul((self.U_h.weight@drdt.unsqueeze(-1)).squeeze(-1),h) +torch.mul((self.U_h.weight@r.unsqueeze(-1)).squeeze(-1),dhpast_dt)
        
        dgdt = torch.mul(torch.mul((1-g),(1+g)),dBdt)
        
        dhgdt = dhpast_dt - dgdt
        
        dhdt = torch.mul(dzdt,hg) + torch.mul(z,dhgdt) + dgdt 
        np.save(self.file+'/dhpastdt/dhpastdt_'+str(self.rnd)+'.npy',dhdt.cpu().detach().numpy())
        
        return dhdt,drdt,dgdt,dzdt
        

class ContGruFunc_Delay(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels,file_path,rnd):
        super(ContGruFunc_Delay, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)
        self.file = file_path
        self.rnd=rnd
    def forward(self, t,x, h,dxdt):
        
        
        if t ==0:
            h_past = h
        else:
            
            h_past = torch.Tensor(np.load(self.file+"/h_past/h_past_"+str(self.rnd)+".npy")).to(h)
            
        r = self.W_r(x) + self.U_r(h_past)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h_past)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h_past)
        g = g0.tanh()
        h_ = torch.mul(z,h_past) + torch.mul((1-z),g) # save h at t 
        
        np.save(self.file+'/h_past/h_past_'+str(self.rnd)+'.npy',h_.cpu().detach().numpy())

        hg = h_past - g
        
        if t==0:
            dhpast_dt = (1 - z) * (g - h)
        else:
            
            dhpast_dt = torch.Tensor(np.load(self.file+"/dhpastdt/dhpastdt_"+str(self.rnd)+".npy")).to(h)
        
        control_gradient = dxdt.derivative(t)  # 256,28
        
        dAdt =((self.W_z.weight @ control_gradient.unsqueeze(-1)) + (self.U_z.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1) # dAdt = 10,49,1
        
        dzdt =torch.mul(torch.mul(z,(1-z)),dAdt)
        
        drdt = torch.mul(torch.mul(r,(1-r)),((self.W_r.weight @ control_gradient.unsqueeze(-1))+(self.U_r.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1)) #drdt : 10,49
        dBdt =(self.W_h.weight @ control_gradient.unsqueeze(-1)).squeeze(-1) + torch.mul((self.U_h.weight@drdt.unsqueeze(-1)).squeeze(-1),h) +torch.mul((self.U_h.weight@r.unsqueeze(-1)).squeeze(-1),dhpast_dt)
        
        dgdt = torch.mul(torch.mul((1-g),(1+g)),dBdt)
        
        dhgdt = dhpast_dt - dgdt
        
        dhdt = torch.mul(dzdt,hg) + torch.mul(z,dhgdt) + dgdt 
        np.save(self.file+'/dhpastdt/dhpastdt_'+str(self.rnd)+'.npy',dhdt.cpu().detach().numpy())
        
        
        return dhdt




class ContinuousRNNConverter(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, model):
        super(ContinuousRNNConverter, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.model = model
        self.linear = torch.nn.Linear(self.hidden_channels,self.input_channels+self.hidden_channels)
        out_base = torch.zeros(
            self.input_channels + self.hidden_channels, self.input_channels
        )
        for i in range(self.input_channels):
            out_base[i, i] = 1
        self.register_buffer("out_base", out_base)

    

    def forward(self,t, z,dxdt):
        x = z[..., : self.input_channels]
        h = z[..., self.input_channels :]
        h = h.clamp(-1, 1)
        model_out = self.model(t,x, h,dxdt)
        out = self.linear(model_out)
        return model_out,out


class ContinuousRNNConverter_Aug(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, model):
        super(ContinuousRNNConverter_Aug, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.model = model
        self.linear = torch.nn.Linear(self.hidden_channels,self.input_channels+self.hidden_channels)
        out_base = torch.zeros(
            self.input_channels + self.hidden_channels, self.input_channels
        )
        for i in range(self.input_channels):
            out_base[i, i] = 1
        self.register_buffer("out_base", out_base)

    

    def forward(self,t, z,dxdt):
        x = z[..., : self.input_channels]
        h = z[..., self.input_channels :]
        h = h.clamp(-1, 1)
        model_out,model_r,model_g,model_z = self.model(t,x, h,dxdt)
        out = self.linear(model_out)
        return model_out,out,model_r,model_g,model_z # 256,49 256,55



class _GRU_forecasting(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, forecast_window,device):
        super(_GRU_forecasting, self).__init__()

        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels
        self.forecast_window = forecast_window
        self.use_intensity = True
        self.device = device
        gru_channels = input_channels if self.use_intensity else (input_channels - 1) // 2
        # gru_channels = input_channels
        self.gru_cell = torch.nn.GRUCell(input_size=gru_channels, hidden_size=hidden_channels)
        self.linear = torch.nn.Linear(hidden_channels, input_channels-1)
        self.readout = torch.nn.Linear(1,output_channels)
        self.read_time = torch.nn.Linear(hidden_channels,forecast_window)
    def extra_repr(self):
        return "input_channels={}, hidden_channels={}, output_channels={}, use_intensity={}" \
               "".format(self.input_channels, self.hidden_channels, self.output_channels, self.use_intensity)

    def evolve(self, h, time_diff):
        raise NotImplementedError

    def _step(self, Xi, h, dt, half_num_channels):
        
        observation = Xi[:, 1: 1 + half_num_channels].max(dim=1).values > 0.5
        if observation.any():
            Xi_piece = Xi if self.use_intensity else Xi[:, 1 + half_num_channels:]
            Xi_piece = Xi_piece.clone() # 200,3
            Xi_piece[:, 0] += dt
            new_h = self.gru_cell(Xi_piece.float(), h.float()) 
            h = torch.where(observation.unsqueeze(1), new_h, h)
            dt += torch.where(observation, torch.tensor(0., dtype=Xi.dtype, device=Xi.device), Xi[:, 0])
        return h, dt

    def forward(self, coeffs, z0=None):
        
        XX = torchcde.CubicSpline(coeffs)
        times = torch.Tensor(np.arange(XX.interval[-1].item())).to(self.device)
        X = torch.stack([XX.evaluate(t) for t in times], dim=-2) # 200,24,7
        half_num_channels = (self.input_channels - 1) // 2

        X[:, 1:, 1:1 + half_num_channels] -= X[:, :-1, 1:1 + half_num_channels]

        X[:, 0, 0] -= times[0]
        X[:, 1:, 0] -= times[:-1]

        batch_dims = X.shape[:-2]

        if z0 is None:
            z0 = torch.zeros(*batch_dims, self.hidden_channels, dtype=X.dtype, device=X.device)
        
        X_unbound = X.unbind(dim=1)
        h, dt = self._step(X_unbound[0].float(), z0.float(), torch.zeros(*batch_dims, dtype=X.dtype, device=X.device).float(),
                           half_num_channels)
        hs = [h]
        time_diffs = times[1:] - times[:-1]
        for time_diff, Xi in zip(time_diffs, X_unbound[1:]):
            h = self.evolve(h, time_diff)
            h, dt = self._step(Xi, h, dt, half_num_channels)
            hs.append(h)
        out = torch.stack(hs, dim=1)
        out= out[:,-1,:] 
        input_time = out.shape[1]
        out = self.read_time(out)
        out = self.readout(out.unsqueeze(-1))
        
        return out

class GRU_D_forecasting(_GRU_forecasting):
    def __init__(self, input_channels, hidden_channels, output_channels,forecast_window,device ):
        super(GRU_D_forecasting, self).__init__(input_channels=input_channels,
                                    
                                    hidden_channels=hidden_channels,
                                    output_channels=output_channels,
                                    forecast_window = forecast_window,
                                    device=device)
        self.decay = torch.nn.Linear(1, hidden_channels)

    def evolve(self, h, time_diff):
        return h * torch.exp(-self.decay(time_diff.unsqueeze(0)).squeeze(0).relu())

class GRU_dt_forecasting(_GRU_forecasting):
    def evolve(self, h, time_diff):
        return h



class ODERNN_forecasting(_GRU_forecasting):
    def __init__(self, input_channels,forecast_window, hidden_channels, output_channels, hidden_hidden_channels, num_hidden_layers,
                 device):
        super(ODERNN_forecasting, self).__init__(input_channels=input_channels,
                                    
                                     hidden_channels=hidden_channels,
                                     output_channels=output_channels,
                                     forecast_window=forecast_window,
                                     device=device)
        
        self.hidden_hidden_channels = hidden_hidden_channels
        self.num_hidden_layers = num_hidden_layers

        self.func = _ODERNNFunc(hidden_channels, hidden_hidden_channels, num_hidden_layers)

    def extra_repr(self):
        return "hidden_hidden_channels={}, num_hidden_layers={}".format(self.hidden_hidden_channels,
                                                                        self.num_hidden_layers)

    def evolve(self, h, time_diff):
        
        t = torch.tensor([0, time_diff.item()], dtype=time_diff.dtype, device=time_diff.device)
        out = torchdiffeq.odeint_adjoint(func=self.func, y0=h, t=t, method='rk4')
        
        return out[1]


def GRU_ODE_Aug(input_channels, hidden_channels):
    
    func = ContGruFunc_Aug(input_channels=input_channels, hidden_channels=hidden_channels)
    return ContinuousRNNConverter_Aug(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)

def GRU_ODE_Delay(input_channels, hidden_channels,file_path,rnd):
    
    func = ContGruFunc_Delay(input_channels=input_channels, hidden_channels=hidden_channels,file_path=file_path,rnd=rnd)
    return ContinuousRNNConverter(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)
def Dis_GRU_ODE(input_channels, hidden_channels):
    
    func = DisGruFunc(input_channels=input_channels, hidden_channels=hidden_channels)
    return ContinuousRNNConverter(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)

class NeuralCDE_forecasting(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels,forecast_window,device,file_path,rnd,alpha,beta, interpolation="cubic"):
        super(NeuralCDE_forecasting, self).__init__()
        self.interpolation = interpolation
        if self.interpolation =='linear':
            input_channels = input_channels*4
        if args.model=='ncde':
            self.func = CDEFunc(input_channels, hidden_channels)
            self.readout = torch.nn.Linear(hidden_channels,output_channels)
        
        if args.model=='contgru_delay':
            self.func = GRU_ODE_Delay(input_channels,hidden_channels,file_path=file_path,rnd=rnd)
            self.readout = torch.nn.Linear(1,output_channels)
            self.read_time = torch.nn.Linear(hidden_channels,forecast_window)
            
        if args.model=='contgru_aug':
            
            self.func = GRU_ODE_Aug(input_channels,hidden_channels)
            self.readout = torch.nn.Linear(1,output_channels)
            self.read_time = torch.nn.Linear(hidden_channels,forecast_window)
        if args.model =='gruode':
            self.func = Dis_GRU_ODE(input_channels,hidden_channels)
            self.readout = torch.nn.Linear(1,output_channels)
            self.read_time = torch.nn.Linear(hidden_channels,forecast_window)
        
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        
           
        self.input_channels = input_channels
        self.device=device
        self.forecast_window = forecast_window
        self.file = file_path
        self.alpha= alpha 
        self.beta= beta


    def forward(self, coeffs,times):
       
        if self.interpolation in CUBICS:
            X = torchcde.CubicSpline(coeffs) # 256,7
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)  # torch.Size([256, 28])
        else:
            raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
        
        batch_dims = coeffs.shape[:-2]
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        # import pdb ; pdb.set_trace()
        
        if args.model=='contgru_delay':
            z0_extra = torch.zeros(
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0 # 256,49
            z0 = torch.cat([z0_extra, z0], dim=-1) # 245,49 + 7 

            h_T,z_T = torchcde.contint_delay(X=X,
                                  z0=z0,
                                  h0=h0,
                                  func=self.func,
                                  t=times,
                                  device=self.device)

        

        if args.model=='contgru_delay_aug':
            z0_extra = torch.zeros(
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0
            z0 = torch.cat([z0_extra, z0], dim=-1)
            # import pdb ; pdb.set_trace()
            h_T,z_T = torchcde.contint_delay_aug(X=X,
                                  z0=z0,
                                  h0=h0,
                                  func=self.func,
                                  t=times,
                                  device=self.device)
                    
                                  # z_T : 10,2,118
        if args.model=='gruode':
            z0_extra = torch.zeros(
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0
            z0 = torch.cat([z0_extra, z0], dim=-1)

            h_T,z_T = torchcde.contint(X=X,
                                  z0=z0,
                                  h0=h0,
                                  input_channels=self.input_channels,
                                  func=self.func,
                                  t=X.interval,
                                  device=self.device)
        if args.model =='ncde':
            z_T = torchcde.cdeint(X=X,
                                z0=z0,
                                func=self.func,
                                t=times)
        look_window = times.shape[0]
        if args.model=='contgru_delay_aug' or args.model =='contgru_delay':
            h_T = h_T[:, 1]
            z_T = z_T[:,1]
            pred_y= self.read_time(h_T)
            
            pred_y = self.readout(pred_y.unsqueeze(-1))
        
        
        if args.model =='ncde':
            
            z_T =  z_T[:,look_window-self.forecast_window:,:]
            pred_y = self.readout(z_T)

           
        if args.model =='gruode':
            h_T = h_T[:, 1]
            pred_y = self.read_time(h_T)
            pred_y = self.readout(pred_y.unsqueeze(-1))
        return pred_y,z_T[:,self.input_channels:]


class NeuralODE_forecasting(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels,forecast_window,times,device,mask='False', interpolation="cubic"):
        super(NeuralODE_forecasting, self).__init__()
        if args.model in ODE_BASED:
            self.func = ODEFunc(input_channels, hidden_channels)
            self.readout = torch.nn.Linear(hidden_channels, output_channels)
        
        
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.interpolation = interpolation
        self.input_channels = input_channels
        self.device=device
        self.times=times
        self.forecast_window = forecast_window
        
    def forward(self, coeffs,times,adjoint=True,**kwargs):
        
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
        
        batch_dims = coeffs.shape[:-2]
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        if args.model in ODE_BASED:
            if 'atol' not in kwargs:
                kwargs['atol'] = 1e-6
            if 'rtol' not in kwargs:
                kwargs['rtol'] = 1e-4
            if adjoint:
                if "adjoint_atol" not in kwargs:
                    kwargs["adjoint_atol"] = kwargs["atol"]
                if "adjoint_rtol" not in kwargs:
                    kwargs["adjoint_rtol"] = kwargs["rtol"]
            if 'method' not in kwargs:
                kwargs['method'] = 'rk4'
            if kwargs['method'] == 'rk4':
                if 'options' not in kwargs:
                    kwargs['options'] = {}
                options = kwargs['options']
                if 'step_size' not in options and 'grid_constructor' not in options:
                    time_diffs = 0.1
                    options['step_size'] = time_diffs
            
            odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
            z_T = odeint(func=self.func, y0=z0, t=times, **kwargs)
            
        look_window = times.shape[0]
        if args.model in ODE_BASED:
            
            z_T = z_T.permute(1,0,2)
            z_T =  z_T[:,look_window-self.forecast_window:,:]
            pred_y = self.readout(z_T) 
        
        return pred_y,z_T[:,self.input_channels:]



def main(model_name=args.model,num_epochs=args.epoch):
    manual_seed = args.seed
    static_intensity=False
    time_intensity=True
    batch_size=256
    np.random.seed(manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    torch.cuda.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)
    torch.random.manual_seed(manual_seed)
    print(f"Setting of this Experiments{args}")
    device="cuda"
    
    
    look_window = args.look_window 
    forecast_window=args.forecast_window
    stride_window = args.stride_window
    loc = datasets.ushcn.get_data(look_window,forecast_window,stride_window)

    here = pathlib.Path(__file__).resolve().parent
    base_base_loc = here / 'datasets/processed_data'
    loc = base_base_loc / ('USHCN' + '_look_'+str(look_window)+'_forecast_'+str(forecast_window)+'_stride_'+str(stride_window))
    
    if args.interpolation=='natural_cubic':
        coeff_loc = loc / ('NaturalCoeffs')
    else:    
        coeff_loc = loc / ('Coeffs')
    
    times        = torch.load(str(loc)+'/times.pt')
   
    train_y       =  torch.load(str(loc) +'/train_y_seq_data.pt')
    val_y       =  torch.load(str(loc) +'/val_y_seq_data.pt')
    test_y       =  torch.load(str(loc) +'/test_y_seq_data.pt')
    
    train_X = torch.load(str(loc) +'/train_seq_data.pt')
    val_X = torch.load(str(loc) +'/val_seq_data.pt')
    test_X = torch.load(str(loc) +'/test_seq_data.pt')
    
    
    augmented_train_X = []
    if time_intensity:
        augmented_train_X.append(times.unsqueeze(0).repeat(train_X.size(0), 1).unsqueeze(-1))
    if static_intensity:
        intensity = ~torch.isnan(train_X)  # of size (batch, stream, channels)
        intensity = intensity.to(train_X.dtype).cumsum(dim=1)
        augmented_train_X.append(intensity)
    augmented_train_X.append(train_X)
    if len(augmented_train_X) == 1:
        train_X = augmented_train_X[0]
    else:
        train_X = torch.cat(augmented_train_X, dim=2)
    augmented_val_X = []
    if time_intensity:
        augmented_val_X.append(times.unsqueeze(0).repeat(val_X.size(0), 1).unsqueeze(-1))
    if static_intensity:
        intensity = ~torch.isnan(val_X)  # of size (batch, stream, channels)
        intensity = intensity.to(val_X.dtype).cumsum(dim=1)
        augmented_val_X.append(intensity)
    augmented_val_X.append(val_X)
    if len(augmented_val_X) == 1:
        val_X = augmented_val_X[0]
    else:
        val_X = torch.cat(augmented_val_X, dim=2)
    augmented_test_X = []
    if time_intensity:
        augmented_test_X.append(times.unsqueeze(0).repeat(test_X.size(0), 1).unsqueeze(-1))
    if static_intensity:
        intensity = ~torch.isnan(test_X)  # of size (batch, stream, channels)
        intensity = intensity.to(test_X.dtype).cumsum(dim=1)
        augmented_test_X.append(intensity)
    augmented_test_X.append(test_X)
    if len(augmented_test_X) == 1:
        test_X = augmented_test_X[0]
    else:
        test_X = torch.cat(augmented_test_X, dim=2)

    if static_intensity:
        input_channels = train_X.shape[-1]
    else:
        input_channels = train_X.shape[-1]
    
    if os.path.exists(coeff_loc):
        
        pass
    else:
        if not os.path.exists(base_base_loc):
            os.mkdir(base_base_loc)
        if not os.path.exists(coeff_loc):
            os.mkdir(coeff_loc)
        if not os.path.exists(coeff_loc):
            os.mkdir(coeff_loc)
        
        
        if args.interpolation =='natural_cubic':
            print("Start extrapolation!")
            train_coeffs = torchcde.natural_cubic_coeffs(train_X)
            
            torch.save(train_coeffs,str(coeff_loc)+'/train_coeffs.pt')
            print("finish extrapolation Train coeff")
            val_coeffs = torchcde.natural_cubic_coeffs(val_X)
            
            
            torch.save(val_coeffs,str(coeff_loc)+'/val_coeffs.pt')
            print("finish extrapolation Val coeff")
            test_coeffs = torchcde.natural_cubic_coeffs(test_X)
            torch.save(test_coeffs,str(coeff_loc)+'/test_coeffs.pt')
            print("finish extrapolation Test coeff")
            print("success!")
        else:
            print("Start extrapolation!")
            train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
            
            torch.save(train_coeffs,str(coeff_loc)+'/train_coeffs.pt')
            print("finish extrapolation Train coeff")
            val_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(val_X)
            
            
            torch.save(val_coeffs,str(coeff_loc)+'/val_coeffs.pt')
            print("finish extrapolation Val coeff")
            test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
            torch.save(test_coeffs,str(coeff_loc)+'/test_coeffs.pt')
            print("finish extrapolation Test coeff")
            print("success!")
        
    
    
    train_coeffs = torch.load(str(coeff_loc)+'/train_coeffs.pt')
    val_coeffs = torch.load(str(coeff_loc)+'/val_coeffs.pt')
    test_coeffs = torch.load(str(coeff_loc)+'/test_coeffs.pt')
    
    
    train_coeffs=train_coeffs.to(device)
    val_coeffs=val_coeffs.to(device)
    test_coeffs=test_coeffs.to(device)
    train_y = torch.Tensor(np.nan_to_num(np.array(train_y.cpu())))
    train_y = train_y.to(device)
    val_y = torch.Tensor(np.nan_to_num(np.array(val_y.cpu())))
    val_y = val_y.to(device)
    test_y = torch.Tensor(np.nan_to_num(np.array(test_y.cpu())))
    test_y = test_y.to(device)
    output_channels = train_y.shape[-1]
    hidden_channels=args.h_channels

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    val_dataset = torch.utils.data.TensorDataset(val_coeffs, val_y)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_y.shape[0])
    test_dataset = torch.utils.data.TensorDataset(test_coeffs, test_y)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=test_y.shape[0])
    experiment_id = int(SystemRandom().random()*100000)
    file_path =os.path.dirname(os.path.abspath(__file__)) + "/USHCN"
    print(f"Past hidden and dh_dt are saved at {file_path} and Experiment id is : {experiment_id}")
    
    # import pdb ; pdb.set_trace()
    if model_name =='gru_d':
        model = GRU_D_forecasting(input_channels=input_channels, hidden_channels=hidden_channels,
                                 output_channels=output_channels, forecast_window=forecast_window,device=device)
    elif model_name == 'dt':
        model = GRU_dt_forecasting(input_channels=input_channels, hidden_channels=hidden_channels,
                                  output_channels=output_channels, forecast_window = forecast_window,device=device)
    elif model_name =='node':
        model = NeuralODE_forecasting(input_channels=input_channels, hidden_channels=hidden_channels, output_channels=output_channels,forecast_window=forecast_window,times=times,device=device)
    elif model_name =='odernn':
        model = ODERNN_forecasting(input_channels=input_channels,forecast_window = forecast_window, hidden_channels=hidden_channels,
                                  hidden_hidden_channels=128, num_hidden_layers=2,
                                  output_channels=output_channels, device=device)
    else:
        model = NeuralCDE_forecasting(input_channels=input_channels, hidden_channels=hidden_channels, output_channels=output_channels,forecast_window=forecast_window,device=device,
        file_path = file_path,rnd = experiment_id,alpha=args.alpha,beta=args.beta,interpolation=args.interpolation)
    model=model.to(device)
    times=times.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.weight_decay)
    best_val_mse=np.inf
    loss_fn = torch.nn.MSELoss()
    
    plateau_terminate = 50
    best_train_loss_epoch=0
    
    breaking=False
    best_train_mse = np.inf
    tqdm_range = tqdm.tqdm(range(num_epochs))
    tqdm_range.write('Starting training for model:\n\n' + str(model) + '\n\n')
    if device != 'cpu':
        torch.cuda.reset_max_memory_allocated(device)
        baseline_memory = torch.cuda.memory_allocated(device)
    else:
        baseline_memory = None   
    for epoch in tqdm_range:
        if breaking :
            break
        model.train()
        
        start_time= time.time()
        total_dataset_size = 0
        train_mse=0
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            if model_name in GRUS:
                pred_y = model(batch_coeffs)
            else:
                pred_y,z_t = model(batch_coeffs,times)
            b_size = batch_y.size(0)
            # import pdb ; pdb.set_trace()
            loss = loss_fn(pred_y,batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_dataset_size += b_size
            train_mse += loss * b_size
        train_mse /= total_dataset_size    
        if train_mse * 1.0001 < best_train_mse:
            best_train_mse = train_mse
            best_train_loss_epoch = epoch       
        
        print('Epoch: {}  Training MSE: {}, Time :{}'.format(epoch,train_mse,(time.time()-start_time)))
        memory_usage = torch.cuda.max_memory_allocated(device) - baseline_memory
        print(f"memory_usage:{memory_usage}")
        model.eval()
        # val_pred_y=torch.Tensor().to(device)
        # val_true_y=torch.Tensor().to(device)
        
        start_time= time.time()
        total_dataset_size = 0
        val_mse=0
        for batch in val_dataloader:
            batch_coeffs, batch_y = batch
            if model_name in GRUS:
                # import pdb ; pdb.set_trace()
                pred_y = model(batch_coeffs)
            else:
                pred_y,z_t = model(batch_coeffs,times)
            b_size = batch_y.size(0)
            loss = loss_fn(pred_y,batch_y)
            total_dataset_size +=b_size
            val_mse += loss * b_size
        val_mse /= total_dataset_size    
        
    
        print('Epoch: {}   Validation MSE: {}, Time :{}'.format(epoch, val_mse,(time.time()-start_time)))
        memory_usage = torch.cuda.max_memory_allocated(device) - baseline_memory
        print(f"memory_usage:{memory_usage}")
        
        start_time= time.time()
        total_dataset_size = 0
        test_mse=0
        for batch in test_dataloader:
            batch_coeffs, batch_y = batch
            if model_name in GRUS:
                pred_y = model(batch_coeffs)
            else:
                
                pred_y,z_t = model(batch_coeffs,times)
            b_size = batch_y.size(0)
            loss = loss_fn(pred_y,batch_y)
            total_dataset_size +=b_size
            test_mse += loss * b_size
        test_mse /= total_dataset_size    
            
        if val_mse<best_val_mse:
            best_val_mse=val_mse
            print('Epoch: {}   Best Test MSE: {}, Time :{}\n'.format(epoch, test_mse,(time.time()-start_time)))
            MODEL_PATH = os.path.dirname(os.path.abspath(__file__))+'/trained_model/contgru_aug/USHCN/'
            if train_mse <= 0.24:
                ckpt_file = MODEL_PATH+"USHCN"+str(epoch)+"_model_MSE_"+str(val_mse)+".pth"
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                    }, ckpt_file)
                print(f"model_saved {ckpt_file}")
            
        else:
            print('Epoch: {}   Test MSE {}, Time :{}\n'.format(epoch, test_mse,(time.time()-start_time)))
        if epoch > best_train_loss_epoch + plateau_terminate:
                tqdm_range.write('Breaking because of no improvement in training loss for {} epochs.'
                                 ''.format(plateau_terminate))
                breaking = True



if __name__ == '__main__':
    main()
