{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b953d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import feature_groups\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import AUROC, Accuracy\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "import os\n",
    "from os import path\n",
    "from dime.data_utils import ROSMAPDataset, get_group_matrix, get_xy, data_split\n",
    "from dime import MaskingPretrainer, CMIEstimator\n",
    "from pytorch_lightning import Trainer\n",
    "from dime.utils import MaskLayerGrouped\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c1f15b9",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e850bbea",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_metric = AUROC(task='multiclass', num_classes=2)\n",
    "acc_metric = Accuracy(task='multiclass', num_classes=2)\n",
    "\n",
    "device = torch.device('cuda', 1)\n",
    "\n",
    "use_apoe=False\n",
    "rosmap_feature_names = feature_groups.rosmap_feature_names\n",
    "rosmap_feature_groups = feature_groups.rosmap_feature_groups\n",
    "\n",
    "if not use_apoe:\n",
    "    rosmap_feature_names = [f for f in rosmap_feature_names if f not in ['apoe4_1copy','apoe4_2copies']]\n",
    "\n",
    "feature_groups_dict, feature_groups_mask = get_group_matrix(rosmap_feature_names, rosmap_feature_groups)\n",
    "num_groups = len(feature_groups_mask)\n",
    "\n",
    "cols_to_drop = []\n",
    "if cols_to_drop is not None:\n",
    "    rosmap_feature_names = [item for item in rosmap_feature_names if str(rosmap_feature_names.index(item)) not in cols_to_drop]\n",
    "\n",
    "# Load dataset\n",
    "train_dataset = ROSMAPDataset('./data', split='train', cols_to_drop=cols_to_drop, use_apoe=use_apoe)\n",
    "d_in = train_dataset.X.shape[1]  \n",
    "d_out = len(np.unique(train_dataset.Y))\n",
    "\n",
    "val_dataset = ROSMAPDataset('./data', split='val', cols_to_drop=cols_to_drop, use_apoe=use_apoe)\n",
    "test_dataset = ROSMAPDataset('./data', split='test', cols_to_drop=cols_to_drop, use_apoe=use_apoe)\n",
    "\n",
    "df = pd.read_csv(\"./data/rosmap_feature_costs.csv\", header=None)\n",
    "if use_apoe:\n",
    "    feature_costs = df[1].tolist()\n",
    "else:\n",
    "    feature_costs = df[~df[0].isin(['apoe4_1copy','apoe4_2copies'])][1].tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7da91c28",
   "metadata": {},
   "source": [
    "# Set up Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f467e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up architecture\n",
    "hidden = 128\n",
    "dropout = 0.3\n",
    "d_in = train_dataset.X.shape[1]  # 121\n",
    "d_out = len(np.unique(train_dataset.Y))  # 2\n",
    "print(d_out)\n",
    "# Outcome Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in + num_groups, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out)).to(device)\n",
    "\n",
    "# CMI Predictor\n",
    "value_network = nn.Sequential(\n",
    "    nn.Linear(d_in + num_groups, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, num_groups)).to(device)\n",
    "\n",
    "test_dataloader = DataLoader(\n",
    "        test_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "val_dataloader = DataLoader(\n",
    "        val_dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)\n",
    "\n",
    "mask_layer = MaskLayerGrouped(append=True, group_matrix=torch.tensor(feature_groups_mask))\n",
    "trainer = Trainer(\n",
    "                    accelerator='gpu',\n",
    "                    devices=[device.index],\n",
    "                    precision=16\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90f42ff",
   "metadata": {},
   "source": [
    "# Evaluate Penalized Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b17e6331",
   "metadata": {},
   "outputs": [],
   "source": [
    "for trial in range(1):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    \n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(trained_model_path,\n",
    "                                                                     value_network=value_network,\n",
    "                                                                     predictor=predictor,\n",
    "                                                                     mask_layer=mask_layer,\n",
    "                                                                     lr=1e-3,\n",
    "                                                                     max_features=15,\n",
    "                                                                     eps=0.05,\n",
    "                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                                     val_loss_fn=auc_metric,\n",
    "                                                                     eps_decay=0.2,\n",
    "                                                                     eps_steps=10,\n",
    "                                                                     patience=5,\n",
    "                                                                     feature_costs=None\n",
    "                                                            ).to(device)\n",
    "    \n",
    "    avg_num_features_lambda = []\n",
    "    accuracy_scores_lambda = []\n",
    "    all_masks_lambda=[]\n",
    "    \n",
    "    # Evaluation Mode lambda penalty\n",
    "    lamda_values = [0.000001, 0.00001, 0.00007, 0.0003, 0.0005] + [0.004, 0.016, 0.07] \n",
    "\n",
    "    for lamda in lamda_values:\n",
    "        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader,feature_costs=None, lam=lamda)\n",
    "    \n",
    "        y = metric_dict['y']\n",
    "        pred = metric_dict['pred']\n",
    "        auc_score = auc_metric(pred.float(), y)\n",
    "        final_masks = np.array(metric_dict['mask'])\n",
    "        accuracy_scores_lamda.append(auc_score)\n",
    "        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "\n",
    "        print(f\"Lambda={lamda}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "        all_masks_lamda.append(final_masks)\n",
    "        \n",
    "    with open(f'results/rosmap_lambda_ours_trial_{trial}.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa78dc10",
   "metadata": {},
   "source": [
    "# Evaluate Budget Constrained Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd9fe3b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "freq = []\n",
    "\n",
    "for trial in range(0, 1):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                                     value_network=value_network,\n",
    "                                                                     predictor=predictor,\n",
    "                                                                     mask_layer=mask_layer,\n",
    "                                                                     lr=1e-3,\n",
    "                                                                     max_features=15,\n",
    "                                                                     eps=0.05,\n",
    "                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                                     val_loss_fn=auc_metric,\n",
    "                                                                     eps_decay=0.2,\n",
    "                                                                     eps_steps=10,\n",
    "                                                                     patience=5,\n",
    "                                                                     feature_costs=None\n",
    "#                                                              cmi_scaling='positive'\n",
    "                                                            ).to(device)\n",
    "    avg_num_features_budget = []\n",
    "    accuracy_scores_budget = []\n",
    "    all_masks_budget=[]\n",
    "    max_budget_values =list(range(1, 15, 1))\n",
    "    freq.append([0] * 43)\n",
    "\n",
    "    for budget in max_budget_values:\n",
    "        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)\n",
    "        \n",
    "        y = metric_dict_budget['y']\n",
    "        pred = metric_dict_budget['pred']\n",
    "        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))\n",
    "        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class\n",
    "        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]\n",
    "        print(\"Rebalanced accuracy: {}\".format(cls_avg))\n",
    "        accuracy_score = acc_metric(pred.float(), y)\n",
    "        final_masks = np.array(metric_dict_budget['mask'])\n",
    "        accuracy_scores_budget.append(accuracy_score)\n",
    "        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "        print(f\"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}\")\n",
    "        freq.append(list(sum(final_masks) / final_masks.shape[0]))\n",
    "\n",
    "        all_masks_budget.append(final_masks)\n",
    "    \n",
    "    with open(f'results/rosmap_ours_costs_inference_trial_{trial}.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25eefc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "df = pd.DataFrame(np.array(freq))\n",
    "plt.rcParams.update({'font.size': 17})\n",
    "\n",
    "sns.heatmap(df,  cmap=\"YlGnBu\")\n",
    "plt.xlabel(\"Feature Index\")\n",
    "plt.ylabel(\"Avg. # Features\")\n",
    "plt.title(\"ROSMAP\")\n",
    "plt.savefig(\"ROSMAP_Selection_Freq.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ddc7add",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "\n",
    "feature_costs = np.array([10, 10, 10, 10, 10, 600, 600, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 60, 300, 120, 300, 180, 300, 300, 180, 180, 60, 180, 900, 450, 180, 1200, 600, 120, 300, 180, 180, 180, 60, 60, 60, 600])\n",
    "\n",
    "for trial in range(0, 1):\n",
    "    results_dict = {\"acc\": {}}\n",
    "    trained_model_path = f\"<path_to_trained_model>\"\n",
    "\n",
    "    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(trained_model_path,\n",
    "                                                                     value_network=value_network,\n",
    "                                                                     predictor=predictor,\n",
    "                                                                     mask_layer=mask_layer,\n",
    "                                                                     lr=1e-3,\n",
    "                                                                     max_features=15,\n",
    "                                                                     eps=0.05,\n",
    "                                                                     loss_fn=nn.CrossEntropyLoss(reduction='none'),\n",
    "                                                                     val_loss_fn=auc_metric,\n",
    "                                                                     eps_decay=0.2,\n",
    "                                                                     eps_steps=10,\n",
    "                                                                     patience=5,\n",
    "                                                                     feature_costs=feature_costs\n",
    "                                                            ).to(device)\n",
    "    avg_num_features_budget = []\n",
    "    accuracy_scores_budget = []\n",
    "    all_masks_budget=[]\n",
    "    max_budget_values = list(range(1, 15, 1))\n",
    "    \n",
    "    for budget in max_budget_values:\n",
    "        metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=feature_costs, budget=budget)\n",
    "        \n",
    "        y = metric_dict_budget['y']\n",
    "        pred = metric_dict_budget['pred']\n",
    "        conf_matrix = metrics.confusion_matrix(y, np.argmax(pred, axis=1))\n",
    "        cls_acc = np.diag(conf_matrix) / np.sum(conf_matrix, 1) # accuracy per class\n",
    "        cls_avg = np.sum(cls_acc) / conf_matrix.shape[0]\n",
    "        print(\"Rebalanced accuracy: {}\".format(cls_avg))\n",
    "        accuracy_score = acc_metric(pred.float(), y)\n",
    "        final_masks = np.array(metric_dict_budget['mask'])\n",
    "        accuracy_scores_budget.append(accuracy_score)\n",
    "        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))\n",
    "        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score\n",
    "        print(f\"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks * feature_costs, axis=1))}\")\n",
    "\n",
    "        all_masks_budget.append(final_masks)\n",
    "    \n",
    "    with open(f'results/rosmap_ours_costs_inference_trial_{trial}.pkl', 'wb') as f:\n",
    "        pickle.dump(results_dict, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e1598d1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:test_DIME_2]",
   "language": "python",
   "name": "conda-env-test_DIME_2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
