{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e7c6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc684c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "import imodels\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import sklearn\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.inspection import DecisionBoundaryDisplay\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.tree import DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7a2df44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decision_boundary(X, y, model, model_name, cols, filename, save=False):\n",
    "    fig = plt.figure(model_name)\n",
    "    plt.clf()\n",
    "    disp = DecisionBoundaryDisplay.from_estimator(\n",
    "        model, X, grid_resolution=200, response_method=\"predict\", alpha=0.5, cmap=\"coolwarm\"\n",
    "    )\n",
    "    ax = disp.ax_\n",
    "    ax.scatter(X[y==0, 0], X[y==0, 1], s=3, c=\"dodgerblue\", marker=\"s\", alpha=0.3, label=0)\n",
    "    ax.scatter(X[y==1, 0], X[y==1, 1], s=3, c=\"orangered\", marker=\"^\", alpha=0.3, label=1)\n",
    "    ax.set_xlabel(cols[0])\n",
    "    ax.set_ylabel(cols[1])\n",
    "    ax.set_title(f\"{model_name} (AUC {sklearn.metrics.roc_auc_score(y, model.predict(X)):.3f})\")\n",
    "    ax.legend()\n",
    "    if save:\n",
    "        plt.savefig(filename, bbox_inches=None, facecolor=\"white\", edgecolor=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a32e1f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_importance(X, y, cols, N=10):\n",
    "    mean_feature_importance = np.zeros(X.shape[1])\n",
    "    for _ in range(N):\n",
    "        tree = DecisionTreeClassifier(criterion='gini')\n",
    "        tree.fit(X, y)\n",
    "        mean_feature_importance += tree.feature_importances_\n",
    "    mean_feature_importance /= N\n",
    "    most_important = np.flip(np.argsort(mean_feature_importance))\n",
    "    print(f\"     Feature importance\\n{34*'='}\")\n",
    "    for i, _ in zip(most_important, range(10)):\n",
    "        print(f\"{cols[i]:>25s} | {mean_feature_importance[i]:.3f}\")\n",
    "    return cols[most_important[:2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d846253",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_simpler_boundary(dataset_name, database_name, cols_name=None, save=False):\n",
    "    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)\n",
    "    cols = np.array(cols)\n",
    "    \n",
    "    # Select provided columns\n",
    "    if cols_name:\n",
    "        _save_name = \"\"\n",
    "        new_X = X[:, np.array([np.where(cols == cols_name[0])[0][0], np.where(cols == cols_name[1])[0][0]])]\n",
    "    else:\n",
    "        # Test which feature is the most important\n",
    "        _save_name = \"_reproduced\"\n",
    "        cols_name = feature_importance(X, y, cols)\n",
    "    new_X = X[:, np.array([np.where(cols == cols_name[0])[0][0], np.where(cols == cols_name[1])[0][0]])]\n",
    "    X_train, X_test, y_train, y_test = train_test_split(new_X, y, test_size=0.33)\n",
    "    \n",
    "    # Train the random forest\n",
    "    RF = RandomForestClassifier(n_estimators=50)\n",
    "    RF.fit(X_train, y_train)\n",
    "    hsRF = imodels.HSTreeClassifierCV(deepcopy(RF))\n",
    "    hsRF.fit(X_train, y_train)\n",
    "    print(f\"Optimal lambda: {hsRF.reg_param}\")\n",
    "    \n",
    "    decision_boundary(X_test, y_test, RF, \"RF\", cols_name, \"../graphs/claim_4/boundaries/\"+dataset_name+\"_RF\"+_save_name, save)\n",
    "    decision_boundary(X_test, y_test, hsRF, \"hsRF\", cols_name, \"../graphs/claim_4/boundaries/\"+dataset_name+\"_hsRF\"+_save_name, save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a2f7f9",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"heart\", \"imodels\", [\"att_8\", \"att_10\"])\n",
    "test_simpler_boundary(\"heart\", \"imodels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b779da71",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"breast_cancer\", \"imodels\", [\"age\", \"tumor-size\"])\n",
    "test_simpler_boundary(\"breast_cancer\", \"imodels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1782db47",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"haberman\", \"imodels\", [\"Age_of_patient_at_time_of_operation\", \"Number_of_positive_axillary_nodes_detected\"])\n",
    "test_simpler_boundary(\"haberman\", \"imodels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e710b7df",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"ionosphere\", \"pmlb\", [\"X_4\", \"X_6\"])\n",
    "test_simpler_boundary(\"ionosphere\", \"pmlb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df3fb734",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"diabetes\", \"pmlb\", [\"A2\", \"A6\"])\n",
    "test_simpler_boundary(\"diabetes\", \"pmlb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2ece4ce",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"german\", \"pmlb\", [\"Credit\", \"Age\"])\n",
    "test_simpler_boundary(\"german\", \"pmlb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dee9e0c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"juvenile_clean\", \"imodels\", [\"friends_broken_in_steal:1\", \"fr_suggest_agnts_law:2\"])\n",
    "test_simpler_boundary(\"juvenile_clean\", \"imodels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "850061c4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "test_simpler_boundary(\"compas_two_year_clean\", \"imodels\", [\"age\", \"priors_count\"])\n",
    "test_simpler_boundary(\"compas_two_year_clean\", \"imodels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e16f0209",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:16:33) [MSC v.1929 64 bit (AMD64)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
