{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Figure aesthetics\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib_inline.backend_inline\n",
    "sns.set_theme(style='white')\n",
    "sns.set_context('notebook', font_scale=1.25)\n",
    "sns.set_style('ticks', rc={ 'figure.facecolor': 'none', 'axes.facecolor': 'none'})\n",
    "matplotlib_inline.backend_inline.set_matplotlib_formats('retina')\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute average SHAP inference times and prediction quality\n",
    "\n",
    "----"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_SHAP_inference = pd.read_csv('/df_SHAP_inference_times.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_SHAP_inference.groupby('Num. samples').mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "sns.lineplot(data=df_SHAP_inference, y='Inference time (s)', x='Num. samples', marker='o')\n",
    "ax.set_ylabel('Inference time (s)', labelpad=15);\n",
    "ax.set_xlabel('Number of re-evaluation samples with SHAP KernelExplainer', labelpad=15);\n",
    "sns.despine()\n",
    "ax.set_title('SHAP Kernelexplainer number of sampels vs. time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP_toshap(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size, n_layers, drop_prob):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        for i in range(n_layers-1):\n",
    "            layers += [\n",
    "                nn.Linear(input_size, hidden_size),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Dropout(drop_prob)\n",
    "            ]\n",
    "            input_size = hidden_size\n",
    "\n",
    "        # Add output layer\n",
    "        layers += [\n",
    "            nn.Linear(input_size, output_size),\n",
    "            ]\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.layers(x)\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loader\n",
    "class ShapDataset(Dataset):\n",
    "    def __init__(self, Feat, Shap):\n",
    "        self.Feat = Feat\n",
    "        self.Shap = Shap\n",
    "    def __len__(self):\n",
    "        ''' get total number of samples in dataset '''\n",
    "        return self.Feat.shape[0]\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        ''' get 1D tensor of weights and respective payoffs'''\n",
    "        return (\n",
    "            self.Feat[index, :].float(),\n",
    "            self.Shap[index, :].float()\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load test data\n",
    "path = '< CHECKPOINT_PATH >'\n",
    "shap_test = torch.load('< SHAP_PATH >')\n",
    "feats_test = torch.load('< FEATURE_PATH>')\n",
    "NUM_FEATS = 13\n",
    "\n",
    "test_set = ShapDataset(feats_test, shap_test)\n",
    "\n",
    "test_loader = DataLoader(\n",
    "            test_set, \n",
    "            batch_size=128,\n",
    "            shuffle=True, \n",
    "            drop_last=True\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model trained on 1000 samples\n",
      "Evaluating model: 1000\n",
      "Model trained on 1500 samples\n",
      "Evaluating model: 1500\n",
      "Model trained on 2000 samples\n",
      "Evaluating model: 2000\n",
      "Model trained on 2500 samples\n",
      "Evaluating model: 2500\n",
      "Model trained on 3000 samples\n",
      "Evaluating model: 3000\n"
     ]
    }
   ],
   "source": [
    "TRAIN_SET_SIZES = [1000, 1500, 2000, 2500, 3000]\n",
    "shap_actual_tensor = torch.zeros((len(TRAIN_SET_SIZES), len(test_set), NUM_FEATS))\n",
    "shap_pred_tensor = torch.zeros((len(TRAIN_SET_SIZES), len(test_set), NUM_FEATS))\n",
    "train_times = np.zeros(len(TRAIN_SET_SIZES))\n",
    "\n",
    "inf_times = np.zeros(len(TRAIN_SET_SIZES))\n",
    "mean_squared_errors = torch.zeros((len(TRAIN_SET_SIZES), len(test_set)))\n",
    "\n",
    "for i, size in enumerate(TRAIN_SET_SIZES):\n",
    "\n",
    "    print(f'Model trained on {size} samples')\n",
    "    print(f'Evaluating model: {size}')\n",
    "\n",
    "    best_model = f'_best_model_trainsize1000.pth'\n",
    "\n",
    "    # Load model\n",
    "    checkpoint = torch.load(\n",
    "        f'{path}{best_model}',  \n",
    "        map_location=torch.device('cpu')\n",
    "    )\n",
    "\n",
    "    model = MLP_toshap(\n",
    "    input_size=shap_test.shape[1],\n",
    "    hidden_size=checkpoint['hidden_size'],\n",
    "    output_size=shap_test.shape[1],\n",
    "    n_layers=checkpoint['n_layers'],\n",
    "    drop_prob=checkpoint['drop_prob'],\n",
    "    )\n",
    "\n",
    "    model.load_state_dict(\n",
    "        checkpoint['model_state_dict']\n",
    "    ) \n",
    "\n",
    "    train_times[i] = checkpoint['train_time']\n",
    "\n",
    "    start = time.time()\n",
    "    # Evaluate \n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        for s, (feats, shap_actual) in enumerate(test_loader):\n",
    "            # Predict Shapley values\n",
    "            shap_pred = model(feats)\n",
    "            mean_squared_errors[i, s] = (shap_pred - shap_actual).pow(2).mean()\n",
    "    end = time.time()\n",
    "    inf_times[i] = end-start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TOTAL_SAMPLES = 8887\n",
    "avg_of_one_instance = 1.6753819346427918\n",
    "SHAP_SAMPLES = np.array([1000, 1500, 2000, 2500, 3000])\n",
    "\n",
    "SHAP_PACKAGE_INF_TIMES = TOTAL_SAMPLES * avg_of_one_instance\n",
    "total_model_time = avg_of_one_instance + train_times + (SHAP_SAMPLES) * avg_of_one_instance\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "ax.plot(SHAP_SAMPLES / TOTAL_SAMPLES, SHAP_PACKAGE_INF_TIMES / total_model_time, marker='o')\n",
    "ax.set_xticks(SHAP_SAMPLES / TOTAL_SAMPLES)\n",
    "ax.set_ylabel('Speed-up relative to SHAP (X times)', labelpad=15);\n",
    "ax.set_xlabel('Fraction of the dataset used for training', labelpad=15)\n",
    "plt.title('Speedup of Neural Network Shapley Prediction (Ours) versus SHAP Kernel Explainer', y=1.05)\n",
    "sns.despine();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('rebuttal')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "8b4c61b0c0dfa50a12dd4948992e6e1a04659874706c6edcf781c314bdb1c462"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
