{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccd443eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15222933",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "\n",
    "\n",
    "import imodels\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import shap\n",
    "\n",
    "from datasets import DATASETS_CLASSIFICATION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cd5344f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_variability(dataset_name, database_name, N=100, save=False):\n",
    "    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)\n",
    "    cols = np.array(cols)\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=50)\n",
    "\n",
    "    variance_RF = np.zeros(len(cols))\n",
    "    variance_hsRF = np.zeros(len(cols))\n",
    "    idx_all = list(range(X_train.shape[0]))\n",
    "\n",
    "    for _ in tqdm(range(N)):\n",
    "        idx_train, _ = train_test_split(idx_all, test_size=0.33)\n",
    "        X_t, y_t = X_train[idx_train], y_train[idx_train]\n",
    "\n",
    "        RF = RandomForestClassifier(n_estimators=50)\n",
    "        RF.fit(X_t, y_t)\n",
    "        hsRF = imodels.HSTreeClassifierCV(deepcopy(RF))\n",
    "        hsRF.fit(X_t, y_t)\n",
    "\n",
    "        shap_values_RF = shap.TreeExplainer(RF).shap_values(X_test)[1]\n",
    "        variance_RF += shap_values_RF.std(axis=0)\n",
    "        shap_values_hsRF = shap.TreeExplainer(hsRF.estimator_).shap_values(X_test)[1]\n",
    "        variance_hsRF += shap_values_hsRF.std(axis=0)\n",
    "\n",
    "    variance_RF /= N\n",
    "    variance_hsRF /= N\n",
    "    \n",
    "    n_take = min(len(cols), 10)\n",
    "    #xaxis = np.array(range(variance_RF.shape[0]))\n",
    "    xaxis = np.array(range(n_take))\n",
    "    fig = plt.figure(dataset_name)\n",
    "    plt.clf()\n",
    "    plt.bar(xaxis-0.2, variance_RF[:n_take], width=0.4, color=\"firebrick\", label=\"RF\")\n",
    "    plt.bar(xaxis+0.2, variance_hsRF[:n_take], width=0.4, color=\"black\", label=\"hsRF\")\n",
    "    plt.xticks(xaxis, cols[:n_take], rotation=90)\n",
    "    plt.legend()\n",
    "    plt.ylabel(\"SHAP Variability\")\n",
    "    if save:\n",
    "        plt.savefig(\"../graphs/claim_4/SHAP_variability/\"+dataset_name, bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e69edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "for (dataset_name, database_name) in DATASETS_CLASSIFICATION.values():\n",
    "    feature_variability(dataset_name, database_name, save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85d89ec9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
