{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Neural Survival Clustering on METABRIC"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will apply Causal Neural Survival Clustering on the METABRIC dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cnsc.datasets import load_dataset\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, a, t, e, col = load_dataset('METABRIC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = pd.DataFrame(x, columns = col)\n",
    "a, t, e = pd.Series(a), pd.Series(t), pd.Series(e) # Reformate data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute horizons at which we evaluate the performance of CNSC\n",
    "\n",
    "Survival predictions are issued at certain time horizons. Here we will evaluate the performance\n",
    "of CNSC to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fix seeds\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.random.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate time horizons\n",
    "horizons = [0.25, 0.5, 0.75]\n",
    "times = np.quantile(t[e!=0], horizons).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the percentage of observed event at different time horizon\n",
    "for treat in np.unique(a):\n",
    "    selection = (a == treat)\n",
    "    print('-' * 42)\n",
    "    for time in times:\n",
    "        print('At time {:.2f} months'.format(time))\n",
    "        for risk in np.unique(e):\n",
    "            print('\\t {:.2f} % observed risk {}'.format(100 * ((e[selection] == risk) & (t[selection] < time)).mean(), risk))\n",
    "    print('Total')\n",
    "    for risk in np.unique(e):\n",
    "        print('\\t {:.2f} % observed risk {}'.format(100 * ((e[selection] == risk)).mean(), risk))\n",
    "              \n",
    "print('-' * 42)\n",
    "print('Overall')\n",
    "for risk in np.unique(e):\n",
    "    print('\\t {:.2f} % observed risk {}'.format(100 * ((e == risk)).mean(), risk))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Splitting the data into train, test and validation sets\n",
    "\n",
    "We will train CNSC on 80% of the Data (10 % of which is used for stopping criterion and 10% for model Selection) and report performance on the remaining 20% held out test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train, test = train_test_split(x.index, test_size = 0.2, random_state = 42)\n",
    "train, val  = train_test_split(train, test_size = 0.2, random_state = 42)\n",
    "val, dev    = train_test_split(val, test_size = 0.5, random_state = 42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minmax = lambda x: x / t.loc[train].max() # Enforce to be inferior to 1\n",
    "t_ddh = minmax(t)\n",
    "times_ddh = minmax(np.array(times))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting the parameter grid\n",
    "\n",
    "Lets set up the parameter grid to tune hyper-parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [[50, 50], [50, 50, 50]]\n",
    "param_grid = {\n",
    "            'layers_surv': layers,\n",
    "            'k': [3],\n",
    "            'representation': [10],\n",
    "            'layers' : layers,\n",
    "            'act': ['Tanh']\n",
    "            }\n",
    "params = ParameterSampler(param_grid, 3, random_state = 42)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Training and Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cnsc import CausalNeuralSurvivalClustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = []\n",
    "for param in params:\n",
    "    print(param)\n",
    "\n",
    "    # Train model on the same set with same stopping\n",
    "    model = CausalNeuralSurvivalClustering(**param, correct = True, multihead = False)\n",
    "    model.fit(x.loc[train].values, t_ddh.loc[train].values, e.loc[train].values, a.loc[train].values, n_iter = 1000, bs = 250,\n",
    "            lr = 0.001, val_data = (x.loc[dev].values, t_ddh.loc[dev].values, e.loc[dev].values, a.loc[dev].values))\n",
    "    nll = model.compute_nll(x.loc[val].values, t_ddh.loc[val].values, e.loc[val].values, a.loc[val].values)\n",
    "\n",
    "    # Save model\n",
    "    models.append([nll, model])\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Discrimination should be 0 when negative gamma as it is possible to predict it given the covariates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model = min(models, key = lambda x: x[0])\n",
    "model = best_model[1]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation\n",
    "\n",
    "We evaluate the performance of CNSC in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on both the **factual** distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Factual loss\n",
    "out_survival = model.predict_survival(x.loc[test].values, times_ddh.tolist(), a.loc[test].values)\n",
    "out_risk = 1 - out_survival\n",
    "\n",
    "# Evaluation in the context of competing risks\n",
    "et_train = np.array([(e.loc[i] == 1, t.loc[i]) for i in train],\n",
    "                dtype = [('e', bool), ('t', float)])\n",
    "et_test = np.array([(e.loc[i] == 1, t.loc[i]) for i in test],\n",
    "                dtype = [('e', bool), ('t', float)])\n",
    "selection = (t.loc[test] < t.loc[train].max())\n",
    "\n",
    "cis = []\n",
    "for i, _ in enumerate(times):\n",
    "    cis.append(concordance_index_ipcw(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])\n",
    "brs = brier_score(et_train, et_test[selection], out_survival[selection], times)[1]\n",
    "roc_auc = []\n",
    "for i, _ in enumerate(times):\n",
    "    roc_auc.append(cumulative_dynamic_auc(et_train, et_test[selection], out_risk[:, i][selection], times[i])[0])\n",
    "for horizon in enumerate(horizons):\n",
    "    print(f\"For {horizon[1]} quantile,\")\n",
    "    print(\"TD Concordance Index:\", cis[horizon[0]])\n",
    "    print(\"Brier Score:\", brs[horizon[0]])\n",
    "    print(\"ROC AUC \", roc_auc[horizon[0]][0], \"\\n\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  Treatment effect evaluation\n",
    "\n",
    "In this section, we evaluate how good is the treatment estimation. We display the KM estimate and estimate of the model clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract same eval time than saved rmst\n",
    "eval_times = np.linspace(0, t.max(), 100)\n",
    "norm_eval_times = minmax(eval_times)\n",
    "delta = eval_times[1] - eval_times[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate the assignment of each points to the different clusters\n",
    "alphas = pd.DataFrame(model.predict_alphas(x.loc[test].values), index = test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimates at the same points than RMST and CIF\n",
    "estimated_survival = pd.concat({treatment: pd.DataFrame(model.predict_survival(x.loc[test].values, norm_eval_times.tolist(), a = value), columns = eval_times, index = test)\n",
    "                           for value, treatment in enumerate(['untreated', 'treated'])}, axis = 1, names = ['Treatment'])\n",
    "estimated_cif = 1 - estimated_survival\n",
    "estimated_rmse = estimated_cif[('untreated',)] - estimated_cif[('treated',)]\n",
    "\n",
    "estimated_cluster_treatment = pd.DataFrame(model.treatment_effect_cluster(norm_eval_times.tolist()).T, columns = eval_times)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Population level"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Estimate the population level treatment effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean, std = estimated_rmse.mean(0), 1.96 * estimated_rmse.std(0) / np.sqrt(len(estimated_rmse))\n",
    "ax = mean.rename('Estimate').plot(ls = '-.')\n",
    "plt.fill_between(mean.index, mean + std, mean - std, alpha = 0.3, color = ax.get_lines()[-1].get_color())\n",
    "\n",
    "plt.ylabel('Treatment effect')\n",
    "plt.title('Mean outcome')\n",
    "plt.grid(alpha = 0.3)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature importance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Estimate which feature most impact the model assignment through a permutation test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importance, confidence = model.feature_importance(x.loc[test].values, t.loc[test].values, e.loc[test].values, a.loc[test].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(pd.DataFrame({'Value': 100 * np.array(list(importance.values())), 'Conf': confidence.values()}, index = col)).sort_values('Value').plot.bar(yerr = 'Conf')\n",
    "plt.ylabel('% change in NLL')\n",
    "plt.xlabel('Covariates')\n",
    "plt.grid(alpha = 0.3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cluster level"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Analyse the group clusters, by displaying treatment effect and their differences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, a, t, e, col = load_dataset('METABRIC', path = 'data/', standardisation = False)\n",
    "x = pd.DataFrame(x, columns = col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in range(model.torch_model.k):\n",
    "    alphas_max = (alphas.apply(lambda x: x.argmax(), 1) == k)\n",
    "    ax = estimated_cluster_treatment.loc[k].rename('Cluster {} (n = {}, a = {})'.format(k, alphas_max.sum(), a[test][alphas_max].sum())).plot()\n",
    "    estimated_rmse[alphas_max].mean(0).rename('Average Effect').plot(ls = '--', color = ax.lines[-1].get_color())\n",
    "plt.ylabel('Treatment effect')\n",
    "plt.xlabel('Time (in years)')\n",
    "plt.grid(alpha = 0.3)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import kruskal\n",
    "results = x.loc[test].groupby(alphas.loc[test].idxmax(1)).apply(lambda x:  pd.Series([\"{:.3f} ({:.3f})\".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns)).T\n",
    "results['P-Value'] = [kruskal(*[x[col].loc[test][alphas.loc[test].idxmax(1) == i] for i in range(2)]).pvalue for col in results.index]\n",
    "results.sort_values('P-Value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('survival')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "vscode": {
   "interpreter": {
    "hash": "f1b50223f39b64c0c24545f474e3e7d2d3b4b121fe045100fc03a3926bb649af"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
