{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d1329f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef3d6c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import imodels\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import shap\n",
    "\n",
    "from datasets import DATASETS_CLASSIFICATION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c7d451",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shap_clusters(dataset_name, database_name, save=False):\n",
    "    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)\n",
    "    cols = np.array(cols)\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=50)\n",
    "    idx_all = list(range(X_train.shape[0]))\n",
    "    idx_train, _ = train_test_split(idx_all, test_size=0.33)\n",
    "    X_t, y_t = X_train[idx_train], y_train[idx_train]\n",
    "    \n",
    "    RF = RandomForestClassifier(n_estimators=50)\n",
    "    RF.fit(X_t, y_t)\n",
    "    hsRF = imodels.HSTreeClassifierCV(deepcopy(RF), reg_param_list=[50, 100, 200, 300, 400, 500])\n",
    "    hsRF.fit(X_t, y_t)\n",
    "    print(f\"Optimal lambda: {hsRF.reg_param}\")\n",
    "    \n",
    "    fig = plt.figure(dataset_name+\"_RF\")\n",
    "    plt.clf()\n",
    "    shap_values_RF = shap.TreeExplainer(RF).shap_values(X_test)[1]\n",
    "    shap.summary_plot(shap_values_RF, X_test, feature_names=cols, max_display=10, sort=False, show=False)\n",
    "    plt.title(\"RF\")\n",
    "    plt.xlabel(\"SHAP value\")\n",
    "    if save:\n",
    "        plt.savefig(\"../graphs/claim_4/SHAP_interpretation/\"+dataset_name+\"_RF\", bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")\n",
    "    \n",
    "    fig = plt.figure(dataset_name+\"_hsRF\")\n",
    "    plt.clf()\n",
    "    shap_values_hsRF = shap.TreeExplainer(hsRF.estimator_).shap_values(X_test)[1]\n",
    "    shap.summary_plot(shap_values_hsRF, X_test, feature_names=cols, max_display=10, sort=False, show=False)\n",
    "    plt.title(\"hsRF\")\n",
    "    plt.xlabel(\"SHAP value\")\n",
    "    if save:\n",
    "        plt.savefig(\"../graphs/claim_4/SHAP_interpretation/\"+dataset_name+\"_hsRF\", bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6cb61f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for (dataset_name, database_name) in DATASETS_CLASSIFICATION.values():\n",
    "    shap_clusters(dataset_name, database_name, save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "effc662e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:16:33) [MSC v.1929 64 bit (AMD64)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
