from pycox.evaluation import EvalSurv
from SurvivalEVAL.Evaluator import SurvivalEvaluator
from sksurv.metrics import cumulative_dynamic_auc
from sksurv.util import Surv
from sksurv.metrics import concordance_index_ipcw
import numpy as np
import pandas as pd


# Time-dependent concordance index
def calculate_td_cindex(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    predict_df = pd.DataFrame(predict.T, index=time_index_train)
    eval_pycox = EvalSurv(predict_df, label_test, censoring_test)
    C_td = eval_pycox.concordance_td('antolini')
    return C_td


# D-calibration
def calculate_d_calibration(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    eval = SurvivalEvaluator(predict, time_index_train, label_test, censoring_test, label_train, censoring_train)
    p_value, _ = eval.d_calibration()
    return p_value


# Brier score
def calculate_brier_score(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    brier_score = []
    eval = SurvivalEvaluator(predict, time_index_train, label_test, censoring_test, label_train, censoring_train)
    # for eval_time in time_index_test:
    #     brier_score.append(eval.brier_score(eval_time))
    integrated_brier_score = eval.integrated_brier_score()
    return integrated_brier_score, brier_score


# MAE
def calculate_mae(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    eval = SurvivalEvaluator(predict, time_index_train, label_test, censoring_test, label_train, censoring_train)
    mae = eval.mae(method='margin', weighted=True)
    return mae


# Truncated time-dependent concordance index
def calculate_truncated_td_cindex(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    truncated_C_td = []
    labels_train_sksurv = Surv.from_arrays(censoring_train, label_train)
    labels_test_sksurv = Surv.from_arrays(censoring_test, label_test)
    for eval_time in time_index_test:
        # find the interpolated time grid's time point closest to the evaluation time
        interp_time_index = np.argmin(np.abs(eval_time - time_index_train))
        surv_values_at_eval_time_np = predict[:, interp_time_index]
        estimated_risks_np = 1 - surv_values_at_eval_time_np
        truncated_C_td.append(concordance_index_ipcw(labels_train_sksurv, labels_test_sksurv, estimated_risks_np, tau=eval_time)[0])
    return np.mean(truncated_C_td), truncated_C_td
        

# Time-dependent AUC
def calculate_td_auc(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test):
    td_auc = []
    if np.max(label_test) >= np.max(label_train):
        label_test = label_test - 1e-6
    labels_train_sksurv = Surv.from_arrays(censoring_train, label_train)
    labels_test_sksurv = Surv.from_arrays(censoring_test, label_test)
    time_index_test = np.unique(label_test)
    percentile_list = np.arange(0, 1, 0.2)
    time_index_test = [time_index_test[int(i * len(time_index_test))] for i in percentile_list]
    for eval_time in time_index_test:
        # find the interpolated time grid's time point closest to the evaluation time
        interp_time_index = np.argmin(np.abs(eval_time - time_index_train))
        surv_values_at_eval_time_np = predict[:, interp_time_index]
        estimated_risks_np = 1 - surv_values_at_eval_time_np
        td_auc.append(cumulative_dynamic_auc(labels_train_sksurv, labels_test_sksurv, estimated_risks_np, times=[eval_time])[0])
    return np.mean(td_auc), td_auc


class Metric(object):
    def __init__(self, hparams):
        self.hparams = hparams

    def __call__(self, predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test, sensitive_attribute):
        score = {'per_group': []}
        n_group = len(set(sensitive_attribute))
        # calculate overall accuracy
        if self.hparams['metric'] == 'ctd':
            score['accuracy'] = calculate_td_cindex(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test)
        elif self.hparams['metric'] == 'brier':
            score['accuracy'] = calculate_brier_score(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test)[0]
        elif self.hparams['metric'] == 'mae':
            score['accuracy'] = calculate_mae(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test)
        elif self.hparams['metric'] == 'auc':
            score['accuracy'] = calculate_td_auc(predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test)[0]
        # calculate per group accuracy
        for s in range(n_group):
            idx = sensitive_attribute == s
            if self.hparams['metric'] == 'ctd':
                score['per_group'].append(calculate_td_cindex(predict[idx], time_index_train, time_index_test, label_train, label_test[idx], censoring_train, censoring_test[idx]))
            elif self.hparams['metric'] == 'brier':
                score['per_group'].append(calculate_brier_score(predict[idx], time_index_train, time_index_test, label_train, label_test[idx], censoring_train, censoring_test[idx])[0])
            elif self.hparams['metric'] == 'mae':
                score['per_group'].append(calculate_mae(predict[idx], time_index_train, time_index_test, label_train, label_test[idx], censoring_train, censoring_test[idx]))
            elif self.hparams['metric'] == 'auc':
                score['per_group'].append(calculate_td_auc(predict[idx], time_index_train, time_index_test, label_train, label_test[idx], censoring_train, censoring_test[idx])[0])
        # calculate fairness
        score['fairness'] = max(score['per_group']) - min(score['per_group'])
        return score
