{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from data import *\n",
    "import pickle\n",
    "from dpdt.utils import extract_tree, Data, average_traj_length_in_mdp\n",
    "import matplotlib.pyplot as plt\n",
    "from baselines import cart_study_post_pruning\n",
    "\n",
    "functions_ = [\n",
    "    # get_bank_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    # get_raisin_data,\n",
    "]\n",
    "names_ = [\n",
    "    # \"bank\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    # \"raisin\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "zetas = np.linspace(-1, 0, 1000)\n",
    "policy_folder = \"\"\n",
    "for f, dataset in enumerate(names_):\n",
    "    S, Y = functions_[f]()\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    # DPDT\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "    #dpdt3\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart3_selector_depth5\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    scores, depths, nodes, lengths = (\n",
    "        np.zeros(zetas.shape[0], dtype=np.float64),\n",
    "        np.zeros(zetas.shape[0], dtype=np.uint8),\n",
    "        np.zeros(zetas.shape[0], dtype=np.uint8),\n",
    "        np.zeros(zetas.shape[0], dtype=np.float64),\n",
    "    )\n",
    "\n",
    "\n",
    "    scores_test = np.zeros(zetas.shape[0], dtype=np.float64)\n",
    "\n",
    "\n",
    "    for i, zeta in enumerate(zetas):\n",
    "        scores[i], lengths[i] = average_traj_length_in_mdp(S, Y, policy, init_obs, i)\n",
    "        _, nodes[i], depths[i] = extract_tree(policy, init_obs, zeta=i)\n",
    "        scores_test[i], _ = average_traj_length_in_mdp(S_test, Y_test, policy, init_obs, i)\n",
    "    time_ = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart3_selector_depth5.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    # CART\n",
    "    (\n",
    "        scores_cart,\n",
    "        scores_cart_test,\n",
    "        depths_cart,\n",
    "        nodes_cart,\n",
    "        time_cart,\n",
    "        lengths_cart,\n",
    "    ) = cart_study_post_pruning(S, Y, max_depth=5, S_test=S_test, Y_test=Y_test)\n",
    "    \n",
    "    #dpdt2\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart2_selector_depth5\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    scores_, depths, nodes_, lengths_ = (\n",
    "        np.zeros(zetas.shape[0], dtype=np.float64),\n",
    "        np.zeros(zetas.shape[0], dtype=np.uint8),\n",
    "        np.zeros(zetas.shape[0], dtype=np.uint8),\n",
    "        np.zeros(zetas.shape[0], dtype=np.float64),\n",
    "    )\n",
    "\n",
    "\n",
    "    scores_test_ = np.zeros(zetas.shape[0], dtype=np.float64)\n",
    "\n",
    "\n",
    "    for i, zeta in enumerate(zetas):\n",
    "        scores_[i], lengths_[i] = average_traj_length_in_mdp(S, Y, policy, init_obs, i)\n",
    "        _, nodes_[i], depths[i] = extract_tree(policy, init_obs, zeta=i)\n",
    "        scores_test_[i], _ = average_traj_length_in_mdp(S_test, Y_test, policy, init_obs, i)\n",
    "    time__ = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart2_selector_depth5.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "\n",
    "    plt.plot(lengths, scores, linewidth=4, label=\"DPDT-3 Train in \" + str(np.round(time_,3)) + \"s\", alpha = 1, c = \"green\")\n",
    "    # plt.plot(lengths, scores_test, linewidth=4, label=\"DPDT-3 Test\", alpha = 1, c = \"green\")\n",
    "    \n",
    "\n",
    "    plt.plot(lengths_, scores_, linewidth=4, label=\"DPDT-2 Train in \" + str(np.round(time__,3)) + \"s\", alpha = 1, c = \"magenta\")\n",
    "    # plt.plot(lengths_, scores_test_, linewidth=4, label=\"DPDT-2 Test\", alpha = 1, c = \"magenta\")\n",
    "    \n",
    "\n",
    "    plt.plot(\n",
    "        lengths_cart,\n",
    "        scores_cart,\n",
    "        linewidth=4,\n",
    "        label=\"CART-PP Train in \" + str(np.round(time_cart,3)) + \"s\",\n",
    "        alpha=1, c = \"blue\"\n",
    "    )\n",
    "\n",
    "    # plt.plot(\n",
    "    #     lengths_cart,\n",
    "    #     scores_cart_test,\n",
    "    #     linewidth=4,\n",
    "    #     label=\"CART-PP Test\",\n",
    "    #     alpha=1, c = \"blue\"\n",
    "    # )\n",
    "    plt.xlabel(\"average tests per sample\", fontdict={\"size\": 14})\n",
    "    plt.ylabel(\"accuracy\", fontdict={\"size\": 14})\n",
    "    plt.grid()\n",
    "    plt.title(dataset, fontdict={\"size\": 18})\n",
    "    plt.legend(loc=\"lower right\", prop={\"size\": 12})\n",
    "    plt.savefig(\"{}.pdf\".format(dataset))\n",
    "    plt.clf()\n",
    "\n",
    "\n",
    "\n",
    "    plt.plot(nodes, scores, linewidth=4, label=\"DPDT-3 Train in \" + str(np.round(time_,3)) + \"s\", alpha = 1, c = \"green\")\n",
    "    # plt.plot(nodes, scores_test, linewidth=4, label=\"DPDT-3 Test\", alpha = 1, c = \"green\")\n",
    "    plt.plot(nodes_, scores_, linewidth=4, label=\"DPDT-2 Train in \" + str(np.round(time__,3)) + \"s\", alpha = 1, c = \"magenta\")\n",
    "    # plt.plot(nodes_, scores_test_, linewidth=4, label=\"DPDT-2 Test\", alpha = 1, c = \"magenta\")\n",
    "    \n",
    "    plt.plot(\n",
    "        nodes_cart,\n",
    "        scores_cart,\n",
    "        linewidth=4,\n",
    "        label=\"CART-PP Train in \" + str(np.round(time_cart,3)) + \"s\",\n",
    "        alpha=1, c = \"blue\"\n",
    "    )\n",
    "\n",
    "    # plt.plot(\n",
    "    #     nodes_cart,\n",
    "    #     scores_cart_test,\n",
    "    #     linewidth=4,\n",
    "    #     label=\"CART-PP Test\",\n",
    "    #     alpha=1, c = \"blue\"\n",
    "    # )\n",
    "    plt.xlabel(\"nodes\", fontdict={\"size\": 14})\n",
    "    plt.ylabel(\"accuracy\", fontdict={\"size\": 14})\n",
    "    plt.grid()\n",
    "    plt.title(dataset, fontdict={\"size\": 18})\n",
    "    plt.legend(loc=\"lower right\", prop={\"size\": 12})\n",
    "    plt.savefig(\"{}nodes.pdf\".format(dataset))\n",
    "    plt.clf()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
