#a demo
import torch
import torch.nn as nn
import pandas as pd
import torch.optim 
import numpy as np
import sys

from loss import eaftloss 
from c_index import C_index
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from sklearn.model_selection import RepeatedKFold


###########################################################################################################
###  RotGBSG Data 
path = '../data/gbsg.csv'
data = pd.read_csv(path)

EPOCH = 2000
rept = 1
betas = (0.91, 0.999)
Riemann_sum_gap = .5
integral_id = 0 
LR_decay = 0.0  

LR = 0.001
wt_decay = 0.5
layer_a = 2
neuron_n = 256 
dropout_pr = 0.0
rs=1
# 5-fold CV
kf = RepeatedKFold(n_splits = 5, n_repeats = 1,random_state=rs)

val_loss_min = np.zeros((7,10))
val_loss_percentile = np.zeros((7,10))

Test_c_index = np.zeros((7,10))
Time_mean_1 = np.zeros((7,10))
Time_std_1 = np.zeros((7,10))

val_loss_min = np.zeros((7,10))
val_loss_percentile = np.zeros((7,10))
Test_c_index = np.zeros((7,10))

Time_mean_1 = np.zeros((7,10))
Time_std_1 = np.zeros((7,10))

AP_mean = np.zeros((7,10))
AP_std = np.zeros((7,10))
AP_rate = np.zeros((7,10))
AFT_rate = np.zeros((7,10))
IBS=np.zeros((7,10))



cv_ind = -1
K = EPOCH
for train_index, test_index in kf.split(data):
    cv_ind = cv_ind + 1
    train_index, test_index =pd.Index(train_index), pd.Index(test_index)
    df_train, df_test = data.iloc[train_index,:], data.iloc[test_index]
    df_val = df_train.sample(frac=0.2,random_state=rs)
    df_train = df_train.drop(df_val.index)
    
    #####
    cols_standardize = ['3', '4', '5','6']
    cols_leave = ['0', '1', '2']

    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = [(col, None) for col in cols_leave]

    x_mapper = DataFrameMapper(standardize + leave)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
        
    delta_train = df_train['delta'].values.astype('float32')
    time_train = df_train['time'].values.astype('float32')

    delta_test = df_test['delta'].values.astype('float32')
    time_test = df_test['time'].values.astype('float32')

    delta_val = df_val['delta'].values.astype('float32')
    time_val = df_val['time'].values.astype('float32')

    x_train=torch.from_numpy(x_train)
    delta_train=torch.from_numpy(delta_train)
    time_train=torch.from_numpy(time_train)

    x_val=torch.from_numpy(x_val)
    delta_val=torch.from_numpy(delta_val)
    time_val=torch.from_numpy(time_val)

    x_test=torch.from_numpy(x_test)
    delta_test=torch.from_numpy(delta_test)
    time_test=torch.from_numpy(time_test)
    
    n_t, feature_p = x_train.size()                
    C_index_TEST1=[]
    
    
    kk = -1
    RES = np.zeros((4,6))
    
    for t_seed in range(rept):
        kk = kk + 1
        # neural network structure
        torch.manual_seed(t_seed) 
        class EAFTNet(nn.Module):
        
            def __init__(self, feature_p, neuron_n, layer_a,  dropout_pr):
                super(EAFTNet, self).__init__()
                self.p = feature_p
                self.a = layer_a
                self.n = neuron_n
                self.dp = dropout_pr
   
                self.f1 = nn.Linear(self.p, self.n)
                self.f2 = nn.Linear(self.n, self.n)

                self.g1 = nn.Linear(self.n, 1)
                self.g2 = nn.Linear(self.n, 2)
                self.dropout = nn.Dropout(p=self.dp)
        
            def forward(self, x):
                x = torch.selu(self.f1(x))
                x = self.dropout(x)
                
                for i in range(self.a):
                    x = torch.selu(self.f2(x))
                    x = self.dropout(x)
                
                # g1 for DeepAFT, g2 for DeepEH 
                out = self.g2(x)
                    
                return out
        
        
        net = EAFTNet(feature_p, neuron_n, layer_a,  dropout_pr)
        
        ### training neural network        
        Loss_train = torch.zeros(1,K)
        Loss_val = torch.zeros(1,K)
        c_val = []
        c_test = np.zeros(K)
        
        loss_temp = 1000.

        for t in range(K):
            learning_rate = LR/(1+t*LR_decay)
            optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate,\
                                                            betas = betas,  weight_decay = wt_decay)
            out = net(x_train)
            loss = eaftloss(out, time_train, delta_train)
            Loss_train[0,t] = loss 
            
            out_val = net(x_val)
            loss_val = eaftloss(out_val, time_val, delta_val)
            Loss_val[0,t] = loss_val
                
            #find minimum value of loss for validation set 
            if loss_val.item() <= loss_temp:
                tt = t
                out_train = out
                out_test = net(x_test)
                loss_temp = loss_val.item()
                
            
      
            if (torch.isnan(out[0,0]) or torch.isnan(out_val[0,0])):
                print('Break due to NaN at step:', t)
                break
            
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()

        if out.shape[1] == 1:
            print('DeepAFT model')              

        else:
            print('DeepEH model')
              
        c_index_test,IBS[cv_ind,kk] = C_index(out_train, time_train, \
                        delta_train, out_test, time_test, delta_test, Riemann_sum_gap, integral_id)        
        Test_c_index[cv_ind,kk] = c_index_test
            

Test_c_index[5,:] =np.mean(Test_c_index[range(5),:],0)
Test_c_index[6,:] =np.std(Test_c_index[range(5),:],0)

IBS[5,:] =np.mean(IBS[range(5),:],0)
IBS[6,:] =np.std(IBS[range(5),:],0)


print('integrated Brier Score')
print(IBS)
print('C-index:')
print(Test_c_index)




