{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from data import *\n",
    "import pickle\n",
    "from dpdt.utils import extract_tree, Data, average_traj_length_in_mdp\n",
    "import matplotlib.pyplot as plt\n",
    "from baselines import cart_study_post_pruning\n",
    "from time import time\n",
    "\n",
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avila & 66.9 & 65.7 & 60.5 & 58.496 & 2.576 & 1.548 & 4.9 & 4.9 & 4.8\\\\\n",
      "bank & 99.3 & 97.8 & 99.3 & 3.181 & 0.217 & 0.046 & 3.2 & 3.7 & 3.4\\\\\n",
      "bean & 91.1 & 91.1 & 89.9 & 137.345 & 5.578 & 8.343 & 4.6 & 4.9 & 5.0\\\\\n",
      "bidding & 99.2 & 99.2 & 99.2 & 4.725 & 0.406 & 0.138 & 1.4 & 1.4 & 2.3\\\\\n",
      "eeg & 78.0 & 74.6 & 73.0 & 85.959 & 3.233 & 1.414 & 4.6 & 4.8 & 5.0\\\\\n",
      "fault & 71.8 & 72.8 & 57.9 & 41.053 & 1.463 & 0.847 & 5.0 & 4.5 & 4.9\\\\\n",
      "htru & 98.0 & 98.3 & 98.3 & 71.19 & 4.421 & 3.427 & 1.1 & 2.4 & 4.7\\\\\n",
      "magic & 84.5 & 84.8 & 82.5 & 131.668 & 6.707 & 4.924 & 5.0 & 4.8 & 5.0\\\\\n",
      "occupancy & 99.5 & 99.5 & 99.5 & 12.451 & 1.203 & 0.248 & 1.0 & 1.0 & 1.4\\\\\n",
      "page & 97.1 & 97.1 & 96.7 & 38.894 & 1.694 & 0.544 & 3.5 & 5.0 & 4.8\\\\\n",
      "raisin & 87.8 & 91.1 & 90.0 & 10.544 & 0.682 & 0.108 & 3.1 & 2.3 & 4.5\\\\\n",
      "rice & 93.7 & 94.2 & 93.4 & 25.159 & 1.437 & 0.547 & 1.6 & 1.7 & 3.6\\\\\n",
      "room & 99.2 & 99.4 & 99.4 & 27.466 & 1.771 & 0.377 & 2.5 & 2.3 & 4.1\\\\\n",
      "segment & 93.5 & 93.1 & 87.4 & 10.156 & 0.812 & 0.291 & 3.7 & 3.9 & 3.9\\\\\n",
      "skin & 99.5 & 99.2 & 98.6 & 366.356 & 19.487 & 3.367 & 3.8 & 3.8 & 4.2\\\\\n",
      "wilt & 87.2 & 84.8 & 87.6 & 5.131 & 0.517 & 0.185 & 4.1 & 3.2 & 3.9\\\\\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "zetas = np.linspace(-1, 0, 1000)\n",
    "for f, dataset in enumerate(names_):\n",
    "    a, b, c = 0, 0, 0\n",
    "    ##################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    #dpdt3\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart3_selector_depth5\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lenghts = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s > best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "    a, a_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    a = np.round(a * 100, 1)\n",
    "    time_a = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart3_selector_depth5.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_a = np.round(time_a, 3)\n",
    "    ##########################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart2_selector_depth5\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lengths = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s > best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "\n",
    "    b, b_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    b = np.round(b*100, 1)\n",
    "    time_b = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart2_selector_depth5.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_b = np.round(time_b, 3)\n",
    "    ####################################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    ts = time()\n",
    "    clf = DecisionTreeClassifier(\n",
    "            criterion=\"entropy\", max_depth=5, random_state=0\n",
    "        )\n",
    "    path = clf.cost_complexity_pruning_path(S, Y)\n",
    "    ccp_alphas, impurities = path.ccp_alphas, path.impurities\n",
    "    clfs = []\n",
    "    for ccp_alpha in ccp_alphas:\n",
    "        clf = DecisionTreeClassifier(\n",
    "            random_state=0,\n",
    "            ccp_alpha=ccp_alpha,\n",
    "            criterion=\"entropy\",\n",
    "            max_depth=5,\n",
    "        )\n",
    "        clf.fit(S, Y)\n",
    "        clfs.append(clf)\n",
    "    te = time()\n",
    "    \n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    for i, clf in enumerate(clfs):\n",
    "        s = clf.score(test_1X, test_1Y)\n",
    "        if s > best_score:\n",
    "            best_mod = clf\n",
    "            best_score = s\n",
    "    c = np.round(best_mod.score(test_2X, test_2Y) * 100, 1)\n",
    "    time_c = np.round(te-ts, 3)\n",
    "    node_indicator = best_mod.decision_path(test_2X)\n",
    "    c_lenghts = np.round(node_indicator.sum(axis=1).mean() - 1, 1)\n",
    "\n",
    "    print(\"{} & {} & {} & {} & {} & {} & {} & {} & {} & {}\\\\\\\\\".format(dataset,a,b,c, time_a, time_b, time_c,np.round(a_lengths,1), np.round(b_lengths,1), c_lenghts))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
