######################
# So you want to train a Neural CDE model?
# Let's get started!
######################

import math
import torch
import torchcde
import datasets
import sklearn.model_selection
 
import torchdiffeq
from random import SystemRandom
import random
import numpy as np 
from parse import parse_args
args = parse_args()
import pathlib
import os 
import time 
import tqdm

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
CUBICS = ['natural_cubic','cubic']
BASELINES = ['node']
BASELINES2 =['gru_d','dt','odernn']
GRUCONT = ['contgru_delay_aug','gruode','contgru_delay']
ODE_BASED=['node']

def split_data(tensor, stratify):
    
    (train_tensor, testval_tensor,
     train_stratify, testval_stratify) = sklearn.model_selection.train_test_split(tensor, stratify,
                                                                                  train_size=0.7,
                                                                                  random_state=0,
                                                                                  shuffle=True,
                                                                                  stratify=stratify)

    val_tensor, test_tensor = sklearn.model_selection.train_test_split(testval_tensor,
                                                                       train_size=0.5,
                                                                       random_state=1,
                                                                       shuffle=True,
                                                                       stratify=testval_stratify)
    return train_tensor, val_tensor, test_tensor
def normalise_data(X,y):
    
    train_X, _, _ = split_data(X, y) 
    out = []
    for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)):
        train_Xi_nonan = train_Xi.masked_select(~torch.isnan(train_Xi))
        mean = train_Xi_nonan.mean()  # compute statistics using only training data.
        std = train_Xi_nonan.std()
        out.append((Xi - mean) / (std + 1e-5))
    out = torch.stack(out, dim=-1)
    return out




class ODEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(ODEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, hidden_channels)

    def forward(self, t, z):
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        z = z.tanh()
        return z




class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

class ContGruFunc_Delay(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels,file_path,rnd):
        super(ContGruFunc_Delay, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)
        self.file = file_path
        self.rnd = rnd
    def forward(self, t,x, h,dxdt):
    
        if t ==0:
            h_past = h
        else:
            
            h_past = torch.Tensor(np.load(self.file+"/h_past/h_past_"+str(self.rnd)+".npy")).to(h)
        
        r = self.W_r(x) + self.U_r(h_past)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h_past)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h_past)
        g = g0.tanh()
        h_ = torch.mul(z,h_past) + torch.mul((1-z),g) # save h at t 
        # Save h at t 
        np.save(self.file+'/h_past/h_past_'+str(self.rnd)+'.npy',h_.cpu().detach().numpy())

        hg = h_past - g
        
        if t==0:
            dhpast_dt = (1 - z) * (g - h)
        else:
            
            dhpast_dt = torch.Tensor(np.load(self.file+"/dhpastdt/dhpastdt_"+str(self.rnd)+".npy")).to(h)
        
        control_gradient = dxdt.derivative(t)  # 256,28
        
        dAdt =((self.W_z.weight @ control_gradient.unsqueeze(-1)) + (self.U_z.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1) # dAdt = 10,49,1
        
        dzdt =torch.mul(torch.mul(z,(1-z)),dAdt)
        
        drdt = torch.mul(torch.mul(r,(1-r)),((self.W_r.weight @ control_gradient.unsqueeze(-1))+(self.U_r.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1)) #drdt : 10,49
        dBdt =(self.W_h.weight @ control_gradient.unsqueeze(-1)).squeeze(-1) + torch.mul((self.U_h.weight@drdt.unsqueeze(-1)).squeeze(-1),h) +torch.mul((self.U_h.weight@r.unsqueeze(-1)).squeeze(-1),dhpast_dt)
        
        dgdt = torch.mul(torch.mul((1-g),(1+g)),dBdt)
        
        dhgdt = dhpast_dt - dgdt
        
        dhdt = torch.mul(dzdt,hg) + torch.mul(z,dhgdt) + dgdt 
        np.save(self.file+'/dhpastdt/dhpastdt_'+str(self.rnd)+'.npy',dhdt.cpu().detach().numpy())
        
        
        return dhdt


class DisGruFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(DisGruFunc, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)

    def forward(self, t,x, h,dxdt):
        # TODO  what is x 
        r = self.W_r(x) + self.U_r(h)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h)
        g = g0.tanh()
        g02=2*g0
        
        dh_dt = (1 - z) * (g - h)

        return dh_dt

class ContGruFunc_Aug(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels,file_path,rnd):
        super(ContGruFunc_Aug, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.W_r = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_z = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.W_h = torch.nn.Linear(input_channels, hidden_channels, bias=False)
        self.U_r = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_z = torch.nn.Linear(hidden_channels, hidden_channels)
        self.U_h = torch.nn.Linear(hidden_channels, hidden_channels)
        self.file =file_path
        self.rnd = rnd
    def forward(self, t,x, h,dxdt):
        if t ==0:
            h_past = h
        else:
            
            h_past = torch.Tensor(np.load(self.file+"/h_past/h_past_"+str(self.rnd)+".npy")).to(h)
        
        r = self.W_r(x) + self.U_r(h_past)
        r = r.sigmoid()
        z = self.W_z(x) + self.U_z(h_past)
        z = z.sigmoid()
        g0 = self.W_h(x) + self.U_h(r * h_past)
        g = g0.tanh()
        h_ = torch.mul(z,h_past) + torch.mul((1-z),g) # save h at t 
        # Save h at t 
        np.save(self.file+'/h_past/h_past_'+str(self.rnd)+'.npy',h_.cpu().detach().numpy())

        hg = h_past - g
        
        if t==0:
            dhpast_dt = (1 - z) * (g - h)
        else:
            
            dhpast_dt = torch.Tensor(np.load(self.file+"/dhpastdt/dhpastdt_"+str(self.rnd)+".npy")).to(h)
        
        control_gradient = dxdt.derivative(t)  # 256,28
        
        dAdt =((self.W_z.weight @ control_gradient.unsqueeze(-1)) + (self.U_z.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1) # dAdt = 10,49,1
        
        dzdt =torch.mul(torch.mul(z,(1-z)),dAdt)
        
        drdt = torch.mul(torch.mul(r,(1-r)),((self.W_r.weight @ control_gradient.unsqueeze(-1))+(self.U_r.weight@dhpast_dt.unsqueeze(-1))).squeeze(-1)) #drdt : 10,49
        dBdt =(self.W_h.weight @ control_gradient.unsqueeze(-1)).squeeze(-1) + torch.mul((self.U_h.weight@drdt.unsqueeze(-1)).squeeze(-1),h) +torch.mul((self.U_h.weight@r.unsqueeze(-1)).squeeze(-1),dhpast_dt)
        
        dgdt = torch.mul(torch.mul((1-g),(1+g)),dBdt)
        
        dhgdt = dhpast_dt - dgdt
        
        dhdt = torch.mul(dzdt,hg) + torch.mul(z,dhgdt) + dgdt 
        np.save(self.file+'/dhpastdt/dhpastdt_'+str(self.rnd)+'.npy',dhdt.cpu().detach().numpy())
        return dhdt,drdt,dgdt,dzdt

class _ODERNNFunc(torch.nn.Module):
    def __init__(self, hidden_channels, hidden_hidden_channels, num_hidden_layers):
        super(_ODERNNFunc, self).__init__()
        layers = [torch.nn.Linear(hidden_channels, hidden_hidden_channels)]
        for _ in range(num_hidden_layers - 1):
            layers.append(torch.nn.Tanh())
            layers.append(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels))
        layers.append(torch.nn.Tanh())
        layers.append(torch.nn.Linear(hidden_hidden_channels, hidden_channels))
        self.sequential = torch.nn.Sequential(*layers)

    def forward(self, t, x):
        return self.sequential(x)


class ContinuousRNNConverter(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, model):
        super(ContinuousRNNConverter, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.model = model
        self.linear = torch.nn.Linear(self.hidden_channels,self.input_channels+self.hidden_channels)
        
    

    def forward(self,t, z,dxdt):
        
        x = z[..., : self.input_channels]
        h = z[..., self.input_channels :]
        h = h.clamp(-1, 1)
        model_out = self.model(t,x, h,dxdt) # 1024,49
        out = self.linear(model_out)
        
        return model_out,out



class ContinuousRNNConverter_Aug(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, model):
        super(ContinuousRNNConverter_Aug, self).__init__()

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.model = model
        self.linear = torch.nn.Linear(self.hidden_channels,self.input_channels+self.hidden_channels)
        out_base = torch.zeros(
            self.input_channels + self.hidden_channels, self.input_channels
        )
        for i in range(self.input_channels):
            out_base[i, i] = 1   
        self.register_buffer("out_base", out_base)

    

    def forward(self,t, z,dxdt):
        x = z[..., : self.input_channels]
        h = z[..., self.input_channels :]
        h = h.clamp(-1, 1)
        model_out,model_r,model_g,model_z = self.model(t,x, h,dxdt)
        out = self.linear(model_out)
        return model_out,out,model_r,model_g,model_z # 256,49 256,55






def GRU_ODE_Aug(input_channels, hidden_channels,file_path,rnd):
    
    func = ContGruFunc_Aug(input_channels=input_channels, hidden_channels=hidden_channels,file_path=file_path,rnd=rnd)
    return ContinuousRNNConverter_Aug(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)
def GRU_ODE(input_channels, hidden_channels):
    
    func =DisGruFunc(input_channels=input_channels, hidden_channels=hidden_channels)
    return ContinuousRNNConverter(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)

def GRU_ODE_Delay(input_channels, hidden_channels,file_path,rnd):
    
    func = ContGruFunc_Delay(input_channels=input_channels, hidden_channels=hidden_channels,file_path=file_path,rnd=rnd)
    return ContinuousRNNConverter(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)


def DisGRU_ODE(input_channels, hidden_channels):
    
    func = DisGruFunc(input_channels=input_channels, hidden_channels=hidden_channels)
    return ContinuousRNNConverter(input_channels=input_channels,
                                            hidden_channels=hidden_channels,
                                            model=func)

class _GRU(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, use_intensity):
        super(_GRU, self).__init__()

        assert (input_channels % 2) == 1, "Input channels must be odd: 1 for time, plus 1 for each actual input, " \
                                          "plus 1 for whether an observation was made for the actual input."

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels
        self.use_intensity = use_intensity

        gru_channels = input_channels if use_intensity else (input_channels - 1) // 2
        self.gru_cell = torch.nn.GRUCell(input_size=gru_channels, hidden_size=hidden_channels)
        self.linear = torch.nn.Linear(hidden_channels, output_channels)

    def extra_repr(self):
        return "input_channels={}, hidden_channels={}, output_channels={}, use_intensity={}" \
               "".format(self.input_channels, self.hidden_channels, self.output_channels, self.use_intensity)

    def evolve(self, h, time_diff):
        raise NotImplementedError

    def _step(self, Xi, h, dt, half_num_channels):
        # import pdb ; pdb.set_trace()
        observation = Xi[:, 1: 1 + half_num_channels].max(dim=1).values > 0.5
        if observation.any():
            Xi_piece = Xi if self.use_intensity else Xi[:, 1 + half_num_channels:]
            Xi_piece = Xi_piece.clone()
            Xi_piece[:, 0] += dt
            new_h = self.gru_cell(Xi_piece.float(), h.float())
            h = torch.where(observation.unsqueeze(1), new_h, h)
            dt += torch.where(observation, torch.tensor(0., dtype=Xi.dtype, device=Xi.device), Xi[:, 0])
        return h, dt

    def forward(self,coeffs, times,   z0=None):
        
        
        interp = torchcde.CubicSpline(coeffs)
        X = torch.stack([interp.evaluate(t) for t in times], dim=-2)
        half_num_channels = (self.input_channels - 1) // 2

        X[:, 1:, 1:1 + half_num_channels] -= X[:, :-1, 1:1 + half_num_channels]

        
        X[:, 0, 0] -= times[0]
        X[:, 1:, 0] -= times[:-1]

        batch_dims = X.shape[:-2]

        if z0 is None:
            z0 = torch.zeros(*batch_dims, self.hidden_channels, dtype=X.dtype, device=X.device)

        X_unbound = X.unbind(dim=1)
        h, dt = self._step(X_unbound[0], z0, torch.zeros(*batch_dims, dtype=X.dtype, device=X.device),
                           half_num_channels)
        hs = [h]
        time_diffs = times[1:] - times[:-1]
        for time_diff, Xi in zip(time_diffs, X_unbound[1:]):
            h = self.evolve(h, time_diff)
            h, dt = self._step(Xi, h, dt, half_num_channels)
            hs.append(h)
        out = torch.stack(hs, dim=1)

        
        return self.linear(out[:,-1,:])
class GRU_dt(_GRU):
    def evolve(self, h, time_diff):
        return h
class GRU_D(_GRU):
    def __init__(self, input_channels, hidden_channels, output_channels, use_intensity):
        super(GRU_D, self).__init__(input_channels=input_channels,
                                    hidden_channels=hidden_channels,
                                    output_channels=output_channels,
                                    use_intensity=use_intensity)
        self.decay = torch.nn.Linear(1, hidden_channels)

    def evolve(self, h, time_diff):
        return h * torch.exp(-self.decay(time_diff.unsqueeze(0)).squeeze(0).relu())

class ODERNN(_GRU):
    def __init__(self, func,input_channels, hidden_channels, output_channels, hidden_hidden_channels, num_hidden_layers,
                 use_intensity):
        super(ODERNN, self).__init__(input_channels=input_channels,
                                     hidden_channels=hidden_channels,
                                     output_channels=output_channels,
                                     use_intensity=use_intensity)
        self.hidden_hidden_channels = hidden_hidden_channels
        self.num_hidden_layers = num_hidden_layers

        self.func = func

    def extra_repr(self):
        return "hidden_hidden_channels={}, num_hidden_layers={}".format(self.hidden_hidden_channels,
                                                                        self.num_hidden_layers)

    def evolve(self, h, time_diff):
        t = torch.tensor([0, time_diff.item()], dtype=time_diff.dtype, device=time_diff.device)
        out = torchdiffeq.odeint_adjoint(func=self.func, y0=h, t=t, method='rk4')
        return out[1]


class NeuralCDE(torch.nn.Module):
    def __init__(self,func, input_channels, hidden_channels, output_channels,device,file_path,rnd,alpha,beta, interpolation="cubic"):
        super(NeuralCDE, self).__init__()
        self.interpolation = interpolation
        if self.interpolation =='linear':
            input_channels = input_channels*4
        self.func = func
        if args.model=='ncde':
            
            self.readout = torch.nn.Linear(hidden_channels, output_channels)
        if args.model =='contgru_delay':
            self.readout = torch.nn.Linear(hidden_channels,output_channels)
            self.readout2 = torch.nn.Linear(hidden_channels+input_channels,output_channels)
        if args.model=='contgru_delay_aug':
            self.readout = torch.nn.Linear(hidden_channels,output_channels)
            self.readout2 = torch.nn.Linear(hidden_channels+input_channels,output_channels)
        if args.model =='gruode':
            
            self.readout = torch.nn.Linear(hidden_channels,output_channels)
            self.readout2 = torch.nn.Linear(hidden_channels+input_channels,output_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.input_channels = input_channels
        self.device=device
        self.alpha = alpha
        self.beta = beta

    def forward(self,args, coeffs,adjoint=True,**kwargs):
        
        if self.interpolation in CUBICS:
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
        
        batch_dims = coeffs.shape[:-2]
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        
        if args.model=='contgru_delay' :
            
            z0_extra = torch.zeros(
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0
            z0 = torch.cat([z0_extra, z0], dim=-1)

            h_T,z_T = torchcde.contint_delay(X=X,
                                  z0=z0,
                                  h0=h0,
                                  func=self.func,
                                  t=X.interval,
                                  device=self.device)
            
            
        if args.model=='gruode' :
            z0_extra = torch.zeros( 
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0
            z0 = torch.cat([z0_extra, z0], dim=-1)

            h_T,z_T = torchcde.contint(X=X,
                                  z0=z0,
                                  h0=h0,
                                  input_channels=self.input_channels,
                                  func=self.func,
                                  t=X.interval,
                                  device=self.device)
        
        if args.model=='contgru_delay_aug' :
            z0_extra = torch.zeros(
                        *batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device
                    )
            h0 = z0
            z0 = torch.cat([z0_extra, z0], dim=-1)

            h_T,z_T = torchcde.contint_aug(X=X,
                                  z0=z0,
                                  h0=h0,
                                  func=self.func,
                                  t=X.interval,
                                  method = args.solver_method,
                                  device=self.device)
                                  
        if args.model =='ncde':
            z_T = torchcde.cdeint(X=X,
                                z0=z0,
                                func=self.func,
                                t=X.interval)
       
         
        z_T = z_T[:,1]
        if args.model in GRUCONT:
            
            h_T = h_T[:,1]
            pred_y1 = self.readout(h_T)
            pred_y2 = self.readout2(z_T)

            pred_y = (self.alpha* pred_y1) + (self.beta * pred_y2)
        if args.model =='ncde':
            pred_y = self.readout(z_T)
        return pred_y #,z_T[:,self.input_channels:]


class NeuralODE(torch.nn.Module):
    def __init__(self,func, input_channels, hidden_channels, output_channels,times,device,mask='False', interpolation="cubic"):
        super(NeuralODE, self).__init__()
        self.func = func
        if args.model in ODE_BASED:
            
            self.readout = torch.nn.Linear(hidden_channels, output_channels)
        
        
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.interpolation = interpolation
        self.input_channels = input_channels
        self.device=device
        self.times=times
        
        
    def forward(self, coeffs,times,adjoint=True,**kwargs):
        # import pdb ; pdb.set_trace()
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError("Only 'linear' and 'cubic' interpolation methods are implemented.")
        
        batch_dims = coeffs.shape[:-2]
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        if args.model in ODE_BASED:
            if 'atol' not in kwargs:
                kwargs['atol'] = 1e-6
            if 'rtol' not in kwargs:
                kwargs['rtol'] = 1e-4
            if adjoint:
                if "adjoint_atol" not in kwargs:
                    kwargs["adjoint_atol"] = kwargs["atol"]
                if "adjoint_rtol" not in kwargs:
                    kwargs["adjoint_rtol"] = kwargs["rtol"]
            if 'method' not in kwargs:
                kwargs['method'] = 'rk4'
            if kwargs['method'] == 'rk4':
                if 'options' not in kwargs:
                    kwargs['options'] = {}
                options = kwargs['options']
                if 'step_size' not in options and 'grid_constructor' not in options:
                    time_diffs = 0.1
                    options['step_size'] = time_diffs
            
            odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
            z_T = odeint(func=self.func, y0=z0, t=times, **kwargs)
            
       
       
        
        if args.model in ODE_BASED:
            z_T = z_T.permute(1,0,2)
            z_T = z_T[:,1]
            
            pred_y = self.readout(z_T) 
            
        
        return pred_y,z_T



def main(num_epochs=args.epoch):
    
    manual_seed = args.seed
    static_intensity=True
    time_intensity=True
    batch_size=1024
    np.random.seed(manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    torch.cuda.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)
    torch.random.manual_seed(manual_seed)
    args.intensity = static_intensity
    print(f"Setting of this Experiments{args}")
    device="cuda"
    
    
    
    loc = datasets.sepsis.get_data(static_intensity,time_intensity,batch_size)

    here = pathlib.Path(__file__).resolve().parent
    base_base_loc = here / 'datasets/processed_data'
    loc = base_base_loc / ('Newsepsis' + ('_intensity' if static_intensity else '_nointensity') + ('_time' if time_intensity else '_notime'))
    if args.interpolation=='natural_cubic':
        coeff_loc = loc / ('NaturalCoeffs')
    else:    
        coeff_loc = loc / ('Coeffs')
    
    
    times        = torch.load(str(loc)+'/times.pt')
   
    full_y       =  torch.load(str(loc) +'/full_y.pt')
    train_y, val_y, test_y = split_data(full_y, full_y)
    full_X       = torch.load(str(loc)+'/full_X.pt')

    if static_intensity:
        X_static =  torch.load(str(loc) +'/X_static.pt')
    
    X = normalise_data(full_X,full_y)
    
    augmented_X = []
    if time_intensity:
        augmented_X.append(times.unsqueeze(0).repeat(X.size(0), 1).unsqueeze(-1))
    if static_intensity:
        intensity = ~torch.isnan(X)  # of size (batch, stream, channels)
        intensity = intensity.to(X.dtype).cumsum(dim=1)
        augmented_X.append(intensity)
    augmented_X.append(X)
    if len(augmented_X) == 1:
        X = augmented_X[0]
    else:
        X = torch.cat(augmented_X, dim=2)
    

    if static_intensity:
        input_channels = X.shape[-1]
    else:
        input_channels = X.shape[-1]
    
    if os.path.exists(coeff_loc):
        
        pass
    else:
        if not os.path.exists(base_base_loc):
            os.mkdir(base_base_loc)
        if not os.path.exists(coeff_loc):
            os.mkdir(coeff_loc)
        if not os.path.exists(coeff_loc):
            os.mkdir(coeff_loc)
        
        
        train_X, val_X, test_X = split_data(X, full_y)
        train_X_static, val_X_static, test_X_static = split_data(X_static, full_y)
        if args.interpolation =='natural_cubic':
            print("Start extrapolation!")
            train_coeffs = torchcde.natural_cubic_coeffs(train_X)
            
            torch.save(train_coeffs,str(coeff_loc)+'/train_coeffs.pt')
            print("finish extrapolation Train coeff")
            val_coeffs = torchcde.natural_cubic_coeffs(val_X)
            
            
            torch.save(val_coeffs,str(coeff_loc)+'/val_coeffs.pt')
            print("finish extrapolation Val coeff")
            test_coeffs = torchcde.natural_cubic_coeffs(test_X)
            torch.save(test_coeffs,str(coeff_loc)+'/test_coeffs.pt')
            print("finish extrapolation Test coeff")
            print("success!")
        else:
            print("Start extrapolation!")
            train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
            
            torch.save(train_coeffs,str(coeff_loc)+'/train_coeffs.pt')
            print("finish extrapolation Train coeff")
            val_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(val_X)
            
            
            torch.save(val_coeffs,str(coeff_loc)+'/val_coeffs.pt')
            print("finish extrapolation Val coeff")
            test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
            torch.save(test_coeffs,str(coeff_loc)+'/test_coeffs.pt')
            print("finish extrapolation Test coeff")
            print("success!")
       
    train_coeffs = torch.load(str(coeff_loc)+'/train_coeffs.pt')
    val_coeffs = torch.load(str(coeff_loc)+'/val_coeffs.pt')
    test_coeffs = torch.load(str(coeff_loc)+'/test_coeffs.pt')
    
    
    train_coeffs=train_coeffs.to(device)
    val_coeffs=val_coeffs.to(device)
    test_coeffs=test_coeffs.to(device)
    train_y = train_y.to(device)
    val_y = val_y.to(device)
    test_y = test_y.to(device)
    hidden_channels=args.h_channels
    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    experiment_id = int(SystemRandom().random()*100000)
    file_path = os.path.dirname(os.path.abspath(__file__))+"/SEPSIS/"
    print(f"Past hidden and dh_dt are saved at {file_path} and Experiment id is : {experiment_id}")
    
    if args.model=='ncde':
        func = CDEFunc(input_channels, hidden_channels)
        
    if args.model =='contgru_delay':
        func=GRU_ODE_Delay(input_channels,hidden_channels,file_path=file_path,rnd=experiment_id)
        
    if args.model=='contgru_delay_aug':
        
        func = GRU_ODE_Aug(input_channels,hidden_channels,file_path=file_path,rnd=experiment_id)
        
    if args.model =='gruode':
        func = DisGRU_ODE(input_channels,hidden_channels)
    if args.model =='node':
        func = ODEFunc(input_channels,hidden_channels)
    if args.model =='odernn':
        func = _ODERNNFunc(hidden_channels, 128, 2)
    if args.model =='odernn':
        model = ODERNN(func=func,input_channels=input_channels, hidden_channels=hidden_channels,
                                  hidden_hidden_channels=128, num_hidden_layers=2,
                                  output_channels=1, use_intensity=True)
    elif args.model =='dt':
        model = GRU_dt(input_channels=input_channels, hidden_channels=hidden_channels,
                                  output_channels=1, use_intensity=True)
    elif args.model =='gru_d':
        model = GRU_D(input_channels=input_channels, hidden_channels=hidden_channels,
                                 output_channels=1, use_intensity=True)
    elif args.model =='node':
        
        model = NeuralODE(func=func,input_channels=input_channels, hidden_channels=hidden_channels, output_channels=1,times=times,device=device)
    
    else:    
        model = NeuralCDE(func=func,input_channels=input_channels, hidden_channels=hidden_channels, output_channels=1,
        file_path=file_path,rnd = experiment_id,alpha=args.alpha,beta = args.beta,device=device)
    model=model.to(device)
    times=times.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.weight_decay)
    
    plateau_terminate = 50
    best_train_loss_epoch=0
    breaking = False
    tqdm_range = tqdm.tqdm(range(num_epochs))
    tqdm_range.write('Starting training for model:\n\n' + str(model) + '\n\n')
    best_val_auroc=0
    if device != 'cpu':
        torch.cuda.reset_max_memory_allocated(device)
        baseline_memory = torch.cuda.memory_allocated(device)
    else:
        baseline_memory = None   
    for epoch in tqdm_range:
        if breaking :
            break
        model.train()
        full_pred_y=torch.Tensor().to(device)
        full_true_y=torch.Tensor().to(device)
        start_time= time.time()
        total_accuracy = 0
        total_dataset_size = 0
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            if breaking :
                break
            batch_size = batch_y.shape[0]
            if args.model in BASELINES:
                
                pred_y,z_t = model(batch_coeffs,times)

            elif args.model in BASELINES2:
                pred_y = model(batch_coeffs,times)
            else:
                pred_y = model(args,batch_coeffs)
            pred_y=pred_y.squeeze(-1)

            binary_prediction = (pred_y>0).to(test_y.dtype)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            total_accuracy += (binary_prediction == batch_y).sum().to(train_y.dtype)
            total_dataset_size += batch_size
            full_pred_y=torch.cat([full_pred_y,binary_prediction])
            full_true_y=torch.cat([full_true_y,batch_y])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        if (full_pred_y.sum() ==0) or (full_pred_y.sum() == full_pred_y.shape[0]) or (full_true_y.sum() ==0) or (full_true_y.sum() == full_true_y.shape[0]): 
            train_auroc = 0.5
        else:
            train_auroc = sklearn.metrics.roc_auc_score(full_pred_y.cpu(), full_true_y.cpu())
        total_accuracy = total_accuracy.item()
        total_accuracy /= total_dataset_size
        print('Epoch: {}   Training loss: {} Training AUROC: {},Training ACC  : {} ,Time :{}'.format(epoch, loss.item(),train_auroc,total_accuracy,(time.time()-start_time)))
        
        val_dataset = torch.utils.data.TensorDataset(val_coeffs, val_y)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_y.shape[0])
        model.eval()
        val_pred_y=torch.Tensor().to(device)
        val_true_y=torch.Tensor().to(device)
        
        start_time= time.time()
        total_val_accuracy = 0
        total_dataset_size = 0
        total_loss = 0 
        for batch in val_dataloader:
            batch_coeffs, batch_y = batch
            batch_size = batch_y.shape[0]
            
            if args.model in BASELINES:
                pred_y,z_t = model(batch_coeffs,times)
            elif args.model in BASELINES2:
                pred_y = model(batch_coeffs,times)
            else:
                pred_y = model(args,batch_coeffs)
            pred_y=pred_y.squeeze(-1)
            
            binary_prediction = (pred_y>0).to(test_y.dtype)
            
            val_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            total_val_accuracy += (binary_prediction == batch_y).sum().to(val_y.dtype)
            total_dataset_size += batch_size
            total_loss += val_loss * batch_size
            val_pred_y=torch.cat([val_pred_y,binary_prediction])
            val_true_y=torch.cat([val_true_y,batch_y])


        total_val_accuracy = total_val_accuracy.item()
        total_val_accuracy /= total_dataset_size
        total_loss /=total_dataset_size
        if (val_pred_y.sum() ==0) or (val_pred_y.sum() == val_pred_y.shape[0]) or (val_true_y.sum() ==0) or (val_true_y.sum() == val_true_y.shape[0]): 
            
            print('Epoch: {}   Validation loss: {} Validation AUROC: 0.5, Validation ACC : {} Time :{}'.format(epoch, total_loss.item(),total_val_accuracy,(time.time()-start_time)))
            val_auroc = 0.5
        else:
            
            val_auroc = sklearn.metrics.roc_auc_score(val_pred_y.cpu(), val_true_y.cpu())
            print('Epoch: {}   Validation loss: {} Validation AUROC: {}, Validation ACC : {} , Time :{}'.format(epoch, total_loss.item(),val_auroc,total_val_accuracy,(time.time()-start_time)))
        
        test_dataset = torch.utils.data.TensorDataset(test_coeffs, test_y)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=test_y.shape[0])
        test_pred_y=torch.Tensor().to(device)
        test_true_y=torch.Tensor().to(device)
        total_test_accuracy = 0
        total_dataset_size = 0
        total_loss = 0 
        start_time= time.time()
        for batch in test_dataloader:
            batch_coeffs, batch_y = batch
            batch_size = batch_y.shape[0]
            if args.model in BASELINES:
                pred_y,z_t = model(batch_coeffs,times)
            elif args.model in BASELINES2:
                pred_y = model(batch_coeffs,times)
            else:
                pred_y = model(args,batch_coeffs)
            pred_y=pred_y.squeeze(-1)
            
            binary_prediction = (pred_y>0).to(test_y.dtype)
            
            test_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            total_loss += test_loss* batch_size
            total_test_accuracy += (binary_prediction == batch_y).sum().to(test_y.dtype)
            total_dataset_size += batch_size
            test_pred_y=torch.cat([test_pred_y,binary_prediction])
            test_true_y=torch.cat([test_true_y,batch_y])
        total_test_accuracy = total_test_accuracy.item()
        total_test_accuracy /= total_dataset_size
        total_loss /= total_dataset_size
        if val_auroc>best_val_auroc:
            best_val_auroc=val_auroc
            
            if (test_pred_y.sum() ==0) or (test_pred_y.sum() == test_pred_y.shape[0]) or (test_true_y.sum() ==0) or (test_true_y.sum() == test_true_y.shape[0]): 
        
                print('Epoch: {}   Best Test loss: {} Test AUROC: 0.5,Test ACC : {}, Time :{}\n'.format(epoch, total_loss.item(),total_test_accuracy,(time.time()-start_time)))
            else:

                test_auroc = sklearn.metrics.roc_auc_score(test_pred_y.cpu(), test_true_y.cpu())
                print('Epoch: {}   Best Test loss: {} Test AUROC: {}, Test ACC : {}, Time :{}\n'.format(epoch, total_loss.item(),test_auroc,total_test_accuracy,(time.time()-start_time)))
                MODEL_PATH = os.path.dirname(os.path.abspath(__file__))+'/trained_model/contgru_aug/Sepsis/'
                if val_auroc > 0.85:
                    ckpt_file = MODEL_PATH+"Sepsis"+str(epoch)+"_model_AUCROC_"+str(val_auroc)+".pth"
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss
                        }, ckpt_file)
                    print(f"model_saved {ckpt_file}")
        else:
            if (test_pred_y.sum() ==0) or (test_pred_y.sum() == test_pred_y.shape[0]) or (test_true_y.sum() ==0) or (test_true_y.sum() == test_true_y.shape[0]):  
                print('Epoch: {}   Test loss: {} Test AUROC: 0.5,Test ACC : {}, Time :{}\n'.format(epoch, total_loss.item(),total_test_accuracy,(time.time()-start_time)))
            else:

                test_auroc = sklearn.metrics.roc_auc_score(test_pred_y.cpu(), test_true_y.cpu())
                print('Epoch: {}   Test loss: {} Test AUROC: {}, Test ACC {}, Time :{}\n'.format(epoch, total_loss.item(),test_auroc,total_test_accuracy,(time.time()-start_time)))
                MODEL_PATH = os.path.dirname(os.path.abspath(__file__))+'/trained_model/contgru_aug/Sepsis/'
                if val_auroc > 0.88:
                    ckpt_file = MODEL_PATH+"0910_Sepsis"+str(epoch)+"_model_AUCROC_"+str(val_auroc)+".pth"
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': total_loss
                        }, ckpt_file)
                    print(f"model_saved {ckpt_file}")
        
        if epoch > best_train_loss_epoch + plateau_terminate:
                tqdm_range.write('Breaking because of no improvement in training loss for {} epochs.'
                                 ''.format(plateau_terminate))
                breaking = True
    


if __name__ == '__main__':
    main()
