{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import pathlib\n",
    "\n",
    "from os.path import join, exists\n",
    "path_to_file = str(pathlib.Path().resolve())\n",
    "dir_path = join(path_to_file, \"../../\")\n",
    "\n",
    "import sys\n",
    "sys.path.append(join(dir_path, \"HelperFiles\"))\n",
    "\n",
    "from helper import *\n",
    "np.set_printoptions(suppress=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SHAP Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "####################\n",
      "Method:  StableSHAP \tK: 2 \tAlpha: 0.1\n",
      "Max FWER (%)\n",
      "[[8. 6. 6. 4. 0.]\n",
      " [2. 2. 6. 2. 4.]]\n",
      "####################\n",
      "Method:  StableSHAP \tK: 2 \tAlpha: 0.2\n",
      "Max FWER (%)\n",
      "[[16. 14. 14.  8. 10.]\n",
      " [16.  0. 10.  2. 12.]]\n",
      "####################\n",
      "Method:  StableSHAP \tK: 5 \tAlpha: 0.1\n",
      "Max FWER (%)\n",
      "[[ 6. 10. 10.  4.  6.]\n",
      " [ 6. 10. 10.  4.  4.]]\n",
      "####################\n",
      "Method:  StableSHAP \tK: 5 \tAlpha: 0.2\n",
      "Max FWER (%)\n",
      "[[14. 18. 20. 12. 20.]\n",
      " [14. 16. 20. 16.  4.]]\n",
      "####################\n",
      "Method:  SPRT-SHAP \tK: 2 \tAlpha: 0.1\n",
      "Max FWER (%)\n",
      "[[ 0.  0. nan  0.  0.]\n",
      " [ 0.  0.  0.  0.  0.]]\n",
      "####################\n",
      "Method:  SPRT-SHAP \tK: 2 \tAlpha: 0.2\n",
      "Max FWER (%)\n",
      "[[ 2.  2. nan  0.  4.]\n",
      " [ 0.  2.  0.  0.  0.]]\n",
      "####################\n",
      "Method:  SPRT-SHAP \tK: 5 \tAlpha: 0.1\n",
      "Max FWER (%)\n",
      "[[nan nan nan nan nan]\n",
      " [ 0.  0.  0.  0.  0.]]\n",
      "####################\n",
      "Method:  SPRT-SHAP \tK: 5 \tAlpha: 0.2\n",
      "Max FWER (%)\n",
      "[[nan nan nan nan  0.]\n",
      " [ 8.  2.  2.  2.  0.]]\n"
     ]
    }
   ],
   "source": [
    "datasets = [\"census\", \"bank\", \"brca\", \"credit\", \"breast_cancer\"]\n",
    "Ks = [2,5]\n",
    "methods = [\"stableshap\", \"sprtshap\"]\n",
    "guarantees = [\"rank\", \"set\"]\n",
    "for method in methods:\n",
    "    for K in Ks:\n",
    "        for alpha in [0.1, 0.2]:\n",
    "            max_mat = np.empty((len(Ks), len(datasets)))\n",
    "            for i, guarantee in enumerate(guarantees):\n",
    "                data_dir = join(dir_path, \"Experiments\", \"Results\", \"Top_K\", guarantee, \"alpha_\"+str(alpha)) \n",
    "                for j, dataset in enumerate(datasets):\n",
    "                    max_mat[i,j] = np.nan\n",
    "                    fname = method + \"_\" + dataset + \"_K\" + str(K) \n",
    "                    path = join(data_dir, fname)\n",
    "                    if exists(path):\n",
    "                        with open(path, \"rb\") as fp:\n",
    "                            results = pickle.load(fp)\n",
    "                        max_fwer = calc_max_fwer(results)\n",
    "                        if max_fwer is not None:\n",
    "                            max_mat[i, j] = np.round(max_fwer*100)\n",
    "            print(\"#\"*20)\n",
    "            methodName = \"StableSHAP\" if method==\"stableshap\" else \"SPRT-SHAP\"\n",
    "            print(\"Method: \", methodName, \"\\tK:\", str(K), \"\\tAlpha:\", alpha)\n",
    "\n",
    "            print(\"Max FWER (%)\")\n",
    "            print(max_mat)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LIME Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "####################\n",
      "Method: LIME \tK:  2\n",
      "Max FWER (%)\n",
      "[[0. 2. 2. 6. 2.]\n",
      " [2. 0. 0. 8. 2.]]\n",
      "####################\n",
      "Method: LIME \tK:  5\n",
      "Max FWER (%)\n",
      "[[nan  0. nan  0.  0.]\n",
      " [nan  0. nan  0.  0.]]\n"
     ]
    }
   ],
   "source": [
    "method = \"lime\"\n",
    "guarantee = \"rank\"\n",
    "alphas = [0.1, 0.2]\n",
    "for K in Ks:\n",
    "    max_mat = np.full((len(alphas), len(datasets)), np.nan)\n",
    "    for i, alpha in enumerate(alphas):\n",
    "        data_dir = join(dir_path, \"Experiments\", \"Results\", \"Top_K\", guarantee, \"alpha_\"+str(alpha))\n",
    "        for j, dataset in enumerate(datasets):\n",
    "            fname = method + \"_\" + dataset + \"_K\" + str(K) \n",
    "            path = join(data_dir, fname)\n",
    "            if exists(path):\n",
    "                with open(path, \"rb\") as fp:\n",
    "                    results = pickle.load(fp)\n",
    "                max_fwer = calc_max_fwer(results)\n",
    "                if max_fwer is not None:\n",
    "                    max_mat[i, j] = np.round(max_fwer*100)\n",
    "\n",
    "    print(\"#\"*20)\n",
    "    print(\"Method: LIME\", \"\\tK: \", K)\n",
    "\n",
    "    print(\"Max FWER (%)\")\n",
    "    print(max_mat)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rankings",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
