#import os
#os.environ['R_HOME'] = '/Library/Frameworks/R.framework/Resources'

from pylab import *
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sksurv.metrics import cumulative_dynamic_auc, brier_score
import time, sys, dill
import rpy2.robjects as ro
from rpy2.robjects.conversion import localconverter, py2rpy,rpy2py
from rpy2.robjects import pandas2ri, numpy2ri
from scipy.interpolate import splrep, BSpline, PchipInterpolator, interp1d

# No show of rpy2 warning
from rpy2.rinterface_lib.callbacks import logger as rpy2_logger
import logging
rpy2_logger.setLevel(logging.ERROR)


def main():
    n_split = 10
    score_t = arange(0.3,1.0,0.1)
    f_data = 'data_nsub100000nint10.dill'
    data = dill.load(open('data/nonlinear/'+f_data,'rb'))

    # DATA SPLITTING ################################# 
    df, cov_func = data['df'], data['func']
    kf = KFold(n_splits=n_split, shuffle=True, random_state=0)
    df_train, df_test = [], []
    for indx_train, indx_test in kf.split(unique(df['id'].to_numpy())):
        df_train.append(df[df['id'].isin(indx_train)])
        df_test.append(df[df['id'].isin(indx_test)])

    # START ESTIMATION & PREDICTION ##################
    models = ['cox']#'sf','gbm']#['sf','cox','gbm']
    for m in models:
        print(m+'...')
        score = {x:[] for x in ['auc','cpu','bri','tll']}
        score['t'] = score_t
        for i, (df_tr, df_te) in enumerate(zip(df_train,df_test)):
            print(i)
            auc, bri, tll, cpu = \
                eval('estimation_'+m)(df_tr,df_te,score_t,cov_func)
            score['auc'].append(auc)
            score['bri'].append(bri)
            score['tll'].append(tll)
            score['cpu'].append(cpu)
        print(score['tll'])
        dill.dump(score, open('result/nonlinear/'+m+'_'+f_data,'wb'))

#----------------------------------------------------
def estimation_gbm(df_tr,df_te,score_t,cov_func):
    ro.r('library(survival) \n' +
         'library(gbm3)'
         )

    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)
    
    t_start = time.time()
    ro.r('set.seed(0)')

    ro.r('params_grid <- expand.grid(num_trees = c(500,1000,2000), shrinkage = c(0.001, 0.005, 0.01))')
    ro.r('score <- c()')
    ro.r('for (i in 1:length(params_grid[,1])) {\n' +
         '  params <- ' +
         '  training_params(num_trees = params_grid[i,]$num_trees,' +
         '  shrinkage = params_grid[i,]$shrinkage,' +
         '  interaction_depth = 1,' +
         '  id = df$id,' +
         '  num_train = round(0.5 * length(unique(df$id)))) \n' +
         '  res_gbm <- ' +
         '  gbmt(Surv(t0, t1, event) ~ cov1 + cov2,' +
         '  data = df, distribution = gbm_dist("CoxPH"),' +
         '  train_params = params, cv_folds=10, keep_gbm_data = TRUE,' +
         '  par_details = gbmParallel(num_threads = 12)) \n' +
         '  score <- append(score,min(res_gbm$valid.error)) \n' +
         '}'
         )
    ro.r('indx <- which.min(score)')
    ro.r('params <- ' +
         'training_params(num_trees = params_grid[indx,]$num_trees,' +
         'shrinkage = params_grid[indx,]$shrinkage,' +
         'interaction_depth = 1,' +
         'id = df$id,' +
         'num_train = round(0.5 * length(unique(df$id)))) \n' +
         'res_gbm <- ' +
         'gbmt(Surv(t0, t1, event) ~ cov1 + cov2,' +
         'data = df, distribution = gbm_dist("CoxPH"),' +
         'train_params = params, cv_folds=10, keep_gbm_data = TRUE,' +
         'par_details = gbmParallel(num_threads = 12))'
         )
    ro.r('best_iter <- gbmt_performance(res_gbm, method="cv")')
    ro.r('res_gbm.pred <- exp(predict(res_gbm, df, n.trees = best_iter))')
    with localconverter(ro.default_converter + pandas2ri.converter):
        pred = rpy2py(ro.r('res_gbm.pred'))
    
    # Calculate cumulative base hazard function 
    df_tr_numpy = df_tr.to_numpy()
    df_tr_numpy = c_[df_tr_numpy,pred]
    spk = sort(df_tr_numpy[where(df_tr_numpy[:,3]==1)][:,2])
    #spk = sort(spk + 0.0001*np.random.random(len(spk)))
    cumhaz = [1./sum(df_tr_numpy[(df_tr_numpy[:,1]<s)&(df_tr_numpy[:,2]>=s)][:,6])
              for s in spk]
    cumhaz = cumsum(cumhaz)
    t_end = time.time()
    """
    # Calculate base hazard function with spline
    indx = where(diff(r_[0,cumhaz])>0)
    ttt = r_[0,spk[indx]]
    bbb = r_[0,cumhaz[indx]]
    risk, basehaz = cumhaz2haz(ttt,bbb)
            
    # Calculate cumulative hazard function
    # and performances
    t = array(sorted(unique(list(linspace(0,1,1000))+\
                            list(surv_test[:,1])+list(score_t))))
    d_risk = diff(risk(t))
    tt = 0.5*(t[1:]+t[:-1])
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        [[a1,w1,b1],[a2,w2,b2]] = cov_func[id]
        cov1, cov2 = a1*cos(2*pi*w1*tt+pi*b1), a2*cos(2*pi*w2*tt+pi*b2)
        df_cov = pd.DataFrame(c_[cov1,cov2],columns=['cov1','cov2'])
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)    
        ro.r.assign("df_cov",rdf_cov)
        ro.r('res_gbm.pred <- exp(predict(res_gbm, df_cov, n.trees = best_iter))')
        with localconverter(ro.default_converter + pandas2ri.converter):
            pred = rpy2py(ro.r('res_gbm.pred'))
        s = d_risk * pred
        cum_haz = cumsum(r_[0,s])
        risk_id.append(cum_haz[isin(t,score_t)])
        
        s, e = minimum(score_t,t1), event * (t1<=score_t)
        cov1, cov2 = a1*cos(2*pi*w1*t1+pi*b1), a2*cos(2*pi*w2*t1+pi*b2)
        df_cov = pd.DataFrame(c_[cov1,cov2],columns=['cov1','cov2'])
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)    
        ro.r.assign("df_cov",rdf_cov)
        ro.r('res_gbm.pred <- predict(res_gbm, df_cov, n.trees = best_iter)')
        with localconverter(ro.default_converter + pandas2ri.converter):
            pred = rpy2py(ro.r('res_gbm.pred'))[0]
        
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*(pred+log(basehaz(t1)))
        tll_id.append(tll)
            
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('rm()')

    return auc, bri, tll, cpu
    """

    ro.r('rm()')
    return 0, 0, 0, t_end - t_start

#----------------------------------------------------
def estimation_cox(df_tr,df_te,score_t,cov_func):
    ro.r('library(survival)')

    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)

    t_start = time.time()
    ro.r('res_coxph <- coxph(Surv(t0, t1, event) ~ cov1 + cov2, df) \n' +
         'sfit <- survfit(res_coxph,list(cov1=0,cov2=0)) \n'
    )
    t_end = time.time()
    """
    # Calculate cumulative base hazard function 
    with localconverter(ro.default_converter + pandas2ri.converter):
        t_risk  = rpy2py(ro.r('sfit$time'))
        b_risk  = rpy2py(ro.r('sfit$cumhaz'))
        n_event = rpy2py(ro.r('sfit$n.event'))
        coef    = rpy2py(ro.r('res_coxph$coefficients'))
    surv_train = df_tr.groupby('id').max()[['event','t1']].to_numpy()
    surv_test  = df_te.groupby('id').max()[['event','t1']]
    list_id, surv_test = surv_test.index.values, surv_test.to_numpy()
    
    # Calculate base hazard function with spline
    ttt = r_[0,t_risk[where(n_event>=1)]]
    bbb = r_[0,b_risk[where(n_event>=1)]]
    risk, basehaz = cumhaz2haz(ttt,bbb)
    
    # Calculate cumulative hazard function
    # and performances
    t = array(sorted(unique(list(linspace(0,1,1000))+\
                            list(surv_test[:,1])+list(score_t))))
    d_risk = diff(risk(t))
    tt = 0.5*(t[1:]+t[:-1])
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        [[a1,w1,b1],[a2,w2,b2]] = cov_func[id]
        cov1, cov2 = a1*cos(2*pi*w1*tt+pi*b1), a2*cos(2*pi*w2*tt+pi*b2)
        s = d_risk * exp(coef[0]*cov1+coef[1]*cov2)
        cum_haz = cumsum(r_[0,s])
        risk_id.append(cum_haz[isin(t,score_t)])

        s, e = minimum(score_t,t1), event * (t1<=score_t)
        cov1, cov2 = a1*cos(2*pi*w1*t1+pi*b1), a2*cos(2*pi*w2*t1+pi*b2)
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*(coef[0]*cov1+coef[1]*cov2+log(basehaz(t1)))
        tll_id.append(tll)
    
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('rm()')
    
    return auc, bri, tll, cpu
    """

    ro.r('rm()')
    return 0, 0, 0, t_end - t_start 

#----------------------------------------------------
def estimation_sf(df_tr,df_te,score_t,cov_func):
    ro.r('cwd <- getwd()')
    cdir = '../TimeVaryingData_LTRCforests/'
    ro.r('setwd("'+cdir+'analysis/utils/transformation/mlt_mod") \n' +
         'devtools::load_all()')
    ro.r('setwd("'+cdir+'analysis/utils/transformation/tram_mod") \n' +
         'devtools::load_all()')
    ro.r('setwd("'+cdir+'analysis/utils/transformation/trtf_mod") \n' +
         'devtools::load_all()')
    ro.r('library(LTRCforests) \n' +
         'library(ipred) \n' +
         'library(partykit) \n' +
         'library(prodlim) \n' +
         'library(pec) \n' +
         'setwd("'+cdir+'analysis/utils/") \n' +
         'source("Loss_funct_tvary.R") \n' +
         'source("Surv_funct.R") \n' +
         'source("tsf_tvary_funct.R") \n' +
         'library(survival)'
         )
    #ro.r('setwd("'+cdir+'pkg/LTRCforests") \n' +
    #     'devtools::load_all()')

    # Shape data for evaluation
    surv_train, surv_test, list_id, survival_train, survival_test \
        = shaping_data(df_tr,df_te)
    
    # Estimation and prediction
    with localconverter(ro.default_converter + pandas2ri.converter):
        rdf = py2rpy(df_tr)    
    ro.r.assign("df",rdf)
    ro.r('colnames(df) <- c("ID", "Start", "Stop", "Event", "cov1", "cov2")')
    
    ntree = 100
    mtrypool = [20,10,5,3,2,1]
    ro.r('ntree = '+str(ntree)+'L')
    ro.r.assign("mtrypool",py2rpy(mtrypool))
    ro.r('mtrypool <- unlist(mtrypool)')
    ro.r('mtryD <- ceiling(10)')
    ro.r('Formula = Surv(Start,Stop,Event) ~ cov1 + cov2 \n' +
         'Formula_TD = Surv(Start,Stop,Event, type = "counting") ~ cov1 + cov2')
    
    t_start = time.time()
    """
    ro.r('modelT <- ltrccif(formula = Formula, data = df, id = ID,' +
         'mtry = mtryD, ntree = ntree,' +
         'control = partykit::ctree_control(teststat = "quad", testtype = "Univ",' +
         'mincriterion = 0, saveinfo = FALSE,minsplit = 20,minbucket = 7,minprob = 0.01))'
         )
    """
    ro.r('modelT <- ltrcrrf(formula = Formula, data = df, id = ID,' +
         'mtry = mtryD, ntree = ntree)'
         )
    t_end = time.time()
    """
    # Calculate cumulative hazard function
    t = array(sorted(unique(list(linspace(0,1,1000))+list(surv_train[:,1])+\
                            list(surv_test[:,1])+list(score_t))))
    risk_id, tll_id = [], []
    for id, event, t1 in zip(list_id,surv_test[:,0],surv_test[:,1]):
        [[a1,w1,b1],[a2,w2,b2]] = cov_func[id]
        cov1, cov2 = a1*cos(2*pi*w1*t[1:]+pi*b1), a2*cos(2*pi*w2*t[1:]+pi*b2)
        df_cov = c_[[id]*len(cov1),t[:-1],t[1:],[0]*len(cov1),cov1,cov2]
        df_cov = pd.DataFrame(df_cov, columns=['ID','Start','Stop','Event','cov1','cov2'])
        with localconverter(ro.default_converter + pandas2ri.converter):
            rdf_cov = py2rpy(df_cov)    
        ro.r.assign("df_cov",rdf_cov)
        
        ro.r('predT <- predictProb(object = modelT, newdata = df_cov, ' +
             'newdata.id = ID, time.eval = sort(df_cov$Stop))')
        with localconverter(ro.default_converter + pandas2ri.converter):
            cum_haz = -log(rpy2py(ro.r('predT$survival.probs+1.e-5')))[:,0]
        cum_haz = r_[0,cum_haz]
        risk_id.append(cum_haz[isin(t,score_t)])
        
        # Calculate hazard function with spline
        indx = where(diff(r_[0,cum_haz])>0)
        ttt = r_[0,t[indx]]
        yyy = r_[0,cum_haz[indx]]
        risk, haz = cumhaz2haz(ttt,yyy)
                
        s, e = minimum(score_t,t1), event * (t1<=score_t)
        tll = - array([cum_haz[where(t==ss)][0] for ss in s]) + \
            e*log(haz(t1))
        tll_id.append(tll)
        print(id)
        
    # Performance score 
    auc, _ = cumulative_dynamic_auc(survival_train,survival_test,risk_id,score_t)
    _, bri = brier_score(survival_train,survival_test,risk_id,score_t)
    tll = mean(tll_id,0)
    cpu = t_end - t_start
    ro.r('setwd(cwd)')
    ro.r('rm()')

    return auc, bri, tll, cpu
    """

    ro.r('setwd(cwd)')
    ro.r('rm()')
    return 0, 0, 0, t_end - t_start  
    
#----------------------------------------------------
def shaping_data(df_tr,df_te):
    # Shape data for evaluation
    surv_train = df_tr.groupby('id').max()[['event','t1']].to_numpy()
    surv_test  = df_te.groupby('id').max()[['event','t1']]
    list_id, surv_test = surv_test.index.values, surv_test.to_numpy()
    # Make structed array -> survival_train, survival_test
    survival_train = zeros(len(surv_train),dtype=[('event',bool),('t1',float64)])
    survival_test  = zeros(len(surv_test),dtype=[('event',bool),('t1',float64)])
    survival_train['event'], survival_train['t1'] = surv_train[:,0], surv_train[:,1]
    survival_test['event'],  survival_test['t1']  = surv_test[:,0], surv_test[:,1]

    return surv_train, surv_test, list_id, survival_train, survival_test

def cumhaz2haz(ttt,bbb,t0=0.0,t1=1.0):
    if ttt[-1] < t1:
        b0 = bbb[-1] + (bbb[-1]-bbb[-2])/(ttt[-1]-ttt[-2])*(t1-ttt[-1])
        bbb = r_[bbb,b0]
        ttt = r_[ttt,t1]
    s = linspace(t0,t1,10)
    bbb_interp = interp1d(ttt,bbb)(s)
    risk = lambda t: interp1d(s,bbb_interp)(t)
    haz = lambda t: interp1d(0.5*(s[:-1]+s[1:]),diff(bbb_interp)/diff(s),
                             kind='nearest',fill_value='extrapolate')(t)        
    return risk, haz

if __name__ == "__main__":
    main()
