#import os
#os.environ['R_HOME'] = '/Library/Frameworks/R.framework/Resources'

from pylab import *
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sksurv.metrics import cumulative_dynamic_auc, brier_score
import time, sys, dill
import rpy2.robjects as ro
from rpy2.robjects.conversion import localconverter, py2rpy,rpy2py
from rpy2.robjects import pandas2ri, numpy2ri
from scipy.interpolate import splrep, BSpline, PchipInterpolator, interp1d

# No show of rpy2 warning
from rpy2.rinterface_lib.callbacks import logger as rpy2_logger
import logging
rpy2_logger.setLevel(logging.ERROR)

cov_names = ['age','edema','alk.phos','chol','ast','platelet',
                 'spiders','hepato','ascites','albumin','bili','protime']
    

def load_data():
    ro.r('cwd <- getwd()')
    ro.r('library(survival)')
    ro.r('setwd("../TimeVaryingData_LTRCforests/analysis/data/")')
    ro.r('source("pbc_complete.R")')
    ro.r('DATA = make_realset(0)')
    with localconverter(ro.default_converter + pandas2ri.converter):
        data = rpy2py(ro.r('DATA'))
    ro.r('setwd(cwd)')
    return data

def main():
    n_split = 10
    score_t = array([1000,2000,3000,4000])
    
    # DATA SPLITTING ################################# 
    df = load_data()

    #for c in cov_names:
    #    df[c] = 0.1*(df[c] - df[c].mean()) / df[c].std()
    
    kf = KFold(n_splits=n_split, shuffle=True, random_state=0)
    df_train, df_test = [], []
    for indx_train, indx_test in kf.split(unique(df['ID'].to_numpy())):
        df_train.append(df[df['ID'].isin(indx_train)])
        df_test.append(df[df['ID'].isin(indx_test)])
    
    # START ESTIMATION & PREDICTION ##################
    models = ['gbm']#'sf','gbm']#['sf','cox','gbm']
    for m in models:
        print(m+'...')
        score = {x:[] for x in ['auc','cpu','bri','tll']}
        score['t'] = score_t
        for i, (df_tr, df_te) in enumerate(zip(df_train,df_test)):
            print(i)
            auc, bri, tll, cpu = \
                eval('estimation_'+m)(df_tr,df_te,score_t)
            score['auc'].append(auc)
            score['bri'].append(bri)
            score['tll'].append(tll)
            score['cpu'].append(cpu)
        print(score['tll'])
        dill.dump(score, open('result/'+m+'_pbc.dill','wb'))

#----------------------------------------------------
def estimation_gbm(df_tr,df_te,score_t):
    ro.r('library(survival) \n' +
         'library(gbm3)'
         )
    
    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)
    
    t_start = time.time()
    ro.r('set.seed(0)')
    formula = 'Surv(Start, Stop, Event) ~ age+edema+alk.phos+chol+ast+platelet+spiders+hepato+ascites+albumin+bili+protime'
    ro.r('params_grid <- expand.grid(num_trees = c(500,1000,2000), shrinkage = c(0.001, 0.005, 0.01))')
    ro.r('score <- c()')
    ro.r('for (i in 1:length(params_grid[,1])) {\n' +
         '  params <- ' +
         '  training_params(num_trees = params_grid[i,]$num_trees,' +
         '  shrinkage = params_grid[i,]$shrinkage,' +
         '  interaction_depth = 1,' +
         '  id = df$ID,' +
         '  num_train = round(0.5 * length(unique(df$ID)))) \n' +
         '  res_gbm <- ' +
         '  gbmt('+formula+',' +
         '  data = df, distribution = gbm_dist("CoxPH"),' +
         '  train_params = params, cv_folds=10, keep_gbm_data = TRUE,' +
         '  par_details = gbmParallel(num_threads = 12)) \n' +
         '  score <- append(score,min(res_gbm$valid.error)) \n' +
         '}'
         )
    ro.r('indx <- which.min(score)')
    ro.r('params <- ' +
         'training_params(num_trees = params_grid[indx,]$num_trees,' +
         'shrinkage = params_grid[indx,]$shrinkage,' +
         'interaction_depth = 1,' +
         'id = df$ID,' +
         'num_train = round(0.5 * length(unique(df$ID)))) \n' +
         'res_gbm <- ' +
         'gbmt('+formula+',' +
         'data = df, distribution = gbm_dist("CoxPH"),' +
         'train_params = params, cv_folds=10, keep_gbm_data = TRUE,' +
         'par_details = gbmParallel(num_threads = 12))'
         )
    ro.r('best_iter <- gbmt_performance(res_gbm, method="cv")')
    ro.r('res_gbm.pred <- exp(predict(res_gbm, df, n.trees = best_iter))')
    with localconverter(ro.default_converter + pandas2ri.converter):
        pred = rpy2py(ro.r('res_gbm.pred'))
    
    # Calculate cumulative base hazard function 
    df_tr_numpy = df_tr.to_numpy()
    df_tr_numpy = c_[df_tr_numpy,pred]
    spk = sort(df_tr_numpy[where(df_tr_numpy[:,3]==1)][:,2])
    #spk = sort(spk + 0.0001*np.random.random(len(spk)))
    cumhaz = [1./sum(df_tr_numpy[(df_tr_numpy[:,1]<s)&(df_tr_numpy[:,2]>=s)][:,6])
              for s in spk]
    cumhaz = cumsum(cumhaz)
    t_end = time.time()
        
    # Calculate base hazard function with spline
    indx = where(diff(r_[0,cumhaz])>0)
    ttt = r_[0,spk[indx]]
    bbb = r_[0,cumhaz[indx]]
    risk, basehaz = cumhaz2haz(ttt,bbb,t1=max(surv_test[:,1])*1.1)
    
    # Calculate cumulative hazard function
    # and performances
    t = array(sorted(unique(list(linspace(0,5000,1000))+\
                            list(surv_test[:,1])+list(score_t))))
    d_risk = diff(risk(t))
    tt = 0.5*(t[1:]+t[:-1])
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        dff = df_te[df_te['ID']==id]
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(tt)
               for c in cov_names
               ]
        df_cov = pd.DataFrame(array(cov).T,columns=cov_names)
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)
        ro.r.assign("df_cov",rdf_cov)
        ro.r('res_gbm.pred <- exp(predict(res_gbm, df_cov, n.trees = best_iter))')
        with localconverter(ro.default_converter + pandas2ri.converter):
            pred = rpy2py(ro.r('res_gbm.pred'))
        s = d_risk * pred
        cum_haz = cumsum(r_[0,s])
        risk_id.append(cum_haz[isin(t,score_t)])
        
        s, e = minimum(score_t,t1), event * (t1<=score_t)
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(t1)
               for c in cov_names
               ]
        df_cov = pd.DataFrame(array(cov)[:,newaxis].T,columns=cov_names)
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)    
        ro.r.assign("df_cov",rdf_cov)
        ro.r('res_gbm.pred <- predict(res_gbm, df_cov, n.trees = best_iter)')
        with localconverter(ro.default_converter + pandas2ri.converter):
            pred = rpy2py(ro.r('res_gbm.pred'))[0]
        
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*(pred+log(basehaz(t1)))
        tll_id.append(tll)
            
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('rm()')

    return auc, bri, tll, cpu

#----------------------------------------------------
def estimation_cox(df_tr,df_te,score_t):
    ro.r('library(survival)')

    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)

    t_start = time.time()
    formula = 'Surv(Start, Stop, Event) ~ age+edema+alk.phos+chol+ast+platelet+spiders+hepato+ascites+albumin+bili+protime'
    dammy = 'list(age=0,edema=0,alk.phos=0,chol=0,ast=0,platelet=0,spiders=0,hepato=0,ascites=0,albumin=0,bili=0,protime=0)'
    ro.r('res_coxph <- coxph('+formula+', df) \n' +
         'sfit <- survfit(res_coxph,'+dammy+') \n'
    )
    t_end = time.time()
    
    # Calculate cumulative base hazard function 
    with localconverter(ro.default_converter + pandas2ri.converter):
        t_risk  = rpy2py(ro.r('sfit$time'))
        b_risk  = rpy2py(ro.r('sfit$cumhaz'))
        n_event = rpy2py(ro.r('sfit$n.event'))
        coef    = rpy2py(ro.r('res_coxph$coefficients'))
    surv_train = df_tr.groupby('ID').max()[['Event','Stop']].to_numpy()
    surv_test  = df_te.groupby('ID').max()[['Event','Stop']]
    list_id, surv_test = surv_test.index.values, surv_test.to_numpy()
    
    # Calculate base hazard function with spline
    ttt = r_[0,t_risk[where(n_event>=1)]]
    bbb = r_[0,b_risk[where(n_event>=1)]]
    risk, basehaz = cumhaz2haz(ttt,bbb,t1=max(surv_test[:,1]))
    
    # Calculate cumulative hazard function
    # and performances
    t = array(sorted(unique(list(linspace(0,5000,1000))+\
                            list(surv_test[:,1])+list(score_t))))
    d_risk = diff(risk(t))
    tt = 0.5*(t[1:]+t[:-1])
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        dff = df_te[df_te['ID']==id]
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(tt)
               for c in ['age','edema','alk.phos','chol','ast','platelet',
                         'spiders','hepato','ascites','albumin','bili','protime']
               ]
        s = d_risk * exp(dot(coef[newaxis,:],cov)[0])
        cum_haz = cumsum(r_[0,s])
        risk_id.append(cum_haz[isin(t,score_t)])

        s, e = minimum(score_t,t1), event * (t1<=score_t)
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(t1)
               for c in ['age','edema','alk.phos','chol','ast','platelet',
                         'spiders','hepato','ascites','albumin','bili','protime']
               ]
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*(dot(coef[newaxis,:],array(cov)[:,newaxis])[0]+log(basehaz(t1)))
        tll_id.append(tll)
    
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('rm()')
    
    return auc, bri, tll, cpu

#----------------------------------------------------
def estimation_sf(df_tr,df_te,score_t):
    ro.r('cwd <- getwd()')
    cdir = '../TimeVaryingData_LTRCforests/'
    ro.r('setwd("'+cdir+'analysis/utils/transformation/mlt_mod") \n' +
         'devtools::load_all()')
    ro.r('setwd("'+cdir+'analysis/utils/transformation/tram_mod") \n' +
         'devtools::load_all()')
    ro.r('setwd("'+cdir+'analysis/utils/transformation/trtf_mod") \n' +
         'devtools::load_all()')
    ro.r('library(LTRCforests) \n' +
         'library(ipred) \n' +
         'library(partykit) \n' +
         'library(prodlim) \n' +
         'library(pec) \n' +
         'setwd("'+cdir+'analysis/utils/") \n' +
         'source("Loss_funct_tvary.R") \n' +
         'source("Surv_funct.R") \n' +
         'source("tsf_tvary_funct.R") \n' +
         'library(survival)'
         )
    #ro.r('setwd("'+cdir+'pkg/LTRCforests") \n' +
    #     'devtools::load_all()')
    
    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)
    #ro.r('colnames(df) <- c("ID", "Start", "Stop", "Event", "cov1", "cov2")')
    
    t_start = time.time()
    ro.r('Formula = Surv(Start, Stop, Event) ~ age+edema+alk.phos+chol+ast+platelet+spiders+hepato+ascites+albumin+bili+protime')
    ro.r('modelT <- ltrcrrf(formula = Formula, data = df, id = ID, stepFactor = 1.5)')
    t_end = time.time()
    
    # Calculate cumulative hazard function
    t = array(sorted(unique(list(linspace(0,5000,1000))+list(surv_train[:,1])+\
                            list(surv_test[:,1])+list(score_t))))
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        dff = df_te[df_te['ID']==id]
        cov = [interp1d(dff['Stop'].to_numpy(),dff[c].to_numpy(),kind='next',
                        fill_value=(dff[c].to_numpy()[0],dff[c].to_numpy()[-1]),
                        bounds_error=False)(t[1:])
               for c in cov_names
               ]
        df_cov = c_[[id]*len(cov[0]),t[:-1],t[1:],[0]*len(cov[0]),array(cov).T]
        df_cov = pd.DataFrame(df_cov, columns=['ID','Start','Stop','Event']+cov_names)
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)    
        ro.r.assign("df_cov",rdf_cov)
        
        ro.r('predT <- predictProb(object = modelT, newdata = df_cov, ' +
             'newdata.id = ID, time.eval = sort(df_cov$Stop))')
        with localconverter(ro.default_converter + pandas2ri.converter):
            cum_haz = -log(rpy2py(ro.r('predT$survival.probs+1.e-5')))[:,0]
        cum_haz = r_[0,cum_haz]
        risk_id.append(cum_haz[isin(t,score_t)])
        
        # Calculate hazard function with spline
        indx = where(diff(r_[0,cum_haz])>0)
        ttt = r_[0,t[indx]]
        yyy = r_[0,cum_haz[indx]]
        risk, haz = cumhaz2haz(ttt,yyy)
                
        s, e = minimum(score_t,t1), event * (t1<=score_t)
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*log(haz(t1))
        tll_id.append(tll)
        print(id)
        
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('setwd(cwd)')
    ro.r('rm()')

    return auc, bri, tll, cpu
    
    
#----------------------------------------------------
def shaping_data(df_tr,df_te):
    # Shape data for evaluation
    surv_train = df_tr.groupby('ID').max()[['Event','Stop']].to_numpy()
    surv_test  = df_te.groupby('ID').max()[['Event','Stop']]
    list_id, surv_test = surv_test.index.values, surv_test.to_numpy()
    # Make structed array -> survival_train, survival_test
    survival_train = zeros(len(surv_train),dtype=[('Event',bool),('Stop',float64)])
    survival_test  = zeros(len(surv_test),dtype=[('Event',bool),('Stop',float64)])
    survival_train['Event'], survival_train['Stop'] = surv_train[:,0], surv_train[:,1]
    survival_test['Event'],  survival_test['Stop']  = surv_test[:,0], surv_test[:,1]

    return surv_train, surv_test, list_id, survival_train, survival_test

def cumhaz2haz(ttt,bbb,t0=0.0,t1=1.0):
    if ttt[-1] < t1:
        b0 = bbb[-1] + (bbb[-1]-bbb[-2])/(ttt[-1]-ttt[-2])*(t1-ttt[-1])
        bbb = r_[bbb,b0]
        ttt = r_[ttt,t1]
    s = linspace(t0,ttt[-1],10)
    bbb_interp = interp1d(ttt,bbb)(s)
    risk = lambda t: interp1d(s,bbb_interp)(t)
    haz = lambda t: interp1d(0.5*(s[:-1]+s[1:]),diff(bbb_interp)/diff(s),
                             kind='nearest',fill_value='extrapolate')(t)        
    return risk, haz

if __name__ == "__main__":
    main()
