{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "077880ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import feature_groups\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import AUROC,Accuracy\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "import os\n",
    "from os import path\n",
    "from dime.data_utils import DenseDatasetSelected, get_group_matrix, get_xy, data_split\n",
    "from dime.utils import MaskLayerGrouped\n",
    "from dime import MaskingPretrainer, CMIEstimator\n",
    "from pytorch_lightning import Trainer\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd11790a",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6c13be",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_metric = AUROC(task='multiclass', num_classes=2)\n",
    "acc_metric = Accuracy(task='multiclass', num_classes=2)\n",
    "intub_feature_names = feature_groups.intub_feature_names\n",
    "intub_feature_groups = feature_groups.intub_feature_groups\n",
    "device = torch.device('cuda', 1)\n",
    "\n",
    "cols_to_drop = []\n",
    "if cols_to_drop is not None:\n",
    "    intub_feature_names = [item for item in intub_feature_names if str(intub_feature_names.index(item)) not in cols_to_drop]\n",
    "\n",
    "# Load dataset\n",
    "dataset = DenseDatasetSelected('data/intub.csv', cols_to_drop=cols_to_drop)\n",
    "d_in = dataset.X.shape[1]  # 121\n",
    "d_out = len(np.unique(dataset.Y))  # 2\n",
    "feature_groups_dict, feature_groups_mask = get_group_matrix(intub_feature_names, intub_feature_groups)\n",
    "num_groups = len(feature_groups_mask) \n",
    "train_dataset, val_dataset, test_dataset = random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(0))\n",
    "# Find mean/variance for normalizing\n",
    "x, y = get_xy(train_dataset)\n",
    "mean = np.mean(x, axis=0)\n",
    "std = np.std(y, axis=0)\n",
    "\n",
    "# Normalize via the original dataset\n",
    "dataset.X = dataset.X - mean\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec9db45a",
   "metadata": {},
   "source": [
    "# Set up networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a438b433",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up architecture\n",
    "hidden = 128\n",
    "dropout = 0.3\n",
    "d_in = dataset.X.shape[1]  # 121\n",
    "d_out = len(np.unique(dataset.Y))  # 2\n",
    "print(d_out)\n",
    "# Outcome Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in + num_groups, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out)).to(device)\n",
    "\n",
    "# CMI Predictor\n",
    "value_network = nn.Sequential(\n",
    "    nn.Linear(d_in + num_groups, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, num_groups)).to(device)\n",
    "\n",
    "test_dataloader = DataLoader(\n",
    "        test_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "val_dataloader = DataLoader(\n",
    "        val_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "\n",
    "mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))\n",
    "\n",
    "trainer = Trainer(\n",
    "                    accelerator='gpu',\n",
    "                    devices=[device.index],\n",
    "                    precision=16\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdd03282",
   "metadata": {},
   "source": [
    "# Evaluate Penalized Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ddf9e59",
   "metadata": {},
   "outputs": [],
   "source": [
    "for trial in range(5):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                             value_network=value_network,\n",
    "                                                             predictor=predictor,\n",
    "                                                             mask_layer=mask_layer,\n",
    "                                                             lr=1e-3,\n",
    "                                                             max_features=30,\n",
    "                                                             eps=0.1,\n",
    "                                                             loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                             val_loss_fn=auc_metric,\n",
    "                                                             eps_decay=0.2,\n",
    "                                                             eps_steps=10,\n",
    "                                                             patience=5,\n",
    "                                                             feature_costs=None\n",
    "                                                            ).to(device)\n",
    "    avg_num_features_lambda = []\n",
    "    accuracy_scores_lambda = []\n",
    "    all_masks_lambda=[]\n",
    "    # Evaluation Mode lambda penalty\n",
    "    lamda_values = [0.000001, 0.00001, 0.00007, 0.0003, 0.0005] + [0.0012, 0.016] \n",
    "\n",
    "    for lamda in lamda_values:\n",
    "        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader,feature_costs=None, lam=lamda)\n",
    "    \n",
    "        y = metric_dict['y']\n",
    "        pred = metric_dict['pred']\n",
    "        auc_score = acc_metric(pred, y)\n",
    "        final_masks = np.array(metric_dict['mask'])\n",
    "        accuracy_scores_lamda.append(auc_score)\n",
    "        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "\n",
    "        print(f\"Lambda={lamda}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "        all_masks_lamda.append(final_masks)\n",
    "    \n",
    "    with open(f'results/intub_lamda_ours_trial_{trial}.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ed40419",
   "metadata": {},
   "source": [
    "# Evaluate with Budget Constrained Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b6ba0bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "num_selections = []\n",
    "feature_index = []\n",
    "freq = []\n",
    "\n",
    "for trial in range(1):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                                     value_network=value_network,\n",
    "                                                                     predictor=predictor,\n",
    "                                                                     mask_layer=mask_layer,\n",
    "                                                                     lr=1e-3,\n",
    "                                                                    max_features=40,\n",
    "                                                                    eps=0.1,\n",
    "                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                                    val_loss_fn=auc_metric,\n",
    "                                                                    eps_decay=0.2,\n",
    "                                                                    eps_steps=10,\n",
    "                                                                    patience=5,\n",
    "                                                                    feature_costs=None,\n",
    "                                                             cmi_scaling='bounded'\n",
    "                                                            ).to(device)\n",
    "    avg_num_features_budget = []\n",
    "    accuracy_scores_budget = []\n",
    "    all_masks_budget=[]\n",
    "    max_budget_values = range(1, 16) #[1, 3, 5, 10, 15, 20, 25]\n",
    "    freq.append([0] * 36)\n",
    "    for budget in max_budget_values:\n",
    "        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)\n",
    "        \n",
    "        y = metric_dict_budget['y']\n",
    "        pred = metric_dict_budget['pred']\n",
    "        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))\n",
    "        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class\n",
    "        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]\n",
    "        print(\"Rebalanced accuracy: {}\".format(cls_avg))\n",
    "\n",
    "        accuracy_score = acc_metric(pred.float(), y)\n",
    "        final_masks = np.array(metric_dict_budget['mask'])\n",
    "        accuracy_scores_budget.append(accuracy_score)\n",
    "        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "        print(f\"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "#         num_selections += [budget] * final_masks.shape[1]\n",
    "#         feature_index += list(range(0, final_masks.shape[1]))\n",
    "        freq.append(list(sum(final_masks) / final_masks.shape[0]))\n",
    "        all_masks_budget.append(final_masks)\n",
    "    \n",
    "    with open(f'results/intub_ours_trial_{trial}_auc.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4488e9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sum(final_masks, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34719ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "feature_costs = [1, 1, 3, 3, 3, 3, 4, 4, 5, 4, 4, 6, 6, 6, 6, 7, 7, 7, 2, 8, 8, 9, 8, 7, 7, 7, 8, 10, 9, 7, 10, 10, 10, 10, 6, 8]\n",
    "freq = []\n",
    "for trial in range(1):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                                     value_network=value_network,\n",
    "                                                                     predictor=predictor,\n",
    "                                                                     mask_layer=mask_layer,\n",
    "                                                                     lr=1e-3,\n",
    "                                                                    max_features=40,\n",
    "                                                                    eps=0.1,\n",
    "                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                                    val_loss_fn=auc_metric,\n",
    "                                                                    eps_decay=0.2,\n",
    "                                                                    eps_steps=10,\n",
    "                                                                    patience=5,\n",
    "                                                                    feature_costs=np.array(feature_costs)\n",
    "                                                            ).to(device)\n",
    "    avg_num_features_budget = []\n",
    "    accuracy_scores_budget = []\n",
    "    all_masks_budget=[]\n",
    "    max_budget_values = range(1, 15)#[1, 3, 5, 10, 15, 20, 25, 30]\n",
    "    for budget in max_budget_values:\n",
    "        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=np.array(feature_costs), budget=budget)\n",
    "        \n",
    "        y = metric_dict_budget['y']\n",
    "        pred = metric_dict_budget['pred']\n",
    "        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))\n",
    "        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class\n",
    "        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]\n",
    "        print(\"Rebalanced accuracy: {}\".format(cls_avg))\n",
    "\n",
    "        accuracy_score = acc_metric(pred.float(), y)\n",
    "        final_masks = np.array(metric_dict_budget['mask'])\n",
    "        accuracy_scores_budget.append(accuracy_score)\n",
    "        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "        print(f\"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks * feature_costs, axis=1))}\")\n",
    "        freq.append(list(sum(final_masks) / final_masks.shape[0]))\n",
    "\n",
    "        all_masks_budget.append(final_masks)\n",
    "    \n",
    "    with open(f'results/intub_ours_trial_{trial}_auc.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:test_DIME_2]",
   "language": "python",
   "name": "conda-env-test_DIME_2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
