{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eab5c2b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "316d8287",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.metrics import roc_auc_score, r2_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier\n",
    "\n",
    "import imodels\n",
    "#from imodels import HSTreeRegressorCV, HSTreeClassifierCV\n",
    "#from imodels import HSTreeRegressor, HSTreeClassifier\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from datasets import DATASETS_REGRESSION, DATASETS_CLASSIFICATION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a739863e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment(X, y, max_leaf_nodes, type_=\"classification\", N=10):\n",
    "    score_CART = list()\n",
    "    score_lbsCART = list()\n",
    "    score_hsCART = list()\n",
    "    \n",
    "    if type_ == \"classification\":\n",
    "        metric = roc_auc_score\n",
    "        tree_model = DecisionTreeClassifier\n",
    "        lbstree_model = imodels.HSTreeClassifier\n",
    "        hstree_model = imodels.HSTreeClassifierCV\n",
    "    elif type_ == \"regression\":\n",
    "        metric = r2_score\n",
    "        tree_model = DecisionTreeRegressor\n",
    "        lbstree_model = imodels.HSTreeRegressor\n",
    "        hstree_model = imodels.HSTreeRegressorCV\n",
    "\n",
    "    for _ in range(N):\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n",
    "\n",
    "        CART = tree_model(max_leaf_nodes=max_leaf_nodes)\n",
    "        CART.fit(X_train, y_train)\n",
    "        lbsCART = lbstree_model(deepcopy(CART), shrinkage_scheme_=\"leaf_based\")\n",
    "        hsCART = hstree_model(deepcopy(CART), reg_params=[0.1, 1, 10, 25, 50, 100])\n",
    "        hsCART.fit(X_train, y_train)\n",
    "        \n",
    "        if type_ == \"classification\":\n",
    "            y_pred_CART = CART.predict_proba(X_test)[:, 1]\n",
    "            y_pred_lbsCART = lbsCART.predict_proba(X_test)[:, 1]\n",
    "            y_pred_hsCART = hsCART.predict_proba(X_test)[:, 1]\n",
    "        elif type_ == \"regression\":\n",
    "            y_pred_CART = CART.predict(X_test)\n",
    "            y_pred_lbsCART = lbsCART.predict(X_test)\n",
    "            y_pred_hsCART = hsCART.predict(X_test)\n",
    "        \n",
    "        score_CART.append(metric(y_test, y_pred_CART))\n",
    "        score_lbsCART.append(metric(y_test, y_pred_lbsCART))\n",
    "        score_hsCART.append(metric(y_test, y_pred_hsCART))\n",
    "    \n",
    "    return score_CART, score_lbsCART, score_hsCART"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa3c4ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def better_same_worse(dataset_name, database_name, type_=\"classification\", N=100, ROPE=0.005):\n",
    "    df = pd.DataFrame(columns=[\"dataset\", \"n_leaves\", \"score\"])\n",
    "    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)\n",
    "    cols = np.array(cols)\n",
    "\n",
    "    leaf_nodes = [2, 4, 8, 12, 16, 20, 24, 28, 30, 32]\n",
    "\n",
    "    SAME = list()\n",
    "    BETTER = list()\n",
    "    WORSE = list()\n",
    "\n",
    "    for max_leaf_nodes in tqdm(leaf_nodes):\n",
    "        score_CART, score_lbsCART, score_hsCART = experiment(X, y, max_leaf_nodes, type_, N=N)\n",
    "        score_CART = np.array(score_CART)\n",
    "        score_lbsCART = np.array(score_lbsCART)\n",
    "        score_hsCART = np.array(score_hsCART)\n",
    "        \n",
    "        df = pd.concat([df, pd.DataFrame.from_dict(\n",
    "            {\"dataset\":[dataset_name]*N, \"n_leaves\":[max_leaf_nodes]*N, \"score\":score_lbsCART}\n",
    "        )], ignore_index=True)\n",
    "\n",
    "        idx_CART = np.random.choice(range(N), size=N, replace=True)\n",
    "        idx_lbsCART = np.random.choice(range(N), size=N, replace=True)\n",
    "        idx_hsCART = np.random.choice(range(N), size=N, replace=True)\n",
    "\n",
    "        SAME.append((np.abs(score_lbsCART[idx_lbsCART] - score_hsCART[idx_hsCART]) <= ROPE).sum())\n",
    "        BETTER.append((score_hsCART[idx_hsCART] - score_lbsCART[idx_lbsCART] > ROPE).sum())\n",
    "        WORSE.append((score_lbsCART[idx_lbsCART] - score_hsCART[idx_hsCART] > ROPE).sum())\n",
    "    \n",
    "    return np.sum(BETTER), np.sum(SAME), np.sum(WORSE), df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fac2082",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_better_same_worse(DATASETS, type_=\"classification\", save=False):\n",
    "    BETTER, SAME, WORSE = list(), list(), list()\n",
    "    names = list()\n",
    "    df = pd.DataFrame(columns=[\"dataset\", \"n_leaves\", \"score\"])\n",
    "    for save_name, (dataset_name, database_name) in DATASETS.items():\n",
    "        better, same, worse, df_dataset = better_same_worse(dataset_name, database_name, type_)\n",
    "        BETTER.append(better)\n",
    "        SAME.append(same)\n",
    "        WORSE.append(worse)\n",
    "        names.append(save_name)\n",
    "        df = pd.concat([df, df_dataset], ignore_index=True)\n",
    "        df.to_csv(\"tmp.csv\", index=False)\n",
    "\n",
    "    SAME = np.array(SAME)\n",
    "    BETTER = np.array(BETTER)\n",
    "    WORSE = np.array(WORSE)\n",
    "    names = np.array(names)\n",
    "\n",
    "    N = SAME+BETTER+WORSE\n",
    "    SAME = np.divide(SAME, N)\n",
    "    BETTER = np.divide(BETTER, N)\n",
    "    WORSE = np.divide(WORSE, N)\n",
    "\n",
    "    plt.bar(names, BETTER, color=\"#56B4E9\", label=\"better\")\n",
    "    plt.bar(names, SAME, bottom=BETTER, color=\"#009E73\", label=\"same\")\n",
    "    plt.bar(names, WORSE, bottom=BETTER+SAME, color=\"#E69F00\", label=\"worse\")\n",
    "    plt.axhline(0.5, color=\"red\", linestyle=\"--\")\n",
    "    plt.xticks(rotation=45)\n",
    "    plt.ylabel(\"probability of HS being better than LBS\")\n",
    "    plt.legend()\n",
    "    if save: \n",
    "        plt.savefig(\"../../figures/HS_vs_LBS/\"+type_, bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b4e79f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_classification = plot_better_same_worse(DATASETS_CLASSIFICATION, \"classification\", save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d207fdb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_regression = plot_better_same_worse(DATASETS_REGRESSION, \"regression\", save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56585fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lbs_vs_hs(dataset_name, database_name, type_=\"classification\", save=False, save_name=None):\n",
    "    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)\n",
    "    cols = np.array(cols)\n",
    "\n",
    "    SCORE_CART_m, SCORE_lbsCART_m, SCORE_hsCART_m = list(), list(), list()\n",
    "    SCORE_CART_s, SCORE_lbsCART_s, SCORE_hsCART_s = list(), list(), list()\n",
    "\n",
    "    leaf_nodes = [2, 4, 8, 12, 16, 20, 24, 28, 30, 32]\n",
    "\n",
    "    for max_leaf_nodes in leaf_nodes:\n",
    "        score_cart, score_lbscart, score_hscart = experiment(X, y, max_leaf_nodes, type_)\n",
    "        SCORE_CART_m.append(np.mean(score_cart))\n",
    "        SCORE_CART_s.append(np.std(score_cart))\n",
    "        SCORE_lbsCART_m.append(np.mean(score_lbscart))\n",
    "        SCORE_lbsCART_s.append(np.std(score_lbscart))\n",
    "        SCORE_hsCART_m.append(np.mean(score_hscart))\n",
    "        SCORE_hsCART_s.append(np.std(score_hscart))\n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "    plt.clf()\n",
    "    plt.errorbar(leaf_nodes, SCORE_CART_m, yerr=SCORE_CART_s, color=\"lightsalmon\", label=\"CART\")\n",
    "    plt.errorbar(leaf_nodes, SCORE_lbsCART_m, yerr=SCORE_lbsCART_s, color=\"goldenrod\", label=\"CART (LBS)\")\n",
    "    plt.errorbar(leaf_nodes, SCORE_hsCART_m, yerr=SCORE_hsCART_s, color=\"firebrick\", label=\"hsCART\")\n",
    "    plt.legend()\n",
    "    plt.xlabel(\"Number of Leaves\")\n",
    "    plt.ylabel(\"AUC\" if type_ == \"classification\" else \"R2\")\n",
    "    if save:\n",
    "        plt.savefig(f\"../graphs/claim_2/HS_vs_LBS/{type_}_{save_name}\", bbox_inches=\"tight\", facecolor=\"white\", edgecolor=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b227aa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for save_name, (dataset_name, database_name) in tqdm(DATASETS_CLASSIFICATION.items()):\n",
    "    lbs_vs_hs(dataset_name, database_name, \"classification\", True, save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f05cef1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for save_name, (dataset_name, database_name) in tqdm(DATASETS_REGRESSION.items()):\n",
    "    lbs_vs_hs(dataset_name, database_name, \"regression\", True, save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff3f286",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
