{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1c527ffc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import argparse\n",
    "import shap\n",
    "\n",
    "from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda\n",
    "from tensorflow.python.keras.layers.normalization import BatchNormalization\n",
    "from tensorflow.python.keras import regularizers\n",
    "from tensorflow.python.keras.models import Model, Sequential\n",
    "\n",
    "\n",
    "from utils.explanations import calculate_robust_astute_sampled\n",
    "\n",
    "\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45d79b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shap_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, classifier):\n",
    "    data_dict = pickle.load(open('data/' + datatype + '.pk', 'rb'))\n",
    "    x_train, y_train, x_val, y_val, datatype_val, input_shape = data_dict['x_train'], data_dict['y_train'], \\\n",
    "                                                                data_dict['x_val'], data_dict['y_val'], \\\n",
    "                                                                data_dict['datatype_val'], data_dict['input_shape']\n",
    "\n",
    "    if classifier == '2layer':\n",
    "        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "        net = Dense(200, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        net = Dense(200, activation=activation, name='dense2',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        preds = Dense(2, activation='softmax', name='dense4',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "\n",
    "    elif classifier == '4layer':\n",
    "        activation = 'relu' if datatype in ['orange_skin', 'XOR'] else 'selu'\n",
    "\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "        net = Dense(50, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        net = Dense(50, activation=activation, name='dense2',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        net = Dense(50, activation=activation, name='dense3',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        net = Dense(50, activation=activation, name='dense4',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        net = BatchNormalization()(net)\n",
    "        preds = Dense(2, activation='softmax', name='dense5',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox_extra.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "\n",
    "\n",
    "    elif classifier == 'linear':\n",
    "        activation = None\n",
    "\n",
    "        model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "\n",
    "        net = Dense(200, activation=activation, name='dense1',\n",
    "                    kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "        net = BatchNormalization()(net)  # Add batchnorm for stability.\n",
    "        preds = Dense(2, activation='softmax', name='dense4',\n",
    "                      kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "        bbox_model = Model(model_input, preds)\n",
    "        bbox_model.load_weights('models/' + datatype + '_blackbox_linear.hdf5',\n",
    "                                by_name=True)\n",
    "        pred_model = Model(model_input, preds)\n",
    "\n",
    "\n",
    "    elif classifier == 'svm':\n",
    "        pred_model = pickle.load(open('models/' + datatype + '_svm.pk', 'rb'))\n",
    "\n",
    "    if classifier == 'svm':\n",
    "        training_indices = np.random.choice(len(x_train), int(0.001*len(x_train)), replace=False)\n",
    "        explainer = shap.KernelExplainer(pred_model.predict_proba, shap.kmeans(x_train[training_indices], 50))\n",
    "\n",
    "\n",
    "        explanation = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                   explainer=explainer,\n",
    "                                                   explainer_type='shap',\n",
    "                                                   explanation_type='attribution',\n",
    "                                                   ball_r=ball_r,\n",
    "                                                   epsilon=epsilon,\n",
    "                                                   num_points=int(prop_points * len(x_val)),\n",
    "                                                   exponentiate=exponentiate,\n",
    "                                                   calculate_astuteness=False,\n",
    "                                                   NN=False)\n",
    "    else:\n",
    "        background = x_train[np.random.choice(len(x_train), 100, replace=False)]\n",
    "        explainer = shap.GradientExplainer(bbox_model, background)\n",
    "\n",
    "        explanation = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                      explainer=explainer,\n",
    "                                                      explainer_type='shap',\n",
    "                                                      explanation_type='attribution',\n",
    "                                                      ball_r=ball_r,\n",
    "                                                      epsilon=epsilon,\n",
    "                                                      num_points=int(prop_points * len(x_val)),\n",
    "                                                      exponentiate=exponentiate,\n",
    "                                                      calculate_astuteness=False)\n",
    "\n",
    "    del pred_model\n",
    "    return np.abs(explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13362eea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-09-29 20:35:24.718566: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
      "2021-09-29 20:35:24.721092: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3600000000 Hz\n",
      "2021-09-29 20:35:24.721596: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x55b180371800 executing computations on platform Host. Devices:\n",
      "2021-09-29 20:35:24.721605: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>\n",
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8644fc758b14ea798ed2c9dc46472c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b222ad88e3043909f34e0ed8b022ddf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89a0af225dba4439a466bc47d27abf33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "345d3aec4d004edcb1adf818249ce8ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ball_radius = 2\n",
    "epsilon = 0.05\n",
    "prop_points = 0.05\n",
    "run_times = 1\n",
    "exponentiate = 0\n",
    "classifiers = ['2layer', 'linear', '4layer', 'svm']\n",
    "\n",
    "for datatype in ['switch']:\n",
    "    for c in range(len(classifiers)):\n",
    "        for i in range(run_times):\n",
    "            fname = 'explained_weights/shap/' + 'shap_' + datatype + '_' + classifiers[c] + '_' + str(i) + '.gz'\n",
    "            explanation = shap_explainer(datatype=datatype,\n",
    "                                           ball_r=ball_radius,\n",
    "                                           epsilon=epsilon,\n",
    "                                           prop_points=prop_points,\n",
    "                                           exponentiate=exponentiate,\n",
    "                                           classifier=classifiers[c])\n",
    "            np.savetxt(X=explanation, fname=fname, delimiter=',')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
