{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from data import *\n",
    "import pickle\n",
    "from dpdt.utils import average_traj_length_in_mdp, Data\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avila $58.0 / 57.9 \\%$ 0.74\n",
      "bank $98.0 / 97.8 \\%$ 0.082\n",
      "bean $85.0 / 85.1 \\%$ 1.686\n",
      "bidding $99.3 / 99.0 \\%$ 0.139\n",
      "eeg $69.4 / 71.0 \\%$ 0.733\n",
      "fault $65.7 / 64.8 \\%$ 0.323\n",
      "htru $98.0 / 98.2 \\%$ 1.017\n",
      "magic $82.7 / 82.0 \\%$ 1.277\n",
      "occupancy $99.3 / 93.4 \\%$ 0.246\n",
      "page $97.0 / 96.0 \\%$ 0.36\n",
      "raisin $88.3 / 87.2 \\%$ 0.111\n",
      "rice $93.5 / 92.1 \\%$ 0.239\n",
      "room $99.2 / 99.0 \\%$ 0.504\n",
      "segment $88.2 / 84.0 \\%$ 0.14\n",
      "skin $96.7 / 96.7 \\%$ 5.341\n",
      "wilt $99.5 / 80.4 \\%$ 0.14\n"
     ]
    }
   ],
   "source": [
    "zetas = np.array([-1])\n",
    "policy_folder = \"\"\n",
    "for f, dataset in enumerate(names_):\n",
    "    S, Y = functions_[f](test=False)\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart3_selector_depth3\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    \n",
    "    for i, zeta in enumerate(zetas):\n",
    "        train, _ = average_traj_length_in_mdp(\n",
    "            S, Y, policy, init_obs, -1\n",
    "        )\n",
    "        test, _ = average_traj_length_in_mdp(\n",
    "            S_test, Y_test, policy, init_obs, -1\n",
    "        )\n",
    "    time_ = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart3_selector_depth3.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    print(dataset, \"${} / {} \\%$\".format(np.round(train * 100, 1), np.round(test * 100, 1)), np.round(time_, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avila $58.5 / 58.2 \\%$ 9.425\n",
      "bank $98.0 / 97.8 \\%$ 0.439\n",
      "bean $85.6 / 84.9 \\%$ 29.52\n",
      "bidding $99.3 / 99.0 \\%$ 0.974\n",
      "eeg $70.3 / 70.0 \\%$ 14.236\n",
      "fault $68.0 / 65.3 \\%$ 5.698\n",
      "htru $98.0 / 97.9 \\%$ 19.541\n",
      "magic $82.9 / 82.2 \\%$ 25.078\n",
      "occupancy $99.4 / 93.9 \\%$ 2.629\n",
      "page $97.0 / 96.1 \\%$ 5.66\n",
      "raisin $88.5 / 88.3 \\%$ 1.888\n",
      "rice $93.7 / 92.7 \\%$ 4.403\n",
      "room $99.2 / 99.0 \\%$ 5.245\n",
      "segment $88.2 / 84.0 \\%$ 1.512\n",
      "skin $96.7 / 96.7 \\%$ 64.922\n",
      "wilt $99.5 / 79.2 \\%$ 1.139\n"
     ]
    }
   ],
   "source": [
    "zetas = np.array([-1])\n",
    "policy_folder = \"\"\n",
    "for f, dataset in enumerate(names_):\n",
    "    S, Y = functions_[f](test=False)\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart5_selector_depth3\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    \n",
    "    for i, zeta in enumerate(zetas):\n",
    "        train, _ = average_traj_length_in_mdp(\n",
    "            S, Y, policy, init_obs, -1\n",
    "        )\n",
    "        test, _ = average_traj_length_in_mdp(\n",
    "            S_test, Y_test, policy, init_obs, -1\n",
    "        )\n",
    "    time_ = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart5_selector_depth3.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    print(dataset, \"${} / {} \\%$\".format(np.round(train * 100, 1), np.round(test * 100, 1)), np.round(time_, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avila $58.0 / 57.9 \\%$ 3.216\n",
      "bank $98.0 / 97.8 \\%$ 0.285\n",
      "bean $85.0 / 85.1 \\%$ 9.152\n",
      "bidding $99.3 / 99.0 \\%$ 0.577\n",
      "eeg $70.0 / 69.8 \\%$ 4.27\n",
      "fault $65.7 / 64.8 \\%$ 1.74\n",
      "htru $98.0 / 97.9 \\%$ 5.685\n",
      "magic $82.7 / 82.0 \\%$ 7.507\n",
      "occupancy $99.3 / 93.4 \\%$ 1.191\n",
      "page $97.0 / 96.0 \\%$ 1.803\n",
      "raisin $88.3 / 87.2 \\%$ 0.573\n",
      "rice $93.6 / 92.1 \\%$ 1.24\n",
      "room $99.2 / 99.0 \\%$ 2.587\n",
      "segment $88.2 / 84.0 \\%$ 0.591\n",
      "skin $96.7 / 96.7 \\%$ 25.406\n",
      "wilt $99.5 / 79.2 \\%$ 0.584\n"
     ]
    }
   ],
   "source": [
    "zetas = np.array([-1])\n",
    "policy_folder = \"\"\n",
    "for f, dataset in enumerate(names_):\n",
    "    S, Y = functions_[f](test=False)\n",
    "    S_test, Y_test = functions_[f](test=True)\n",
    "    data = Data(np.array(S, dtype=np.float64), np.array(Y, dtype=np.uint8))\n",
    "    init_obs = np.concatenate(\n",
    "        (data.x.min(axis=0) - 1e-3, data.x.max(axis=0) + 1e-3), dtype=np.float64\n",
    "    )\n",
    "\n",
    "    with open(\n",
    "        \"saved_policies/{}_cart4_selector_depth3\".format(dataset) + \".pkl\", \"rb\"\n",
    "    ) as pol:\n",
    "        policy = pickle.load(pol)\n",
    "\n",
    "    \n",
    "    for i, zeta in enumerate(zetas):\n",
    "        train, _ = average_traj_length_in_mdp(\n",
    "            S, Y, policy, init_obs, -1\n",
    "        )\n",
    "        test, _ = average_traj_length_in_mdp(\n",
    "            S_test, Y_test, policy, init_obs, -1\n",
    "        )\n",
    "    time_ = np.load(\n",
    "        \"results_npz/time_\" + dataset + \"_cart4_selector_depth3.npy\",\n",
    "        allow_pickle=True,\n",
    "    )[0]\n",
    "    print(dataset, \"${} / {} \\%$\".format(np.round(train * 100, 1), np.round(test * 100, 1)), np.round(time_, 3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
