{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file allows to analyze results obtained by running `experiments_cnsc.py`.\n",
    "\n",
    "It computed performance metric, analyse the evolution of likelihood given number of clusters if available, and display the obtained treatment response clusters (for the selected methodology)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "sys.path.append('../DeepSurvivalMachines/')\n",
    "sys.path.append('../auton-survival/')\n",
    "\n",
    "from cnsc import datasets\n",
    "from experiment import Experiment\n",
    "from simulation.generate import generate, compute_cif\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False, \"axes.spines.left\": False,\n",
    "                 \"axes.spines.bottom\": False, \"figure.dpi\": 700, 'savefig.dpi': 300}\n",
    "sns.set_theme(style = \"whitegrid\", rc = custom_params, font_scale = 1.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change this to analyze other datasets result\n",
    "dataset = 'SEER'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../Results_cnsc/' # Path where the data is saved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, a, t, e, covariates = datasets.load_dataset(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "horizons = [0.25, 0.5, 0.75] # Horizons to evaluate the models\n",
    "times_eval = np.quantile(t[e > 0], horizons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pycox.evaluation import EvalSurv\n",
    "from sklearn.metrics import adjusted_rand_score\n",
    "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score\n",
    "\n",
    "def evaluate(survival_pred, a = a, t = t, e = e, times_eval = times_eval):\n",
    "    \"\"\"\n",
    "        Evaluate the performance of the survival model on the factual loss and on the gt if available\n",
    "    \"\"\"\n",
    "    folds = survival_pred[('Use',)].values.flatten()\n",
    "    a = pd.Series(a, index = survival_pred.index).astype(bool)\n",
    "    factual_survival = pd.concat([survival_pred['treated'].loc[a.loc[survival_pred.index]], survival_pred['untreated'].loc[~a.loc[survival_pred.index]]], axis = 0).loc[survival_pred.index]\n",
    "    factual_survival.columns = factual_survival.columns.astype(float)\n",
    "    times = factual_survival.columns.unique()\n",
    "    results = {}\n",
    "\n",
    "    \n",
    "    for fold in np.arange(np.unique(folds).shape[0]):\n",
    "        res = {}\n",
    "        # Subselect given fold\n",
    "        e_train, t_train = e[folds != fold], t[folds != fold]\n",
    "        e_test,  t_test  = e[folds == fold], t[folds == fold]\n",
    "\n",
    "        et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))], # For estimation censoring\n",
    "                        dtype = [('e', bool), ('t', float)])\n",
    "        et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))], # For measure performance for given outcome\n",
    "                        dtype = [('e', bool), ('t', float)])\n",
    "        \n",
    "        selection = (t_test < t_train.max()) | (e[folds == fold] == 0)\n",
    "        \n",
    "        et_test = et_test[selection]\n",
    "        survival_train = factual_survival[folds != fold]\n",
    "        survival_fold = factual_survival[folds == fold]\n",
    "\n",
    "        km = EvalSurv(survival_train.T, t_train, e_train, censor_surv = 'km')\n",
    "        test_eval = EvalSurv(survival_fold.T, t_test, e_test, censor_surv = km)\n",
    "\n",
    "        res['Overall'] = {\n",
    "                \"CIS\": test_eval.concordance_td(), \n",
    "            }\n",
    "        try:\n",
    "            res['Overall']['BRS'] = test_eval.integrated_brier_score(times.to_numpy())\n",
    "        except: pass\n",
    "\n",
    "        if len(times_eval) > 0:\n",
    "            et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))], # For estimation censoring\n",
    "                                dtype = [('e', bool), ('t', float)])\n",
    "            et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))], # For measure performance for given outcome\n",
    "                            dtype = [('e', bool), ('t', float)])\n",
    "            selection = (t_test < t_train.max()) | (e_test == 0)\n",
    "            et_test = et_test[selection]\n",
    "            \n",
    "            indexes = [np.argmin(np.abs(times - te)) for te in times_eval]\n",
    "            briers = brier_score(et_train, et_test, survival_fold[selection].iloc[:, indexes], times_eval)[1]\n",
    "            for te, brier, index in zip(times_eval, briers, indexes):\n",
    "                try:\n",
    "                    res[te] = {\n",
    "                        \"CIS\": concordance_index_ipcw(et_train, et_test, 1 - survival_fold[selection].iloc[:, index])[0], \n",
    "                        \"BRS\": brier,\n",
    "                        \"ROCS\": cumulative_dynamic_auc(et_train, et_test, 1 - survival_fold[selection].iloc[:, index])[0][0]}\n",
    "                except:\n",
    "                    pass\n",
    "\n",
    "        results[fold] = pd.DataFrame.from_dict(res)\n",
    "    results = pd.concat(results)\n",
    "    results.index.set_names(['Fold', 'Metric'], inplace = True)\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open file and compute performance\n",
    "predictions, results = {}, {}\n",
    "for file_name in os.listdir(path):\n",
    "    if dataset in file_name and '.csv' in file_name: \n",
    "        model = file_name       \n",
    "        model = model[model.index('_') + 1: model.index('.')]\n",
    "\n",
    "        if model == 'cnsc':\n",
    "            print(\"Opening :\", file_name, ' - ', model)\n",
    "            predictions[model] = pd.read_csv(path + file_name, header = [0, 1], index_col = 0)\n",
    "            results[model] = evaluate(predictions[model])\n",
    "\n",
    "# Rename\n",
    "# TODO: Add your method in the list for nicer display\n",
    "dict_name = {'cnsc': 'CNSC', 'cmhe': 'CMHE'} \n",
    "\n",
    "results = pd.concat(results).rename(dict_name)\n",
    "results.index.set_names('Model', 0, inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute average performance across fold and models\n",
    "table = results.groupby(['Model', 'Metric']).apply(lambda x:  pd.Series([\"{:.3f} ({:.3f})\".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns))\n",
    "table = table.sort_index(level = 0, sort_remaining = False)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(table.to_latex())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Likelihood evolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Anlayze the outcome of the clustering method \n",
    "method_display = 'cmhe' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models in family\n",
    "likelihood = {}\n",
    "for file_name in os.listdir(path):\n",
    "    if '{}_{}+k='.format(dataset, method_display) in file_name and '.pickle' in file_name and '_' not in file_name[file_name.index('k='):]:\n",
    "        model = int(file_name[file_name.rindex('k=')+2: file_name.index('.')])\n",
    "        print(\"Likelihood Computation :\", file_name, ' - ', model)\n",
    "\n",
    "        model_pickle = Experiment.load(path + file_name)\n",
    "        likelihood[model] = model_pickle.likelihood(x, t, e, a)\n",
    "\n",
    "likelihood = pd.DataFrame.from_dict(likelihood, 'index')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = likelihood.sort_index().mean(1)\n",
    "std = 1.96 * likelihood.sort_index().std(1) / np.sqrt(5)\n",
    "\n",
    "mean.plot()\n",
    "plt.fill_between(std.index, mean + std, mean - std, alpha = 0.3)\n",
    "plt.grid(alpha = .3)\n",
    "\n",
    "plt.xlabel(r'Number of clusters $K$')\n",
    "plt.ylabel(r'Negative Log Likelihood')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analysis cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Anlayze the outcome of the clustering method \n",
    "method_display = 'cmhe' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assignment = {}\n",
    "horizon = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = None\n",
    "for i in np.arange(5):\n",
    "    horizons_pred = np.linspace(0, 0.75, 10)\n",
    "    pred = predictions[method_display]\n",
    "    pred = pred[(pred.Use == i).values]\n",
    "\n",
    "    assignment[i] = pred.Assignment.idxmax(1)\n",
    "    for treat in ['treated', 'untreated']:\n",
    "        pred_treat = pred[treat]\n",
    "        pred_treat.columns = pred_treat.columns.map(float)\n",
    "        ax = pred_treat.groupby(assignment[i]).mean(0).T.plot(ax = ax, ls = '--' if treat == 'untreated' else '-')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Survival Predictions')\n",
    "plt.grid(alpha = 0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the experiment associated - Only works when predcit_cluster is available\n",
    "for file_name in os.listdir(path):\n",
    "    if dataset in file_name and method_display + '.pickle' in file_name:\n",
    "        print(\"Cluster Computation :\", file_name)\n",
    "\n",
    "        model_pickle = Experiment.load(path + file_name)\n",
    "        clusters = model_pickle.clusters(model_pickle.times)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = None\n",
    "for i in clusters:\n",
    "    ax = pd.DataFrame(clusters[i], index = model_pickle.times / 12).plot(ax = ax)\n",
    "plt.xlabel('Time (in years)')\n",
    "plt.ylabel('Survival Predictions')\n",
    "plt.title('Estimated Cluster')\n",
    "plt.grid(alpha = 0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "te_cluster, ordering, rmst = {}, {}, {}\n",
    "# Sort treatment effect\n",
    "for fold in clusters:\n",
    "    #selection = (predictions[method_display].Use == fold).values.flatten()\n",
    "    #unique = predictions[method_display][selection].Assignment.idxmax(1).unique().astype(float)\n",
    "    te = pd.DataFrame(clusters[fold], index = model_pickle.times / 12).T.rename_axis('Cluster')#.loc[unique]\n",
    "    ordering[fold] = {i: j for j, i in enumerate(te.T.mean().sort_values().index)}\n",
    "    te_cluster[fold] = te.rename(index = ordering[fold])\n",
    "    step = te.columns[1] - te.columns[0]\n",
    "    rmst[fold] = te_cluster[fold].loc[:, te.columns < horizon].mean(1) / step\n",
    "\n",
    "te_cluster = pd.concat(te_cluster, names = ['Fold'])\n",
    "mean = te_cluster.groupby('Cluster').mean().T\n",
    "std = 1.96 * te_cluster.groupby('Cluster').std().T / np.sqrt(len(mean.columns))\n",
    "ax = mean.plot(legend = False)\n",
    "for k in mean.columns:\n",
    "    ax.fill_between(mean.index, mean[k] + std[k], mean[k] - std[k], alpha = 0.25)\n",
    "\n",
    "ax.grid(alpha = 0.3)\n",
    "ax.legend(title = 'Clusters', loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "plt.ylabel(r'$\\tau_k(t)$')\n",
    "plt.xlabel('Time (in years)')\n",
    "plt.ylim(0,0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('RMST at {} years'.format(horizon))\n",
    "pd.concat(rmst, axis = 1).T.mean(), pd.concat(rmst, axis = 1).T.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cluster Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What is the distribution of probability to be part of a given cluster ?\n",
    "clusters_assignment = {}\n",
    "for fold in clusters:\n",
    "    selection = (predictions[method_display].Use == fold).values.flatten()\n",
    "    unique = predictions[method_display][selection].Assignment.idxmax(1).unique().astype(float)\n",
    "    clusters_assignment[fold] = predictions[method_display][selection].Assignment\n",
    "    clusters_assignment[fold].columns = clusters_assignment[fold].columns.astype(float)\n",
    "    clusters_assignment[fold] = clusters_assignment[fold][unique].rename(columns = ordering[fold])\n",
    "\n",
    "clusters_assignment = pd.concat(clusters_assignment, axis = 0)\n",
    "for cluster in clusters_assignment.columns:\n",
    "    clusters_assignment[cluster].plot.hist(alpha = 0.5, bins = 100)\n",
    "plt.xlabel('Probality cluster')\n",
    "plt.grid(alpha = 0.3)\n",
    "plt.legend(title = 'Clusters')\n",
    "plt.show()\n",
    "\n",
    "# Distribution maximally assigned\n",
    "axes = clusters_assignment.groupby(clusters_assignment.apply(lambda x: np.argmax(x), axis = 1)).boxplot(layout = (1, 3), figsize = (7, 3), grid = 0.5)\n",
    "for ax in axes:\n",
    "    ax.grid(alpha = 0.3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cluster Importance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Only available for NTC\n",
    "for file_name in os.listdir(path):\n",
    "    if dataset in file_name and method_display + '.pickle' in file_name:\n",
    "        print(\"Importance Computation :\", file_name)\n",
    "\n",
    "        model_pickle = Experiment.load(path + file_name)\n",
    "        importance = model_pickle.importance(x, t, e, a, n = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for fold in importance:\n",
    "    importance[fold] = pd.Series(importance[fold][0])\n",
    "importance = - pd.concat(importance, 1, names = ['Fold'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display importance of features obtained by test\n",
    "importance.index = covariates.to_series().replace({'CS mets at dx (2004-2015)': 'Distant lymph nodes', \n",
    "                                       'Derived HER2 Recode (2010+)': 'HER2 status',\n",
    "                                       'ER Status Recode Breast Cancer (1990+)': 'ER status',\n",
    "                                       'CS lymph nodes (2004-2015)': 'Lymph nodes',\n",
    "                                       'PR Status Recode Breast Cancer (1990+)': 'PR status',\n",
    "                                       'CS extension (2004-2015)': 'Tumor extension',\n",
    "                                       'Regional nodes positive (1988+)': 'Positive lymph nodes',\n",
    "                                       'Age recode with <1 year olds': 'Age',\n",
    "                                       'Grade (thru 2017)': 'Grade',\n",
    "                                       'CS tumor size (2004-2015)': 'Tumor size'}).values\n",
    "importance.mean(1).sort_values()[-10:].plot.barh(xerr = importance.std(1))\n",
    "plt.ylabel('Covariate')\n",
    "plt.xlabel('Likelihood change')\n",
    "plt.grid(alpha = 0.3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Characteristics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reopen non normalised data\n",
    "x, a, t, e, covariates = datasets.load_dataset(dataset, standardisation = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Binarized cluster (note that it is the aligned number)\n",
    "binarized = clusters_assignment.idxmax(1).astype(int)\n",
    "binarized.index = binarized.index.get_level_values(1)\n",
    "binarized = binarized.sort_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(x, columns = covariates).groupby(binarized).apply(lambda x:  pd.Series([\"{:.3f} ({:.3f})\".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(x, columns = covariates)['Derived HER2 Recode (2010+)'].groupby(binarized).value_counts(normalize = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('../data/export.csv')\n",
    "df = df.groupby('Patient ID').first().drop(columns= ['Site recode ICD-O-3/WHO 2008'])\n",
    "\n",
    "# Encode using dictionary to remove missing data\n",
    "df[\"RX Summ--Surg Prim Site (1998+)\"].replace('126', np.nan, inplace = True)\n",
    "\n",
    "# Remove not grades\n",
    "grades = ['Well differentiated; Grade I', 'Moderately differentiated; Grade II',\n",
    "    'Poorly differentiated; Grade III', 'Undifferentiated; anaplastic; Grade IV']\n",
    "df = df[df[\"Grade (thru 2017)\"].isin(grades)]\n",
    "\n",
    "categorical_col = [\"Race and origin recode (NHW, NHB, NHAIAN, NHAPI, Hispanic)\", \"Laterality\", \n",
    "    \"Diagnostic Confirmation\", \"Histology recode - broad groupings\", \n",
    "    \"Radiation recode\", \"ER Status Recode Breast Cancer (1990+)\", \"PR Status Recode Breast Cancer (1990+)\",\n",
    "    \"Histologic Type ICD-O-3\", \"ICD-O-3 Hist/behav, malignant\", \"Sequence number\", \"Derived HER2 Recode (2010+)\",\n",
    "    \"CS extension (2004-2015)\", \"CS lymph nodes (2004-2015)\", \"CS mets at dx (2004-2015)\", \"Origin recode NHIA (Hispanic, Non-Hisp)\"]\n",
    "\n",
    "# Remove patients without surgery\n",
    "df = df[df[\"RX Summ--Surg Prim Site (1998+)\"] != '00']\n",
    "\n",
    "# Remove patients without chemo\n",
    "df = df[df[\"Chemotherapy recode (yes, no/unk)\"] == 'Yes']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.reset_index(inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Derived HER2 Recode (2010+)'].groupby(binarized).value_counts(normalize = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(a).groupby(binarized).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(a).groupby(binarized).size() / len(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "survival",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f1b50223f39b64c0c24545f474e3e7d2d3b4b121fe045100fc03a3926bb649af"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
