{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code is released exclusively for review purposes with the following terms:\n",
    "PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE \n",
    "CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE \n",
    "REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Compare explanations for various examples.\"\"\"\n",
    "\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"../utilities/\")\n",
    "import os\n",
    "import pandas as pd\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.utils import check_random_state\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.neighbors import kneighbors_graph\n",
    "import yaml\n",
    "import pickle\n",
    "import joblib\n",
    "from itertools import product\n",
    "from scipy.stats import sem, pearsonr\n",
    "from scipy.sparse import vstack\n",
    "import scipy\n",
    "\n",
    "# fname_lime_exp\n",
    "from utils import (fname_data, fname_model, fname_exp, \n",
    "                   fname_base_perts, fname_env_perts, fname_preds,\n",
    "                   train_perturbation, scale_data,\n",
    "                  compute_weights)\n",
    "from helpers import lrg_lsq_sparse, lrg_lsq\n",
    "pd.options.display.max_rows = 100\n",
    "\n",
    "from IPython.display import display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args():\n",
    "    def __init__(self, config_fname=None, dataset_key=None, model_key=None, \n",
    "                 pert_key=None, start_ex=None, end_ex=None):\n",
    "        self.config_fname = config_fname\n",
    "        self.dataset_key = dataset_key\n",
    "        self.model_key = model_key\n",
    "        self.pert_key = pert_key\n",
    "        self.start_ex = start_ex\n",
    "        self.end_ex = end_ex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=\"IRIS\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "pert_keys = [\"Base_Perturbations\"]\n",
    "\n",
    "generalized_metrics = [ \"pred_cons\", \"coef_cons\", \"gen_fid\", \"unidir\"]\n",
    "classwise_gen_metrics = False\n",
    "\n",
    "if dataset_name==\"IRIS\":\n",
    "    cnts = [10, 20, 30, 40, 50]\n",
    "    num_envs1 = [2, 3, 4, 5]\n",
    "    kernel_widths = [0.1, 0.2, 0.5, 1.0, 1.5]\n",
    "    config_fname_base=\"IRIS/config_iris_\"\n",
    "    dataset_key=\"IRIS\"\n",
    "    model_key=\"RFC\"\n",
    "    results_fname=\"results_all/iris_rfc_res.pkl\"\n",
    "    \n",
    "    metrics_to_compute = [\"fid\", \"pred_cons\", \"coef_cons\", \n",
    "                                 \"gen_fid\", \"unidir\", \"stability\", \"compute_time\"]\n",
    "    \n",
    "      \n",
    "if len(set(generalized_metrics).intersection(set(metrics_to_compute))) > 0:\n",
    "    ks = list(np.arange(1, 21))\n",
    "else:\n",
    "    ks = [0]\n",
    "\n",
    "combinations = [\n",
    "    cnts,\n",
    "    num_envs1,\n",
    "    kernel_widths,\n",
    "    ks \n",
    "]\n",
    "combinations1 = combinations.copy()\n",
    "combinations2 = combinations.copy()\n",
    "combinations3 = combinations.copy()\n",
    "combinations4 = combinations.copy()\n",
    "\n",
    "combinations1 = [[\"LIME\", \"LIME_smooth\", \"LINEX\"], [\"Base_Perturbations\"]] + combinations\n",
    "combinations_df_prod1 = list(product(*combinations1))\n",
    "\n",
    "combinations2 = [[\"LIME\", \"LINEX\"], [\"MeLime_Perturbations\"]] + combinations\n",
    "combinations_df_prod2 = list(product(*combinations2))\n",
    "\n",
    "combinations3 = [[\"LIME\",\"LINEX\"], [\"MAPLE\"]] + combinations\n",
    "combinations_df_prod3 = list(product(*combinations3))\n",
    "\n",
    "combinations4 = [[\"SHAP\"], [\"Base_Perturbations\"]] +\\\n",
    "            [[cnts[0]], [num_envs1[0]], [kernel_widths[0]], ks]\n",
    "combinations_df_prod4 = list(product(*combinations4))\n",
    "\n",
    "combinations_df_prod = (combinations_df_prod1+combinations_df_prod2+\n",
    "                        combinations_df_prod3+combinations_df_prod4)\n",
    "\n",
    "combinations_df_prod = combinations_df_prod1\n",
    "\n",
    "res_df = pd.DataFrame(index=pd.MultiIndex.from_tuples(combinations_df_prod), \n",
    "                        columns=[\"fid\", \"pred_cons\", \"coef_cons\", \n",
    "                                 \"gen_fid\", \"unidir\", \"stability\", \"compute_time\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>fid</th>\n",
       "      <th>pred_cons</th>\n",
       "      <th>coef_cons</th>\n",
       "      <th>gen_fid</th>\n",
       "      <th>unidir</th>\n",
       "      <th>stability</th>\n",
       "      <th>compute_time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">LIME</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">Base_Perturbations</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">10</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">2</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">0.1</th>\n",
       "      <th>1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">LINEX</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">Base_Perturbations</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">50</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">5</th>\n",
       "      <th rowspan=\"5\" valign=\"top\">1.5</th>\n",
       "      <th>16</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6000 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      fid pred_cons coef_cons gen_fid unidir  \\\n",
       "LIME  Base_Perturbations 10 2 0.1 1   NaN       NaN       NaN     NaN    NaN   \n",
       "                                  2   NaN       NaN       NaN     NaN    NaN   \n",
       "                                  3   NaN       NaN       NaN     NaN    NaN   \n",
       "                                  4   NaN       NaN       NaN     NaN    NaN   \n",
       "                                  5   NaN       NaN       NaN     NaN    NaN   \n",
       "...                                   ...       ...       ...     ...    ...   \n",
       "LINEX Base_Perturbations 50 5 1.5 16  NaN       NaN       NaN     NaN    NaN   \n",
       "                                  17  NaN       NaN       NaN     NaN    NaN   \n",
       "                                  18  NaN       NaN       NaN     NaN    NaN   \n",
       "                                  19  NaN       NaN       NaN     NaN    NaN   \n",
       "                                  20  NaN       NaN       NaN     NaN    NaN   \n",
       "\n",
       "                                     stability compute_time  \n",
       "LIME  Base_Perturbations 10 2 0.1 1        NaN          NaN  \n",
       "                                  2        NaN          NaN  \n",
       "                                  3        NaN          NaN  \n",
       "                                  4        NaN          NaN  \n",
       "                                  5        NaN          NaN  \n",
       "...                                        ...          ...  \n",
       "LINEX Base_Perturbations 50 5 1.5 16       NaN          NaN  \n",
       "                                  17       NaN          NaN  \n",
       "                                  18       NaN          NaN  \n",
       "                                  19       NaN          NaN  \n",
       "                                  20       NaN          NaN  \n",
       "\n",
       "[6000 rows x 7 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'results_all/iris_rfc_res.pkl'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_fname"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IRIS/config_iris_0.1_2_10.yaml\n"
     ]
    }
   ],
   "source": [
    "# Load a dummy config file for loading data etc.\n",
    "config_fname =  config_fname_base+str(kernel_widths[0])+\"_\"+str(num_envs1[0])+\"_\"+str(cnts[0])+\".yaml\"\n",
    "args = Args(config_fname, dataset_key, model_key, pert_keys[0])\n",
    "print(config_fname)\n",
    "\n",
    "# Load the config file\n",
    "config = yaml.load(open(\n",
    "        os.path.join(\"config\", args.config_fname)),\n",
    "        Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load data\n",
    "datafname = fname_data(config, dataset_key)+\".pkl\"\n",
    "    \n",
    "dirname = os.path.join(\"data\", dataset_key, \"input\")\n",
    "((X_train0, X_test0, y_train, y_test, w_train, w_test),\n",
    "            categorical_feature_names, numerical_feature_names,\n",
    "            categorical_feature_inds, numerical_feature_inds,\n",
    "            colnames_onehot, colnames_orig) = pickle.load(open(os.path.join(dirname, datafname), \"rb\"))\n",
    "n_data_all = len(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_key == \"IRIS\":\n",
    "    if type(X_train0) == np.ndarray:\n",
    "        X_train0 = pd.DataFrame(data=X_train0, columns=colnames_orig)\n",
    "    if type(X_test0) == np.ndarray:\n",
    "        X_test0 = pd.DataFrame(data=X_test0, columns=colnames_orig)\n",
    "\n",
    "    # Load black box model\n",
    "    modelfname = fname_model(config, model_key, \n",
    "                            dataset_key)+\".pkl\"\n",
    "    dirname = os.path.join(\"data\", dataset_key, \"models\")\n",
    "    bb_model = joblib.load(open(os.path.join(dirname, modelfname), \"rb\"))\n",
    "\n",
    "    # \"train\" the perturbation model - collect data stats\n",
    "    feat_mean, feat_std, cat_freqs, cond_prob_predictor = train_perturbation(\n",
    "                                        X_train0,\n",
    "                                        categorical_feature_names,\n",
    "                                        cond_prob_train=False)\n",
    "    \n",
    "    # standardize test data\n",
    "    X_test0 = X_test0.iloc[args.start_ex:args.end_ex, :]\n",
    "    X_test1 = pd.get_dummies(X_test0, prefix_sep=\"=\")\n",
    "    X_test = scale_data(X_test1, \n",
    "                        numerical_feature_inds,\n",
    "                        feat_mean,\n",
    "                        feat_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_key == \"IRIS\":\n",
    "    # black box predictions for the test data\n",
    "    if config[\"Preds\"][\"cls\"] == \"reg\":\n",
    "        def predict_fn(x):\n",
    "            return bb_model.predict(x)\n",
    "    else:\n",
    "        def predict_fn(x):\n",
    "            return bb_model.predict_proba(x)[:, config[\"Preds\"][\"cls\"]].ravel()\n",
    "    ytrain_pred = predict_fn(X_train0)\n",
    "    y_pred = predict_fn(X_test0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"IRIS\":\n",
    "    shap_mean_offset = predict_fn(X_train0).mean() * np.ones(len(y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# K Neighbors graph for test data (should we symmetrize)\n",
    "if len(set(generalized_metrics).intersection(set(metrics_to_compute))) > 0:\n",
    "    \n",
    "    knn = {}\n",
    "    \n",
    "    if classwise_gen_metrics:\n",
    "        \n",
    "        if type(X_test) == list:\n",
    "            X_tests = np.vstack(X_test)\n",
    "        else:\n",
    "            X_tests = X_test\n",
    "        \n",
    "        for k in ks:    \n",
    "            rowinds = []\n",
    "            colinds = []\n",
    "            for cls in np.unique(y_test):\n",
    "                class_inds = np.where(y_test == cls)[0]\n",
    "                A0 = kneighbors_graph(X_tests[class_inds], k)\n",
    "                rowinds0, colinds0 = A0.nonzero()\n",
    "                rowinds1 = class_inds[rowinds0]\n",
    "                colinds1 = class_inds[colinds0]\n",
    "                rowinds.append(rowinds1)\n",
    "                colinds.append(colinds1)\n",
    "                \n",
    "            knn[k] = [np.hstack(rowinds), np.hstack(colinds)]\n",
    "    else:\n",
    "        for k in ks:\n",
    "            A0 = kneighbors_graph(X_test, k)\n",
    "            rowinds, colinds = A0.nonzero()\n",
    "\n",
    "            knn[k] = [rowinds, colinds]    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define metric functions here\n",
    "def sign_consistency(c1, c2, thresh=np.inf):\n",
    "    \n",
    "    c1_thresh = c1.copy()\n",
    "    c2_thresh = c2.copy()\n",
    "    \n",
    "    c1_thresh[np.abs(c1_thresh) < thresh] = 0.0\n",
    "    c2_thresh[np.abs(c2_thresh) < thresh] = 0.0\n",
    "    \n",
    "    return np.mean(np.sign(c1_thresh) == np.sign(c2_thresh))\n",
    "\n",
    "def class_attribution_consistency(X, y, exp, cls):\n",
    "    \n",
    "    if type(X) == list:\n",
    "        Xs = np.vstack(X)\n",
    "    else:\n",
    "        Xs = X\n",
    "        \n",
    "    if type(y) == list:\n",
    "        ys = np.array(y)\n",
    "    else:\n",
    "        ys = y\n",
    "        \n",
    "    if type(exp) == list:\n",
    "        exps = np.vstack(exp)\n",
    "    else:\n",
    "        exps = exp\n",
    "    \n",
    "    clses = np.unique(y)\n",
    "    \n",
    "    r_l = []\n",
    "    for cls in clses:\n",
    "        \n",
    "        if type(Xs) == np.ndarray:\n",
    "            pc = pearsonr(np.mean(Xs[ys==cls], axis = 0), \n",
    "                                np.mean(exps[ys==cls], axis=0))[0]\n",
    "        elif scipy.sparse.issparse(Xs):\n",
    "            pc = pearsonr(np.mean(Xs[ys==cls], axis = 0).A.ravel(), \n",
    "                    np.mean(exps[ys==cls], axis=0))[0]\n",
    "        r_l.append(pc)\n",
    "    \n",
    "    return np.mean(r_l)\n",
    "\n",
    "def ind_predictions(X, exp, offset=0):\n",
    "\n",
    "        \n",
    "    if type(X) == np.ndarray:\n",
    "        \n",
    "        if type(exp) == list:\n",
    "            exps = np.vstack(exp)\n",
    "        else:\n",
    "            exps = exp\n",
    "            \n",
    "        return np.sum(X * exps, axis=1) + offset\n",
    "    \n",
    "    elif type(X) == list:\n",
    "        \n",
    "        exps = exp\n",
    "        \n",
    "        return np.array([np.sum(Xi * expi) + offseti\n",
    "                for (Xi, expi, offseti) in zip(X, exps, offset)])\n",
    "    \n",
    "    elif scipy.sparse.issparse(X):\n",
    "        \n",
    "        if type(exp) == list:\n",
    "            exps = np.vstack(exp)\n",
    "        else:\n",
    "            exps = exp\n",
    "            \n",
    "        return X.multiply(exps).sum(axis=1).A.ravel() + offset\n",
    "        \n",
    "\n",
    "def prediction_consistency(y, rowinds, colinds, agg_type=\"mean\"):\n",
    "    if agg_type == \"mean\":\n",
    "        return np.mean(np.abs(y[rowinds] - y[colinds]))\n",
    "    elif agg_type == \"none\":\n",
    "        return np.abs(y[rowinds] - y[colinds])\n",
    "\n",
    "def coefficient_consistency(exp, rowinds, colinds, agg_type=\"mean\"):\n",
    "    exp1 = np.vstack(exp)\n",
    "    return prediction_consistency(exp1, rowinds, colinds, agg_type=agg_type)\n",
    "\n",
    "def generalized_fidelity(X, y, exp, rowinds, colinds, offset=0, agg_type=\"mean\"):\n",
    "    if type(X) == list:\n",
    "        Xs = np.vstack(X)\n",
    "    else:\n",
    "        Xs = X\n",
    "    if type(exp) == list:\n",
    "        exps = np.vstack(exp)\n",
    "    else:\n",
    "        exps = exp\n",
    "    \n",
    "    if type(Xs) == np.ndarray:\n",
    "        y_pred_other = np.sum(Xs[colinds] * exps[rowinds], axis=1)+offset[rowinds]\n",
    "    elif scipy.sparse.issparse(Xs):\n",
    "        y_pred_other = Xs[colinds].multiply(exps[rowinds]).sum(axis=1).A.ravel()+offset[rowinds]\n",
    "    \n",
    "    if agg_type == \"mean\":\n",
    "        return np.mean(np.abs(y[colinds] - y_pred_other))\n",
    "    elif agg_type == \"none\":\n",
    "        return np.abs(y[colinds] - y_pred_other)\n",
    "\n",
    "def fidelity(y, y_exp):\n",
    "    return mean_absolute_error(y, y_exp)\n",
    "\n",
    "# Other utility functions go here\n",
    "def load_explanations(exp_type, config, args):\n",
    "\n",
    "    if exp_type == \"LINEX\":\n",
    "        exp_fname = fname_exp(config, \"LINEX\", \"Env_Perturbations\",\n",
    "                            args.pert_key, args.model_key,\n",
    "                            args.dataset_key)+\".pkl\"\n",
    "        dirname = os.path.join(\"data\", args.dataset_key, \"explanations\")\n",
    "        linex_explanations = pickle.load(open( os.path.join(dirname, exp_fname), \"rb\" ) )\n",
    "        if \"linex_time\" in linex_explanations:\n",
    "            linex_time = linex_explanations[\"linex_time\"]\n",
    "        return (linex_explanations[\"linex\"], \n",
    "            linex_explanations[\"linex_time\"] if \"linex_time\" in linex_explanations else 0.0)\n",
    "    \n",
    "    elif exp_type == \"LIME\":\n",
    "        exp_fname = fname_exp(config, \"LIME\", \"Env_Perturbations\",\n",
    "                            args.pert_key, args.model_key,\n",
    "                            args.dataset_key)+\".pkl\"\n",
    "        dirname = os.path.join(\"data\", args.dataset_key, \"explanations\")\n",
    "        lime_explanations = pickle.load( open( os.path.join(dirname, exp_fname), \"rb\" ) )\n",
    "        return (lime_explanations[\"lime_base\"],\n",
    "            lime_explanations[\"lime_base_time\"] if \"lime_base_time\" in lime_explanations else 0.0)\n",
    "    \n",
    "    elif exp_type == \"LIME_smooth\":\n",
    "        exp_fname = fname_exp(config, \"LIME\", \"Env_Perturbations\",\n",
    "                            args.pert_key, args.model_key,\n",
    "                            args.dataset_key)+\".pkl\"\n",
    "        dirname = os.path.join(\"data\", args.dataset_key, \"explanations\")\n",
    "        lime_explanations = pickle.load( open( os.path.join(dirname, exp_fname), \"rb\" ) )\n",
    "        lime_envs_exp = lime_explanations[\"lime_envs\"]\n",
    "        \n",
    "        if type(lime_envs_exp) == list:\n",
    "            lime_smooth_exp = [lime_env_exps.mean(axis=0) for \n",
    "                               lime_env_exps in lime_explanations[\"lime_envs\"]]\n",
    "        elif type(lime_envs_exp) == np.ndarray:\n",
    "            lime_smooth_exp = lime_envs_exp.mean(axis=2)\n",
    "        \n",
    "        return (lime_smooth_exp,\n",
    "            lime_explanations[\"lime_envs_time\"] if \"lime_envs_time\" in lime_explanations else 0.0)\n",
    "    \n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base_Perturbations LINEX\n",
      "IRIS/config_iris_0.1_2_10.yaml\n",
      "IRIS/config_iris_0.2_2_10.yaml\n",
      "IRIS/config_iris_0.5_2_10.yaml\n",
      "IRIS/config_iris_1.0_2_10.yaml\n",
      "IRIS/config_iris_1.5_2_10.yaml\n",
      "IRIS/config_iris_0.1_3_10.yaml\n",
      "IRIS/config_iris_0.2_3_10.yaml\n",
      "IRIS/config_iris_0.5_3_10.yaml\n",
      "IRIS/config_iris_1.0_3_10.yaml\n",
      "IRIS/config_iris_1.5_3_10.yaml\n",
      "IRIS/config_iris_0.1_4_10.yaml\n",
      "IRIS/config_iris_0.2_4_10.yaml\n",
      "IRIS/config_iris_0.5_4_10.yaml\n",
      "IRIS/config_iris_1.0_4_10.yaml\n",
      "IRIS/config_iris_1.5_4_10.yaml\n",
      "IRIS/config_iris_0.1_5_10.yaml\n",
      "IRIS/config_iris_0.2_5_10.yaml\n",
      "IRIS/config_iris_0.5_5_10.yaml\n",
      "IRIS/config_iris_1.0_5_10.yaml\n",
      "IRIS/config_iris_1.5_5_10.yaml\n",
      "IRIS/config_iris_0.1_2_20.yaml\n",
      "IRIS/config_iris_0.2_2_20.yaml\n",
      "IRIS/config_iris_0.5_2_20.yaml\n",
      "IRIS/config_iris_1.0_2_20.yaml\n",
      "IRIS/config_iris_1.5_2_20.yaml\n",
      "IRIS/config_iris_0.1_3_20.yaml\n",
      "IRIS/config_iris_0.2_3_20.yaml\n",
      "IRIS/config_iris_0.5_3_20.yaml\n",
      "IRIS/config_iris_1.0_3_20.yaml\n",
      "IRIS/config_iris_1.5_3_20.yaml\n",
      "IRIS/config_iris_0.1_4_20.yaml\n",
      "IRIS/config_iris_0.2_4_20.yaml\n",
      "IRIS/config_iris_0.5_4_20.yaml\n",
      "IRIS/config_iris_1.0_4_20.yaml\n",
      "IRIS/config_iris_1.5_4_20.yaml\n",
      "IRIS/config_iris_0.1_5_20.yaml\n",
      "IRIS/config_iris_0.2_5_20.yaml\n",
      "IRIS/config_iris_0.5_5_20.yaml\n",
      "IRIS/config_iris_1.0_5_20.yaml\n",
      "IRIS/config_iris_1.5_5_20.yaml\n",
      "IRIS/config_iris_0.1_2_30.yaml\n",
      "IRIS/config_iris_0.2_2_30.yaml\n",
      "IRIS/config_iris_0.5_2_30.yaml\n",
      "IRIS/config_iris_1.0_2_30.yaml\n",
      "IRIS/config_iris_1.5_2_30.yaml\n",
      "IRIS/config_iris_0.1_3_30.yaml\n",
      "IRIS/config_iris_0.2_3_30.yaml\n",
      "IRIS/config_iris_0.5_3_30.yaml\n",
      "IRIS/config_iris_1.0_3_30.yaml\n",
      "IRIS/config_iris_1.5_3_30.yaml\n",
      "IRIS/config_iris_0.1_4_30.yaml\n",
      "IRIS/config_iris_0.2_4_30.yaml\n",
      "IRIS/config_iris_0.5_4_30.yaml\n",
      "IRIS/config_iris_1.0_4_30.yaml\n",
      "IRIS/config_iris_1.5_4_30.yaml\n",
      "IRIS/config_iris_0.1_5_30.yaml\n",
      "IRIS/config_iris_0.2_5_30.yaml\n",
      "IRIS/config_iris_0.5_5_30.yaml\n",
      "IRIS/config_iris_1.0_5_30.yaml\n",
      "IRIS/config_iris_1.5_5_30.yaml\n",
      "IRIS/config_iris_0.1_2_40.yaml\n",
      "IRIS/config_iris_0.2_2_40.yaml\n",
      "IRIS/config_iris_0.5_2_40.yaml\n",
      "IRIS/config_iris_1.0_2_40.yaml\n",
      "IRIS/config_iris_1.5_2_40.yaml\n",
      "IRIS/config_iris_0.1_3_40.yaml\n",
      "IRIS/config_iris_0.2_3_40.yaml\n",
      "IRIS/config_iris_0.5_3_40.yaml\n",
      "IRIS/config_iris_1.0_3_40.yaml\n",
      "IRIS/config_iris_1.5_3_40.yaml\n",
      "IRIS/config_iris_0.1_4_40.yaml\n",
      "IRIS/config_iris_0.2_4_40.yaml\n",
      "IRIS/config_iris_0.5_4_40.yaml\n",
      "IRIS/config_iris_1.0_4_40.yaml\n",
      "IRIS/config_iris_1.5_4_40.yaml\n",
      "IRIS/config_iris_0.1_5_40.yaml\n",
      "IRIS/config_iris_0.2_5_40.yaml\n",
      "IRIS/config_iris_0.5_5_40.yaml\n",
      "IRIS/config_iris_1.0_5_40.yaml\n",
      "IRIS/config_iris_1.5_5_40.yaml\n",
      "IRIS/config_iris_0.1_2_50.yaml\n",
      "IRIS/config_iris_0.2_2_50.yaml\n",
      "IRIS/config_iris_0.5_2_50.yaml\n",
      "IRIS/config_iris_1.0_2_50.yaml\n",
      "IRIS/config_iris_1.5_2_50.yaml\n",
      "IRIS/config_iris_0.1_3_50.yaml\n",
      "IRIS/config_iris_0.2_3_50.yaml\n",
      "IRIS/config_iris_0.5_3_50.yaml\n",
      "IRIS/config_iris_1.0_3_50.yaml\n",
      "IRIS/config_iris_1.5_3_50.yaml\n",
      "IRIS/config_iris_0.1_4_50.yaml\n",
      "IRIS/config_iris_0.2_4_50.yaml\n",
      "IRIS/config_iris_0.5_4_50.yaml\n",
      "IRIS/config_iris_1.0_4_50.yaml\n",
      "IRIS/config_iris_1.5_4_50.yaml\n",
      "IRIS/config_iris_0.1_5_50.yaml\n",
      "IRIS/config_iris_0.2_5_50.yaml\n",
      "IRIS/config_iris_0.5_5_50.yaml\n",
      "IRIS/config_iris_1.0_5_50.yaml\n",
      "IRIS/config_iris_1.5_5_50.yaml\n",
      "Base_Perturbations LIME\n",
      "IRIS/config_iris_0.1_2_10.yaml\n",
      "IRIS/config_iris_0.2_2_10.yaml\n",
      "IRIS/config_iris_0.5_2_10.yaml\n",
      "IRIS/config_iris_1.0_2_10.yaml\n",
      "IRIS/config_iris_1.5_2_10.yaml\n",
      "IRIS/config_iris_0.1_3_10.yaml\n",
      "IRIS/config_iris_0.2_3_10.yaml\n",
      "IRIS/config_iris_0.5_3_10.yaml\n",
      "IRIS/config_iris_1.0_3_10.yaml\n",
      "IRIS/config_iris_1.5_3_10.yaml\n",
      "IRIS/config_iris_0.1_4_10.yaml\n",
      "IRIS/config_iris_0.2_4_10.yaml\n",
      "IRIS/config_iris_0.5_4_10.yaml\n",
      "IRIS/config_iris_1.0_4_10.yaml\n",
      "IRIS/config_iris_1.5_4_10.yaml\n",
      "IRIS/config_iris_0.1_5_10.yaml\n",
      "IRIS/config_iris_0.2_5_10.yaml\n",
      "IRIS/config_iris_0.5_5_10.yaml\n",
      "IRIS/config_iris_1.0_5_10.yaml\n",
      "IRIS/config_iris_1.5_5_10.yaml\n",
      "IRIS/config_iris_0.1_2_20.yaml\n",
      "IRIS/config_iris_0.2_2_20.yaml\n",
      "IRIS/config_iris_0.5_2_20.yaml\n",
      "IRIS/config_iris_1.0_2_20.yaml\n",
      "IRIS/config_iris_1.5_2_20.yaml\n",
      "IRIS/config_iris_0.1_3_20.yaml\n",
      "IRIS/config_iris_0.2_3_20.yaml\n",
      "IRIS/config_iris_0.5_3_20.yaml\n",
      "IRIS/config_iris_1.0_3_20.yaml\n",
      "IRIS/config_iris_1.5_3_20.yaml\n",
      "IRIS/config_iris_0.1_4_20.yaml\n",
      "IRIS/config_iris_0.2_4_20.yaml\n",
      "IRIS/config_iris_0.5_4_20.yaml\n",
      "IRIS/config_iris_1.0_4_20.yaml\n",
      "IRIS/config_iris_1.5_4_20.yaml\n",
      "IRIS/config_iris_0.1_5_20.yaml\n",
      "IRIS/config_iris_0.2_5_20.yaml\n",
      "IRIS/config_iris_0.5_5_20.yaml\n",
      "IRIS/config_iris_1.0_5_20.yaml\n",
      "IRIS/config_iris_1.5_5_20.yaml\n",
      "IRIS/config_iris_0.1_2_30.yaml\n",
      "IRIS/config_iris_0.2_2_30.yaml\n",
      "IRIS/config_iris_0.5_2_30.yaml\n",
      "IRIS/config_iris_1.0_2_30.yaml\n",
      "IRIS/config_iris_1.5_2_30.yaml\n",
      "IRIS/config_iris_0.1_3_30.yaml\n",
      "IRIS/config_iris_0.2_3_30.yaml\n",
      "IRIS/config_iris_0.5_3_30.yaml\n",
      "IRIS/config_iris_1.0_3_30.yaml\n",
      "IRIS/config_iris_1.5_3_30.yaml\n",
      "IRIS/config_iris_0.1_4_30.yaml\n",
      "IRIS/config_iris_0.2_4_30.yaml\n",
      "IRIS/config_iris_0.5_4_30.yaml\n",
      "IRIS/config_iris_1.0_4_30.yaml\n",
      "IRIS/config_iris_1.5_4_30.yaml\n",
      "IRIS/config_iris_0.1_5_30.yaml\n",
      "IRIS/config_iris_0.2_5_30.yaml\n",
      "IRIS/config_iris_0.5_5_30.yaml\n",
      "IRIS/config_iris_1.0_5_30.yaml\n",
      "IRIS/config_iris_1.5_5_30.yaml\n",
      "IRIS/config_iris_0.1_2_40.yaml\n",
      "IRIS/config_iris_0.2_2_40.yaml\n",
      "IRIS/config_iris_0.5_2_40.yaml\n",
      "IRIS/config_iris_1.0_2_40.yaml\n",
      "IRIS/config_iris_1.5_2_40.yaml\n",
      "IRIS/config_iris_0.1_3_40.yaml\n",
      "IRIS/config_iris_0.2_3_40.yaml\n",
      "IRIS/config_iris_0.5_3_40.yaml\n",
      "IRIS/config_iris_1.0_3_40.yaml\n",
      "IRIS/config_iris_1.5_3_40.yaml\n",
      "IRIS/config_iris_0.1_4_40.yaml\n",
      "IRIS/config_iris_0.2_4_40.yaml\n",
      "IRIS/config_iris_0.5_4_40.yaml\n",
      "IRIS/config_iris_1.0_4_40.yaml\n",
      "IRIS/config_iris_1.5_4_40.yaml\n",
      "IRIS/config_iris_0.1_5_40.yaml\n",
      "IRIS/config_iris_0.2_5_40.yaml\n",
      "IRIS/config_iris_0.5_5_40.yaml\n",
      "IRIS/config_iris_1.0_5_40.yaml\n",
      "IRIS/config_iris_1.5_5_40.yaml\n",
      "IRIS/config_iris_0.1_2_50.yaml\n",
      "IRIS/config_iris_0.2_2_50.yaml\n",
      "IRIS/config_iris_0.5_2_50.yaml\n",
      "IRIS/config_iris_1.0_2_50.yaml\n",
      "IRIS/config_iris_1.5_2_50.yaml\n",
      "IRIS/config_iris_0.1_3_50.yaml\n",
      "IRIS/config_iris_0.2_3_50.yaml\n",
      "IRIS/config_iris_0.5_3_50.yaml\n",
      "IRIS/config_iris_1.0_3_50.yaml\n",
      "IRIS/config_iris_1.5_3_50.yaml\n",
      "IRIS/config_iris_0.1_4_50.yaml\n",
      "IRIS/config_iris_0.2_4_50.yaml\n",
      "IRIS/config_iris_0.5_4_50.yaml\n",
      "IRIS/config_iris_1.0_4_50.yaml\n",
      "IRIS/config_iris_1.5_4_50.yaml\n",
      "IRIS/config_iris_0.1_5_50.yaml\n",
      "IRIS/config_iris_0.2_5_50.yaml\n",
      "IRIS/config_iris_0.5_5_50.yaml\n",
      "IRIS/config_iris_1.0_5_50.yaml\n",
      "IRIS/config_iris_1.5_5_50.yaml\n",
      "Base_Perturbations LIME_smooth\n",
      "IRIS/config_iris_0.1_2_10.yaml\n",
      "IRIS/config_iris_0.2_2_10.yaml\n",
      "IRIS/config_iris_0.5_2_10.yaml\n",
      "IRIS/config_iris_1.0_2_10.yaml\n",
      "IRIS/config_iris_1.5_2_10.yaml\n",
      "IRIS/config_iris_0.1_3_10.yaml\n",
      "IRIS/config_iris_0.2_3_10.yaml\n",
      "IRIS/config_iris_0.5_3_10.yaml\n",
      "IRIS/config_iris_1.0_3_10.yaml\n",
      "IRIS/config_iris_1.5_3_10.yaml\n",
      "IRIS/config_iris_0.1_4_10.yaml\n",
      "IRIS/config_iris_0.2_4_10.yaml\n",
      "IRIS/config_iris_0.5_4_10.yaml\n",
      "IRIS/config_iris_1.0_4_10.yaml\n",
      "IRIS/config_iris_1.5_4_10.yaml\n",
      "IRIS/config_iris_0.1_5_10.yaml\n",
      "IRIS/config_iris_0.2_5_10.yaml\n",
      "IRIS/config_iris_0.5_5_10.yaml\n",
      "IRIS/config_iris_1.0_5_10.yaml\n",
      "IRIS/config_iris_1.5_5_10.yaml\n",
      "IRIS/config_iris_0.1_2_20.yaml\n",
      "IRIS/config_iris_0.2_2_20.yaml\n",
      "IRIS/config_iris_0.5_2_20.yaml\n",
      "IRIS/config_iris_1.0_2_20.yaml\n",
      "IRIS/config_iris_1.5_2_20.yaml\n",
      "IRIS/config_iris_0.1_3_20.yaml\n",
      "IRIS/config_iris_0.2_3_20.yaml\n",
      "IRIS/config_iris_0.5_3_20.yaml\n",
      "IRIS/config_iris_1.0_3_20.yaml\n",
      "IRIS/config_iris_1.5_3_20.yaml\n",
      "IRIS/config_iris_0.1_4_20.yaml\n",
      "IRIS/config_iris_0.2_4_20.yaml\n",
      "IRIS/config_iris_0.5_4_20.yaml\n",
      "IRIS/config_iris_1.0_4_20.yaml\n",
      "IRIS/config_iris_1.5_4_20.yaml\n",
      "IRIS/config_iris_0.1_5_20.yaml\n",
      "IRIS/config_iris_0.2_5_20.yaml\n",
      "IRIS/config_iris_0.5_5_20.yaml\n",
      "IRIS/config_iris_1.0_5_20.yaml\n",
      "IRIS/config_iris_1.5_5_20.yaml\n",
      "IRIS/config_iris_0.1_2_30.yaml\n",
      "IRIS/config_iris_0.2_2_30.yaml\n",
      "IRIS/config_iris_0.5_2_30.yaml\n",
      "IRIS/config_iris_1.0_2_30.yaml\n",
      "IRIS/config_iris_1.5_2_30.yaml\n",
      "IRIS/config_iris_0.1_3_30.yaml\n",
      "IRIS/config_iris_0.2_3_30.yaml\n",
      "IRIS/config_iris_0.5_3_30.yaml\n",
      "IRIS/config_iris_1.0_3_30.yaml\n",
      "IRIS/config_iris_1.5_3_30.yaml\n",
      "IRIS/config_iris_0.1_4_30.yaml\n",
      "IRIS/config_iris_0.2_4_30.yaml\n",
      "IRIS/config_iris_0.5_4_30.yaml\n",
      "IRIS/config_iris_1.0_4_30.yaml\n",
      "IRIS/config_iris_1.5_4_30.yaml\n",
      "IRIS/config_iris_0.1_5_30.yaml\n",
      "IRIS/config_iris_0.2_5_30.yaml\n",
      "IRIS/config_iris_0.5_5_30.yaml\n",
      "IRIS/config_iris_1.0_5_30.yaml\n",
      "IRIS/config_iris_1.5_5_30.yaml\n",
      "IRIS/config_iris_0.1_2_40.yaml\n",
      "IRIS/config_iris_0.2_2_40.yaml\n",
      "IRIS/config_iris_0.5_2_40.yaml\n",
      "IRIS/config_iris_1.0_2_40.yaml\n",
      "IRIS/config_iris_1.5_2_40.yaml\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IRIS/config_iris_0.1_3_40.yaml\n",
      "IRIS/config_iris_0.2_3_40.yaml\n",
      "IRIS/config_iris_0.5_3_40.yaml\n",
      "IRIS/config_iris_1.0_3_40.yaml\n",
      "IRIS/config_iris_1.5_3_40.yaml\n",
      "IRIS/config_iris_0.1_4_40.yaml\n",
      "IRIS/config_iris_0.2_4_40.yaml\n",
      "IRIS/config_iris_0.5_4_40.yaml\n",
      "IRIS/config_iris_1.0_4_40.yaml\n",
      "IRIS/config_iris_1.5_4_40.yaml\n",
      "IRIS/config_iris_0.1_5_40.yaml\n",
      "IRIS/config_iris_0.2_5_40.yaml\n",
      "IRIS/config_iris_0.5_5_40.yaml\n",
      "IRIS/config_iris_1.0_5_40.yaml\n",
      "IRIS/config_iris_1.5_5_40.yaml\n",
      "IRIS/config_iris_0.1_2_50.yaml\n",
      "IRIS/config_iris_0.2_2_50.yaml\n",
      "IRIS/config_iris_0.5_2_50.yaml\n",
      "IRIS/config_iris_1.0_2_50.yaml\n",
      "IRIS/config_iris_1.5_2_50.yaml\n",
      "IRIS/config_iris_0.1_3_50.yaml\n",
      "IRIS/config_iris_0.2_3_50.yaml\n",
      "IRIS/config_iris_0.5_3_50.yaml\n",
      "IRIS/config_iris_1.0_3_50.yaml\n",
      "IRIS/config_iris_1.5_3_50.yaml\n",
      "IRIS/config_iris_0.1_4_50.yaml\n",
      "IRIS/config_iris_0.2_4_50.yaml\n",
      "IRIS/config_iris_0.5_4_50.yaml\n",
      "IRIS/config_iris_1.0_4_50.yaml\n",
      "IRIS/config_iris_1.5_4_50.yaml\n",
      "IRIS/config_iris_0.1_5_50.yaml\n",
      "IRIS/config_iris_0.2_5_50.yaml\n",
      "IRIS/config_iris_0.5_5_50.yaml\n",
      "IRIS/config_iris_1.0_5_50.yaml\n",
      "IRIS/config_iris_1.5_5_50.yaml\n"
     ]
    }
   ],
   "source": [
    "for pert_key in pert_keys:\n",
    "    \n",
    "    if pert_key == \"Base_Perturbations\":\n",
    "        exp_types = [\"LINEX\", \"LIME\", \"LIME_smooth\"]\n",
    "        \n",
    "    for exp_type in exp_types:\n",
    "        \n",
    "        print(pert_key, exp_type)\n",
    "        \n",
    "        fu_lot = sorted(list(set(list(zip(*list(zip(*list(res_df.loc[(exp_type, pert_key)].index)))[0:3])))))\n",
    "  \n",
    "        for (cnt, num_envs, kernel_width) in fu_lot:\n",
    "            config_fname = config_fname_base+str(kernel_width)+\"_\"+str(num_envs)+\"_\"+str(cnt)+\".yaml\"\n",
    "            args = Args(config_fname, dataset_key, model_key, pert_key)\n",
    "            \n",
    "            print(config_fname)\n",
    "\n",
    "            # Load the config file\n",
    "            config = yaml.load(open(\n",
    "                    os.path.join(\"config\", args.config_fname)),\n",
    "                    Loader=yaml.FullLoader)\n",
    "\n",
    "            # Load the explanations\n",
    "            exp, exp_time = load_explanations(exp_type, config, args)\n",
    "            if exp_type == \"SHAP\":\n",
    "                offset = shap_mean_offset\n",
    "            else:\n",
    "                offset = np.zeros(y_test.shape)\n",
    "\n",
    "            # Predictions\n",
    "            y_pred_exp = ind_predictions(X_test, exp, offset=offset)\n",
    "\n",
    "            # fidelity\n",
    "            if \"fid\" in metrics_to_compute:\n",
    "                fid = fidelity(y_pred, y_pred_exp)\n",
    "\n",
    "            # CAC metric\n",
    "            if \"stability\" in metrics_to_compute:\n",
    "                stab = class_attribution_consistency(X_test, y_test, exp, cls=0)\n",
    "\n",
    "            for k in ks:\n",
    "                # CAC metric\n",
    "                if \"stability\" in metrics_to_compute:\n",
    "                    res_df.loc[(exp_type, pert_key, \n",
    "                                cnt, num_envs, kernel_width, k), \"stability\"] = stab\n",
    "\n",
    "                # Fidelity metric\n",
    "                if \"fid\" in metrics_to_compute:\n",
    "                    res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k), \n",
    "                               \"fid\"] = fid\n",
    "\n",
    "                # compute time\n",
    "                if \"compute_time\" in metrics_to_compute:\n",
    "                    res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k), \n",
    "                               \"compute_time\"] = exp_time\n",
    "                \n",
    "                # Other metrics that depend on nearest neighbors\n",
    "                if len(set(generalized_metrics).intersection(set(metrics_to_compute))) > 0:\n",
    "                    \n",
    "                    rowinds = knn[k][0]\n",
    "                    colinds = knn[k][1]\n",
    "\n",
    "                    # prediction consistency\n",
    "                    if \"pred_cons\" in metrics_to_compute:\n",
    "                        res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k), \n",
    "                                   \"pred_cons\"] = prediction_consistency(y_pred_exp, rowinds, colinds)\n",
    "\n",
    "                    # Coefficient consistency\n",
    "                    if \"coef_cons\" in metrics_to_compute:\n",
    "                        res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k), \n",
    "                                   \"coef_cons\"] = coefficient_consistency(exp, rowinds, colinds)\n",
    "\n",
    "                    # generalized fidelity\n",
    "                    if \"gen_fid\" in metrics_to_compute:\n",
    "                        res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k), \n",
    "                                   \"gen_fid\"] = generalized_fidelity(X_test, y_pred, exp, \n",
    "                                                                     rowinds, colinds, offset=offset)\n",
    "\n",
    "                    # undirectionality (sign consistency) after optional coef thresholding\n",
    "                    if \"unidir\" in metrics_to_compute:\n",
    "                        # Unidirectionality\n",
    "                        if type(exp) == list:\n",
    "                            exps = np.vstack(exp)\n",
    "                        else:\n",
    "                            exps = exp\n",
    "                        res_df.loc[(exp_type, pert_key, cnt, num_envs, kernel_width, k),\n",
    "                                   \"unidir\"] = sign_consistency(exps[rowinds],\n",
    "                                                                        exps[colinds], thresh=1e-4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save as pickle\n",
    "res_df.to_pickle(results_fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
