{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file allows to analyze results obtained by running run_all.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from generate import *\n",
    "from experiment import Experiment\n",
    "\n",
    "import seaborn as sns\n",
    "custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False, \"axes.spines.left\": False,\n",
    "                 \"axes.spines.bottom\": False, \"figure.dpi\": 300, 'savefig.dpi': 300}\n",
    "sns.set_theme(style = \"whitegrid\", rc = custom_params, font_scale = 1.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'Results_ntc/' # Path where the data is saved\n",
    "\n",
    "mode =  'obs' # Treatment assignment 'obs' or 'rand'\n",
    "random_seed = 42 # Default experiment 42\n",
    "simulation = 'linear' # Type of simulation 'linear', 'small', 'size', 'homogenous', 'treat'. For main results use ''\n",
    "parameter = 3 # Associated parameters for 'size': 300 or 30000, 'treat': 0.25 or 0.75, and for the main '': 3 or 5\n",
    "\n",
    "root = 'generate{}_'.format(simulation) # Associated file name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import adjusted_rand_score\n",
    "\n",
    "### Utils: The evaluatino metrics used\n",
    "def evaluate(clusters_pred, te_cluster, survival_pred, survival_gt, a, t, e, times, groups): \n",
    "    folds = survival_pred[('Use',)]\n",
    "    survival_pred = survival_pred.drop(columns = ['Use', 'Assignment'])\n",
    "    results = {}\n",
    "\n",
    "    # Compute performance for each fold\n",
    "    for fold in np.arange(5):\n",
    "        # Subselect all variables in fold\n",
    "        clusters_pred_fold, survival_pred_fold = clusters_pred[(folds == fold).values], survival_pred[(folds == fold).values]\n",
    "        groups_fold = groups.loc[clusters_pred_fold.index]\n",
    "        te_fold, te_gt_fold = survival_pred_fold['treated'] - survival_pred_fold['untreated'], \\\n",
    "                        (survival_gt['treated'] - survival_gt['untreated']).loc[survival_pred_fold.index]\n",
    "\n",
    "        # Evaluate quality cluster at the population level\n",
    "        results_fold = {}\n",
    "        results_fold['Population'] = {\n",
    "            \"Rand_index\": adjusted_rand_score(groups_fold, clusters_pred_fold.apply(lambda x: x.argmax(), 1)), \n",
    "            \"MSE_Mean_TE\": mse_mean(te_fold, te_gt_fold)}\n",
    "        \n",
    "        # At the group level\n",
    "        for group in groups.unique() if groups is not None else []:\n",
    "            selection = groups_fold == group\n",
    "            alpha_max = clusters_pred_fold[selection].mean(0).argmax() # Find closest cluster in expectation\n",
    "            cluster_te = te_cluster[fold][:, alpha_max] if te_cluster is not None else None\n",
    "            results_fold[group] = {\n",
    "                \"MSE_Cluster_TE\": mse_cluster(te_gt_fold[selection], cluster_te) if te_cluster is not None else np.nan\n",
    "            }\n",
    "\n",
    "        results[fold] = pd.concat(results_fold, axis = 1)\n",
    "\n",
    "    return pd.concat(results)\n",
    "\n",
    "def mse_mean(pred, gt):\n",
    "    return np.abs(pred.mean(0).values - gt.mean(0).values).mean()\n",
    "\n",
    "def mse_cluster(pred, mean):\n",
    "    return np.abs(pred.mean(0).values - mean).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute performances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Rename\n",
    "dict_name = {'ntc': 'NTC', 'ntc+uncorrect': 'NTC (Unadjusted)', 'cmhe+g': 'CMHE (Treatment)', 'cmhe+k': 'CMHE (Survival)', 'cmhe+gk': 'CMHE (Combine)'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open file and compute performance\n",
    "treated, untreated, clusters, results, te_cluster = {}, {}, {}, {}, {}\n",
    "for file_name in sorted(os.listdir(path)):\n",
    "    if root in file_name and '.csv' in file_name: \n",
    "        if not (str(random_seed) in file_name): continue\n",
    "        if mode not in file_name: continue\n",
    "        if ('{}+'.format(random_seed) in file_name) and not('+{}_'.format(parameter) in file_name): continue # Naming convention (parameter will follow seed)\n",
    "        model = file_name\n",
    "\n",
    "        model = model[model.rindex('_') + 1: model.rindex('.')]\n",
    "        model = dict_name[model] if model in dict_name else model\n",
    "        print(\"Opening :\", file_name, ' - ', model, ' - ', random_seed, ' - ', mode)\n",
    "\n",
    "        if model not in results:\n",
    "            results[model], treated[model], untreated[model], clusters[model], te_cluster[model] = {}, {}, {}, {}, {}\n",
    "\n",
    "        predictions = pd.read_csv(path + file_name, header = [0, 1], index_col = 0).dropna()\n",
    "        treated[model][random_seed]  = predictions[['treated']].droplevel(0, axis = 1)\n",
    "        untreated[model][random_seed]= predictions[['untreated']].droplevel(0, axis = 1)\n",
    "        clusters[model][random_seed] = predictions[['Assignment']].droplevel(0, axis = 1)\n",
    "\n",
    "        # Remove last columns and change name column to flo\n",
    "        times = treated[model][random_seed].columns.astype(float)\n",
    "\n",
    "        # Generate associated ground truth\n",
    "        if parameter == 5:\n",
    "            centers = ([0, 2.25], [-2.25, -1], [2.25, -1], [-3, 3], [4, 4])\n",
    "        else:\n",
    "            centers = ([0, 2.25], [-2.25, -1], [2.25, -1])\n",
    "\n",
    "        if 'linear' in root:\n",
    "            x, a, t, e, (cluster_centers, parameters, outcomes, assignement) = generate_linear(random_seed, mode = mode)\n",
    "            cifs = compute_cif_linear(x, outcomes.cluster, cluster_centers, parameters, times)\n",
    "        else:\n",
    "            x, a, t, e, (cluster_centers, parameters, outcomes, assignement) = generate(random_seed, mode = mode, centers = centers, \n",
    "                                                                                        homogenous = 'homogenous' in root, \n",
    "                                                                                        proportions = [0.625, 0.25, 0.125] if 'small' in root else None, \n",
    "                                                                                        size = parameter if 'size' in root else 3000,\n",
    "                                                                                        percentage_treatment = parameter if 'treat' in root else 0.5)\n",
    "            cifs = compute_cif(x, outcomes.cluster, cluster_centers, parameters, times)\n",
    "\n",
    "        model_file = file_name.replace('.csv', '.pickle')\n",
    "        if os.path.isfile(path + model_file):\n",
    "            model_pickle = Experiment.load(path + model_file)\n",
    "            te_cluster[model][random_seed] = model_pickle.clusters(times)\n",
    "        else:\n",
    "            te_cluster[model][random_seed] = None\n",
    "\n",
    "\n",
    "        # Evaluate\n",
    "        results[model][random_seed] = evaluate(clusters[model][random_seed], te_cluster[model][random_seed], predictions, 1 - cifs,\n",
    "                                               a, t, e, times, outcomes.cluster)\n",
    "else:\n",
    "    te_cluster['GT'] = {random_seed: {}}\n",
    "    for fold in range(5): \n",
    "        index = (predictions[('Use',)] == fold).values\n",
    "        te_cluster['GT'][random_seed][fold] =  (cifs['untreated'] - cifs['treated']).loc[index].groupby(outcomes.cluster.loc[index]).mean(0).T.values\n",
    "                                                \n",
    "\n",
    "results = pd.concat({model: pd.concat(results[model], names = ['Seed']) for model in results})\n",
    "results.index.set_names(['Model', 'Fold', 'Metric'], level = [0, 2, 3], inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results.groupby(['Model', 'Seed', 'Metric']).apply(lambda x:  pd.Series([\"{:.3f} ({:.3f})\".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute average performance across fold and models\n",
    "table = results[results.index.get_level_values('Metric') == 'MSE_Cluster_TE'].dropna(axis = \"columns\")\n",
    "table = table.groupby(['Model', 'Seed', 'Metric']).apply(lambda x:  pd.Series([\"{:.3f} ({:.3f})\".format(mean, std) for mean, std in zip(x.mean(), x.std())], index = x.columns))\n",
    "table = table.sort_index(level = 0, sort_remaining = False)\n",
    "print(table.to_latex())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display estimated treatment effects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in te_cluster:\n",
    "    try:\n",
    "        patches = []\n",
    "        for model in [model, 'GT']: # Display both ground truth and current method\n",
    "            if te_cluster[model][random_seed] is None: continue\n",
    "            mean_te = {cluster: [] for cluster in range(te_cluster[model][random_seed][0].shape[1])}\n",
    "            for fold in te_cluster[model][random_seed]:\n",
    "                argmin = np.argsort(te_cluster[model][random_seed][fold].mean(0)) # Order by mean treatment effect\n",
    "                for key, value in enumerate(argmin):\n",
    "                    mean_te[key].append(te_cluster[model][random_seed][fold][:, value])\n",
    "            ax = None\n",
    "            for cluster in mean_te:\n",
    "                mean = pd.DataFrame(mean_te[cluster]).mean() # Average across folds\n",
    "                std = 1.96 * pd.DataFrame(mean_te[cluster]).std() / np.sqrt(5) # Estimate normal confidence interval\n",
    "                ax = mean.plot(ax = ax, label = model, ls = ':' if model == 'GT' else '-', color = 'tab:grey' if model == 'GT' else None)\n",
    "                ax.fill_between(mean.index, mean + std, mean - std, alpha = 0.25, color = ax.get_lines()[-1].get_color() )\n",
    "\n",
    "        plt.xlabel('Time')\n",
    "        plt.ylabel('Treatment effect')\n",
    "        plt.xlim(-1, 1)\n",
    "        plt.title(model)\n",
    "        plt.show()\n",
    "    except: pass\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "survival",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f1b50223f39b64c0c24545f474e3e7d2d3b4b121fe045100fc03a3926bb649af"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
