from pylab import *
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sksurv.metrics import cumulative_dynamic_auc, brier_score
import time, sys, dill
from scipy.interpolate import interp1d

import rpy2.robjects as ro
from rpy2.robjects.conversion import localconverter, py2rpy,rpy2py
from rpy2.robjects import pandas2ri, numpy2ri
from rpy2.rinterface_lib.callbacks import logger as rpy2_logger
import logging
rpy2_logger.setLevel(logging.ERROR)

sys.path.append('../')
from mylib import survival_permanental_process as SPP
import tensorflow as tf

def load_data():
    ro.r('cwd <- getwd()')
    ro.r('library(survival)')
    ro.r('setwd("/Users/kim/Documents/research/Survival_PP/TimeVaryingData_LTRCforests/analysis/data/")')
    ro.r('source("pbc_complete.R")')
    ro.r('DATA = make_realset(0)')
    with localconverter(ro.default_converter + pandas2ri.converter):
        data = rpy2py(ro.r('DATA'))
    ro.r('setwd(cwd)')
    return data

cov_names = ['age','edema','alk.phos','chol','ast','platelet',
                 'spiders','hepato','ascites','albumin','bili','protime','ttt']

def main():
    n_split = 10
    score_t = array([1000,2000,3000,4000])
        
    # DATA SPLITTING #################################
    df = load_data()
    df['ttt'] = df['Stop']
    
    for c in cov_names:
        df[c] = 0.1*(df[c] - df[c].mean()) / df[c].std()
    
    kf = KFold(n_splits=n_split, shuffle=True, random_state=0)
    df_train, df_test = [], []
    for indx_train, indx_test in kf.split(unique(df['ID'].to_numpy())):
        df_train.append(df[df['ID'].isin(indx_train)])
        df_test.append(df[df['ID'].isin(indx_test)])

    # START ESTIMATION & PREDICTION ##################
    score = {x:[] for x in ['auc','cpu','bri','tll']}
    score['t'] = score_t
    for df_tr, df_te in zip(df_train,df_test):
        auc, bri, tll, cpu = estimation_spp(df_tr,df_te,score_t)
        score['auc'].append(auc)
        score['bri'].append(bri)
        score['tll'].append(tll)
        score['cpu'].append(cpu)
    print(score['auc'])
    dill.dump(score, open('result/spp_pbc.dill','wb'))

def estimation_spp(df_tr, df_te, score_t):
    
    # Shape data for evaluation
    surv_train = df_tr.groupby('ID').max()[['Event','Stop']].to_numpy()
    surv_test  = df_te.groupby('ID').max()[['Event','Stop']]
    list_id, surv_test = surv_test.index.values, surv_test.to_numpy()
    # Make structed array -> survival_train, survival_test
    survival_train = zeros(len(surv_train),dtype=[('Event',bool),('Stop',float64)])
    survival_test  = zeros(len(surv_test),dtype=[('Event',bool),('Stop',float64)])
    survival_train['Event'], survival_train['Stop'] = surv_train[:,0], surv_train[:,1]
    survival_test['Event'],  survival_test['Stop']  = surv_test[:,0], surv_test[:,1]

    # Estimation and prediction
    model = SPP(kernel='Gaussian', eq_kernel='RFM', eq_kernel_options={'n_rfm':500})
    with tf.device('/cpu:0'):
        set_par = [[1]+[x]*13 for x in [0.1, 0.2, 0.5, 0.7, 1.0, 2.0, 5.0, 7.0, 10.0]]
        #set_par = [[1]+[x]*13 for x in [0.01, 0.02, 0.03]]
        formula = 'Surv(Start, Stop, Event) ~ age+edema+alk.phos+chol+ast+platelet+spiders+hepato+ascites+albumin+bili+protime+ttt'
        cpu = model.fit(formula, df=df_tr, set_par=set_par)
    
    # Calculate cumulative hazard function
    # and performances
    t = array(sorted(unique(list(linspace(0,5000,1000))+\
                            list(surv_test[:,1])+list(score_t))))
    tt = 0.5*(t[1:]+t[:-1])
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        dff = df_te[df_te['ID']==id]
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(t)
               for c in cov_names
               ]
        #s = model.predict(c_[array(cov).T,t],[0.5])[0]
        s = model.predict(array(cov).T,[0.5])[0]
        cum_haz = r_[0,cumsum(0.5*(s[:-1]+s[1:])*diff(t))]
        risk_id.append(cum_haz[isin(t,score_t)])
        
        s, e = minimum(score_t,t1), event * (t1<=score_t)
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(t1)
               for c in cov_names
               ]
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*log(model.predict(array(cov)[newaxis,:],[0.5])[0][0])
        tll_id.append(tll)
    
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)

    return auc, bri, tll, cpu
    

if __name__ == "__main__":
    main()
