{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from data import *\n",
    "import pickle\n",
    "from dpdt.utils import extract_tree, Data, average_traj_length_in_mdp\n",
    "import matplotlib.pyplot as plt\n",
    "from baselines import cart_study_post_pruning\n",
    "from time import time\n",
    "\n",
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "999\n",
      "999\n",
      "avila & 57.6 & 57.9 & 60.5 & 3.216 & 9.425 & 1.267 & 3.0 & 3.0 & 4.826786740754934\\\\\n",
      "999\n",
      "999\n",
      "bank & 100.0 & 100.0 & 99.3 & 0.285 & 0.439 & 0.037 & 3.0 & 3.0 & 3.36231884057971\\\\\n",
      "999\n",
      "999\n",
      "bean & 85.2 & 85.2 & 89.9 & 9.152 & 29.52 & 6.914 & 3.0 & 3.0 & 4.961820851688693\\\\\n",
      "999\n",
      "999\n",
      "bidding & 98.6 & 98.6 & 99.2 & 0.577 & 0.974 & 0.12 & 3.0 & 3.0 & 2.334913112164297\\\\\n",
      "968\n",
      "961\n",
      "eeg & 70.6 & 70.6 & 73.0 & 4.27 & 14.236 & 1.407 & 2.637516688918558 & 2.5714285714285716 & 4.977970627503338\\\\\n",
      "967\n",
      "999\n",
      "fault & 60.0 & 61.5 & 57.9 & 1.74 & 5.698 & 0.9 & 2.594871794871795 & 3.0 & 4.943589743589744\\\\\n",
      "999\n",
      "999\n",
      "htru & 98.0 & 98.0 & 98.3 & 5.685 & 19.541 & 2.95 & 3.0 & 3.0 & 4.681005586592179\\\\\n",
      "999\n",
      "999\n",
      "magic & 81.8 & 82.2 & 82.5 & 7.507 & 25.078 & 4.032 & 3.0 & 3.0 & 4.9879074658254465\\\\\n",
      "991\n",
      "991\n",
      "occupancy & 99.5 & 99.5 & 99.5 & 1.191 & 2.629 & 0.198 & 1.0 & 1.0 & 1.3567402158157513\\\\\n",
      "999\n",
      "999\n",
      "page & 96.4 & 96.4 & 96.7 & 1.803 & 5.66 & 0.436 & 2.9835766423357666 & 3.0 & 4.846715328467154\\\\\n",
      "993\n",
      "999\n",
      "raisin & 88.9 & 88.9 & 90.0 & 0.573 & 1.888 & 0.09 & 2.488888888888889 & 2.588888888888889 & 4.544444444444444\\\\\n",
      "999\n",
      "999\n",
      "rice & 93.2 & 94.5 & 93.4 & 1.24 & 4.403 & 0.43 & 3.0 & 3.0 & 3.561679790026247\\\\\n",
      "998\n",
      "998\n",
      "room & 99.0 & 99.0 & 99.4 & 2.587 & 5.245 & 0.311 & 2.1994076999012835 & 2.1994076999012835 & 4.050345508390918\\\\\n",
      "999\n",
      "999\n",
      "segment & 81.0 & 81.0 & 87.4 & 0.591 & 1.512 & 0.227 & 2.8528138528138527 & 2.8528138528138527 & 3.8571428571428568\\\\\n",
      "999\n",
      "999\n",
      "skin & 96.7 & 96.7 & 98.6 & 25.406 & 64.922 & 2.836 & 3.0 & 3.0 & 4.247939280176284\\\\\n",
      "993\n",
      "993\n",
      "wilt & 83.2 & 83.2 & 87.6 & 0.584 & 1.139 & 0.165 & 1.568 & 1.568 & 3.92\\\\\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "zetas = np.linspace(-1, 0, 1000)\n",
    "for f, dataset in enumerate(names_):\n",
    "    a, b, c = 0, 0, 0\n",
    "    ##################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    #dpdt3\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart4_selector_depth3\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lenghts = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s >= best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "    print(best_mod)\n",
    "    a, a_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    a = np.round(a * 100, 1)\n",
    "    time_a = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart4_selector_depth3.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_a = np.round(time_a, 3)\n",
    "    ##########################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart5_selector_depth3\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lengths = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s >= best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "    print(best_mod)\n",
    "    b, b_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    b = np.round(b*100, 1)\n",
    "    time_b = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart5_selector_depth3.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_b = np.round(time_b, 3)\n",
    "    ####################################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    ts = time()\n",
    "    clf = DecisionTreeClassifier(\n",
    "            criterion=\"entropy\", max_depth=5, random_state=0\n",
    "        )\n",
    "    path = clf.cost_complexity_pruning_path(S, Y)\n",
    "    ccp_alphas, impurities = path.ccp_alphas, path.impurities\n",
    "    clfs = []\n",
    "    for ccp_alpha in ccp_alphas:\n",
    "        clf = DecisionTreeClassifier(\n",
    "            random_state=0,\n",
    "            ccp_alpha=ccp_alpha,\n",
    "            criterion=\"entropy\",\n",
    "            max_depth=3,\n",
    "        )\n",
    "        clf.fit(S, Y)\n",
    "        clfs.append(clf)\n",
    "    te = time()\n",
    "    \n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    for i, clf in enumerate(clfs):\n",
    "        s = clf.score(test_1X, test_1Y)\n",
    "        if s > best_score:\n",
    "            best_mod = clf\n",
    "            best_score = s\n",
    "    c = np.round(best_mod.score(test_2X, test_2Y) * 100, 1)\n",
    "    time_c = np.round(te-ts, 3)\n",
    "    node_indicator = best_mod.decision_path(test_2X)\n",
    "    c_lenghts = node_indicator.sum(axis=1).mean() - 1\n",
    "\n",
    "    print(\"{} & {} & {} & {} & {} & {} & {} & {} & {} & {}\\\\\\\\\".format(dataset,a,b,c, time_a, time_b, time_c,a_lengths, b_lengths, c_lenghts))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
