{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import pathlib\n",
    "import sys\n",
    "from os.path import join\n",
    "path_to_file = str(pathlib.Path().resolve())\n",
    "dir_path = join(path_to_file, \"../../\")\n",
    "sys.path.append(join(dir_path, \"HelperFiles\"))\n",
    "import helper\n",
    "retro_path = join(path_to_file, \"..\", \"Results\", \"Retrospective\")\n",
    "\n",
    "alphas = [0.05, 0.1, 0.2]\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterate through all datasets with both methods, at multiple alphas\n",
    "\n",
    "Displayed matrix has same shape as Table 1 in paper, just only the ranking methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 2.  6. 10.  4.  4.  8.]\n",
      " [ 2.  4. 12.  2.  4.  8.]\n",
      " [ 2.  8. 10.  2.  2. 10.]\n",
      " [ 2.  8. 10.  2.  2.  6.]\n",
      " [ 2.  8. 10.  4.  6. 10.]]\n",
      "[[0.004 0.011 0.031 0.003 0.005 0.026]\n",
      " [0.005 0.011 0.027 0.002 0.005 0.026]\n",
      " [0.004 0.012 0.026 0.001 0.005 0.017]\n",
      " [0.001 0.011 0.033 0.001 0.002 0.012]\n",
      " [0.003 0.007 0.023 0.003 0.005 0.011]]\n"
     ]
    }
   ],
   "source": [
    "datasets = [\"census\", \"bank\", \"brca\", \"credit\", \"breast_cancer\"]\n",
    "methods = ['ss', 'kernelshap']\n",
    "max_mat = np.empty((len(datasets), len(methods)*len(alphas)))\n",
    "avg_mat = np.empty((len(datasets), len(methods)*len(alphas)))\n",
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    for method in methods:\n",
    "        with open(join(retro_path, method+\"_\"+dataset), 'rb') as f:\n",
    "            retro_results = pickle.load(f)\n",
    "        shap_vals = retro_results[\"shap_vals\"]\n",
    "        N_verified_all = retro_results[\"N_verified\"]\n",
    "        N_pts, N_runs, N_alphas = N_verified_all.shape\n",
    "        shap_vars = retro_results[\"shap_vars\"]\n",
    "\n",
    "        all_ranks = helper.shap_vals_to_ranks(shap_vals, abs=True)\n",
    "\n",
    "        avg_shap = np.mean(shap_vals, axis=1)\n",
    "        avg_ranks = np.array([helper.get_ranking(avg_shap[i], abs=True) for i in range(N_pts)])\n",
    "\n",
    "        fwers = helper.calc_all_retro_fwers(N_verified_all, all_ranks, avg_ranks)\n",
    "        max_fwers = np.round(np.nanmax(fwers, axis=1), 3)\n",
    "        avg_fwers = np.round(np.nanmean(fwers, axis=1), 3)\n",
    "        # print(\"Max:\\t\", max_fwers)\n",
    "        col_start_idx = 0 if method==\"ss\" else 3\n",
    "        max_mat[dataset_idx, col_start_idx:col_start_idx+3] = max_fwers\n",
    "        avg_mat[dataset_idx, col_start_idx:col_start_idx+3] = avg_fwers\n",
    "print(max_mat*100)\n",
    "print(avg_mat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stability of top K set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get rejection results\n",
    "K = 5\n",
    "max_mat_set = np.empty((len(datasets), len(methods)*len(alphas)))\n",
    "avg_mat_set = np.empty((len(datasets), len(methods)*len(alphas)))\n",
    "for dataset_idx, dataset in enumerate(datasets):\n",
    "    for method in methods:\n",
    "        with open(join(retro_path, method+\"_\"+dataset), 'rb') as f:\n",
    "            retro_results = pickle.load(f)\n",
    "        shap_vals = retro_results[\"shap_vals\"]\n",
    "        shap_vars = retro_results[\"shap_vars\"]\n",
    "\n",
    "        all_ranks = helper.shap_vals_to_ranks(shap_vals, abs=True)\n",
    "        avg_shap = np.mean(shap_vals, axis=1)\n",
    "        avg_ranks = np.array([helper.get_ranking(avg_shap[i], abs=True) for i in range(N_pts)])\n",
    "\n",
    "        # skip_thresh = 0.2\n",
    "        max_fwers = []\n",
    "        avg_fwers = []\n",
    "        for alpha in alphas:\n",
    "            fwers_all = []\n",
    "            for i in range(N_pts):\n",
    "                num_false_rejections = 0\n",
    "                true_top_K_set = np.sort(avg_ranks[i,:K])\n",
    "                for j in range(N_runs):\n",
    "                    ss_vals, ss_vars = shap_vals[i,j,:], shap_vars[i,j,:]\n",
    "                    result = helper.test_top_k_set(ss_vals, ss_vars, K=K, alpha=alpha, abs=True)\n",
    "                    if result==\"reject\":\n",
    "                        est_top_K_set = np.sort(all_ranks[i,j,:K])\n",
    "                        if not np.array_equal(true_top_K_set, est_top_K_set):\n",
    "                            num_false_rejections += 1\n",
    "                fwer = num_false_rejections/N_runs\n",
    "                fwers_all.append(fwer)\n",
    "            max_fwer = np.round(np.nanmax(fwers_all), 3)#.item()\n",
    "            max_fwers.append(max_fwer)\n",
    "            avg_fwer = np.round(np.nanmean(fwers_all), 3)#.item()\n",
    "            avg_fwers.append(avg_fwer)\n",
    "        np.array(max_fwers)\n",
    "\n",
    "        col_start_idx = 0 if method==\"ss\" else 3\n",
    "        max_mat_set[dataset_idx, col_start_idx:col_start_idx+3] = max_fwers\n",
    "        avg_mat_set[dataset_idx, col_start_idx:col_start_idx+3] = avg_fwers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 4.  4. 10.  4.  6. 12.]\n",
      " [ 0.  2.  8.  0.  0.  2.]\n",
      " [ 2.  4.  8.  2.  4.  6.]\n",
      " [ 2.  6.  8.  2.  4. 10.]\n",
      " [ 0.  4.  6.  0.  0.  2.]]\n",
      "[[0.2 0.7 1.7 0.2 0.7 2.1]\n",
      " [0.  0.3 1.3 0.  0.  0.1]\n",
      " [0.2 0.5 1.6 0.1 0.2 1.3]\n",
      " [0.2 0.7 2.1 0.1 0.5 1.3]\n",
      " [0.  0.2 0.5 0.  0.  0.1]]\n"
     ]
    }
   ],
   "source": [
    "print(max_mat_set*100)\n",
    "print(avg_mat_set*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Great: Even in worst case, FWER is always controlled!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rankings",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
