{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from data import *\n",
    "import pickle\n",
    "from dpdt.utils import extract_tree, Data, average_traj_length_in_mdp\n",
    "import matplotlib.pyplot as plt\n",
    "from baselines import cart_study_post_pruning\n",
    "from time import time\n",
    "\n",
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avila & 94.3 & 95.1 & 87.8 & 86.476 & 187.313 & 1.579 & 8.4 & 8.4 & 8.8\\\\\n",
      "bank & 99.3 & 99.3 & 99.3 & 1.664 & 2.174 & 0.028 & 3.3 & 3.3 & 3.4\\\\\n",
      "bean & 91.3 & 90.9 & 91.2 & 102.796 & 309.981 & 8.287 & 5.2 & 4.0 & 6.1\\\\\n",
      "bidding & 99.4 & 99.4 & 99.4 & 1.833 & 3.226 & 0.095 & 2.4 & 2.4 & 2.4\\\\\n",
      "eeg & 83.6 & 83.5 & 82.0 & 85.198 & 229.49 & 2.386 & 8.1 & 8.2 & 9.3\\\\\n",
      "fault & 73.3 & 73.8 & 68.7 & 35.09 & 108.265 & 1.148 & 5.6 & 5.6 & 6.9\\\\\n",
      "htru & 97.6 & 98.0 & 98.1 & 45.941 & 123.689 & 4.234 & 2.2 & 1.2 & 3.4\\\\\n",
      "magic & 85.4 & 84.9 & 84.8 & 146.253 & 391.594 & 7.021 & 5.8 & 5.9 & 8.1\\\\\n",
      "occupancy & 99.5 & 99.5 & 99.5 & 6.847 & 15.608 & 0.226 & 1.0 & 1.0 & 1.4\\\\\n",
      "page & 96.5 & 96.9 & 96.5 & 22.526 & 58.102 & 0.713 & 4.5 & 6.2 & 7.7\\\\\n",
      "raisin & 85.6 & 86.7 & 88.9 & 8.717 & 19.652 & 0.115 & 2.1 & 2.1 & 6.5\\\\\n",
      "rice & 93.4 & 93.2 & 93.7 & 20.18 & 44.867 & 0.626 & 1.8 & 1.8 & 3.0\\\\\n",
      "room & 99.3 & 99.6 & 99.6 & 5.186 & 8.55 & 0.318 & 2.3 & 4.1 & 4.1\\\\\n",
      "segment & 97.0 & 97.0 & 94.8 & 9.796 & 22.562 & 0.286 & 5.1 & 5.1 & 5.0\\\\\n",
      "skin & 99.9 & 99.9 & 99.8 & 120.576 & 308.577 & 2.94 & 6.3 & 6.2 & 5.4\\\\\n",
      "wilt & 86.0 & 86.0 & 84.8 & 2.274 & 3.583 & 0.151 & 4.3 & 4.4 & 4.4\\\\\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "zetas = np.linspace(-1, 0, 1000)\n",
    "for f, dataset in enumerate(names_):\n",
    "    a, b, c = 0, 0, 0\n",
    "    ##################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    #dpdt3\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart_adapt_2_more_selector_depth10\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lenghts = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s > best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "    a, a_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    a = np.round(a * 100, 1)\n",
    "    time_a = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart_adapt_2_more_selector_depth10.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_a = np.round(time_a, 3)\n",
    "    ##########################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart_adapt_22_more_selector_depth10\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    best_lengths = 0\n",
    "    for z, zeta in enumerate(zetas):\n",
    "        s, lengths = average_traj_length_in_mdp(test_1X, test_1Y, policy, init_obs, z)\n",
    "        if s > best_score:\n",
    "            best_mod = z\n",
    "            best_score = s\n",
    "            best_lengths = lenghts\n",
    "\n",
    "    b, b_lengths = average_traj_length_in_mdp(test_2X, test_2Y, policy, init_obs, best_mod)\n",
    "    b = np.round(b*100, 1)\n",
    "    time_b = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart_adapt_22_more_selector_depth10.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    time_b = np.round(time_b, 3)\n",
    "    ####################################################################\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    ts = time()\n",
    "    clf = DecisionTreeClassifier(\n",
    "            criterion=\"entropy\", max_depth=5, random_state=0\n",
    "        )\n",
    "    path = clf.cost_complexity_pruning_path(S, Y)\n",
    "    ccp_alphas, impurities = path.ccp_alphas, path.impurities\n",
    "    clfs = []\n",
    "    for ccp_alpha in ccp_alphas:\n",
    "        clf = DecisionTreeClassifier(\n",
    "            random_state=0,\n",
    "            ccp_alpha=ccp_alpha,\n",
    "            criterion=\"entropy\",\n",
    "            max_depth=10,\n",
    "        )\n",
    "        clf.fit(S, Y)\n",
    "        clfs.append(clf)\n",
    "    te = time()\n",
    "    \n",
    "    test_1X, test_2X = S_test[:int(len(S_test)/2)], S_test[int(len(Y_test)/2):]\n",
    "    test_1Y, test_2Y = Y_test[:int(len(S_test)/2)], Y_test[int(len(Y_test)/2):]\n",
    "    best_score = 0\n",
    "    best_mod = None\n",
    "    for i, clf in enumerate(clfs):\n",
    "        s = clf.score(test_1X, test_1Y)\n",
    "        if s > best_score:\n",
    "            best_mod = clf\n",
    "            best_score = s\n",
    "    c = np.round(best_mod.score(test_2X, test_2Y) * 100, 1)\n",
    "    time_c = np.round(te-ts, 3)\n",
    "    node_indicator = best_mod.decision_path(test_2X)\n",
    "    c_lenghts = np.round(node_indicator.sum(axis=1).mean() - 1, 1)\n",
    "\n",
    "    print(\"{} & {} & {} & {} & {} & {} & {} & {} & {} & {}\\\\\\\\\".format(dataset,a,b,c, time_a, time_b, time_c,np.round(a_lengths,1), np.round(b_lengths,1), c_lenghts))\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
