{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e16c7b4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import argparse\n",
    "\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from tensorflow.python.keras.layers import Dense, Input, Flatten, Add, Multiply, Lambda\n",
    "from tensorflow.python.keras.layers.normalization import BatchNormalization\n",
    "from tensorflow.python.keras import regularizers\n",
    "from tensorflow.python.keras.models import Model, Sequential\n",
    "from tensorflow.python.keras import optimizers\n",
    "from tensorflow.python.keras.callbacks import ModelCheckpoint\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "from utils.explanations import calculate_robust_astute_sampled\n",
    "\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8070334e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_all_weights(model, all_layer_weights):\n",
    "    count = 0\n",
    "    for layer in model.layers:\n",
    "        if type(layer) is Dense:\n",
    "            count += 1\n",
    "    if count == len(all_layer_weights):\n",
    "        c = 0\n",
    "        for layer in model.layers:\n",
    "            if type(layer) is Dense:\n",
    "                layer.set_weights(all_layer_weights[c])\n",
    "                c += 1\n",
    "        return model\n",
    "    else:\n",
    "        print(\"models don't match\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aac00ec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rise_explainer(datatype, ball_r, epsilon, prop_points, exponentiate, lambda_names, all_layer_weights):\n",
    "    blackbox_path = 'models/' + datatype + '_blackbox.hdf5'\n",
    "    rice_pd = pd.read_excel('data/Rice_Osmancik_Cammeo_Dataset.xlsx')\n",
    "    data = rice_pd.values[:, :-1]\n",
    "    labels = rice_pd.values[:, -1]\n",
    "    labels[labels == 'Cammeo'] = 0\n",
    "    labels[labels == 'Osmancik'] = 1\n",
    "    x_train, x_val, y_train, y_val = train_test_split(data, labels, test_size=0.33, random_state=42)\n",
    "    x_train = StandardScaler().fit_transform(x_train)\n",
    "    x_val = StandardScaler().fit_transform(x_val)\n",
    "    input_shape = x_train.shape[-1]\n",
    "    \n",
    "    activation = 'relu'\n",
    "\n",
    "    model_input = Input(shape=(input_shape,), dtype='float32')\n",
    "\n",
    "    net = Dense(32, activation=activation, name='dense1',\n",
    "                kernel_regularizer=regularizers.l2(1e-3))(model_input)\n",
    "    net = Dense(32, activation=activation, name='dense2',\n",
    "                kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "    net = Dense(32, activation=activation, name='dense3',\n",
    "                kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "    net = Dense(32, activation=activation, name='dense4',\n",
    "                kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "    preds = Dense(1, activation='sigmoid', name='dense5',\n",
    "                  kernel_regularizer=regularizers.l2(1e-3))(net)\n",
    "    bbox_model = Model(model_input, preds)\n",
    "    bbox_model = set_all_weights(bbox_model, all_layer_weights)\n",
    "    pred_model = Model(model_input, preds)\n",
    "\n",
    "    explanation = calculate_robust_astute_sampled(data=x_val,\n",
    "                                                      explainer=pred_model,\n",
    "                                                      explainer_type='rise',\n",
    "                                                      explanation_type='attribution',\n",
    "                                                      ball_r=ball_r,\n",
    "                                                      epsilon=epsilon,\n",
    "                                                      num_points=int(prop_points * len(x_val)),\n",
    "                                                      exponentiate=exponentiate,\n",
    "                                                      calculate_astuteness=False)\n",
    "\n",
    "    del pred_model\n",
    "    return np.abs(explanation)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1b76847",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/zulqarnain/anaconda3/envs/old_tf/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-18 13:59:55.321748: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
      "2022-05-18 13:59:55.350105: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3600000000 Hz\n",
      "2022-05-18 13:59:55.350895: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x55ec2da9cd30 executing computations on platform Host. Devices:\n",
      "2022-05-18 13:59:55.350926: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>\n"
     ]
    }
   ],
   "source": [
    "ball_radius = 2\n",
    "epsilon = 0.05\n",
    "prop_points = 0.05\n",
    "run_times = 5\n",
    "exponentiate = 0\n",
    "lambda_dense_list = [float(0.7), float(1), float(\"inf\")]\n",
    "lambda_names = ['Regularized High', 'Regularized Low', 'Not Regularized']\n",
    "for datatype in ['rice']:\n",
    "    for c in range(len(lambda_names)):\n",
    "        for i in range(run_times):\n",
    "            fname = 'explained_weights/rise/' + 'rise_' + datatype + '_' + str(c) + '_' + str(i) + '_lip.gz'\n",
    "            all_layer_weights = pickle.load(open('extracted_weights/rice_l2_' + str(c) + '.pk', 'rb'))\n",
    "\n",
    "            explanation = rise_explainer(datatype=datatype,\n",
    "                                           ball_r=ball_radius,\n",
    "                                           epsilon=epsilon,\n",
    "                                           prop_points=prop_points,\n",
    "                                           exponentiate=exponentiate,\n",
    "                                           lambda_names=lambda_names[c],\n",
    "                                           all_layer_weights=all_layer_weights)\n",
    "            np.savetxt(X=explanation, fname=fname, delimiter=',')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a53d72ff",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
