{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "from src.utils.env_export import create_json\n",
    "from src.utils.khalili_env import get_env\n",
    "from src.policy.jas_voc_policy import JAS_voc_policy\n",
    "from src.utils.mouselab_jas import MouselabJas\n",
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from src.utils.data_classes import Action\n",
    "from src.policy.po_uct import POUCT, POUCT_policy\n",
    "from simulation import run_episode\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "env, cfg = get_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "pouct = POUCT_policy(steps=5000, rollout_depth=0, exploration_coeff=10)\n",
    "mgps = JAS_voc_policy(cost_weight=0.5798921379230035)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 6.0.1 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"350pt\" height=\"476pt\"\n",
       " viewBox=\"0.00 0.00 350.00 476.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-472 346,-472 346,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-450\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-446.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.0</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-374.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.23</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M149.75,-438.67C125.4,-426.83 85.28,-407.33 57.57,-393.86\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"58.86,-390.6 48.33,-389.37 55.8,-396.89 58.86,-390.6\"/>\n",
       "</g>\n",
       "<!-- 7 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>7</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-374.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.9</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;7 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>0&#45;&gt;7</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M156.43,-434.83C146.25,-424.94 132.48,-411.55 120.97,-400.36\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"123.41,-397.85 113.8,-393.38 118.53,-402.87 123.41,-397.85\"/>\n",
       "</g>\n",
       "<!-- 13 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>13</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-374.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.81</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;13 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>0&#45;&gt;13</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-431.7C171,-423.98 171,-414.71 171,-406.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-406.1 171,-396.1 167.5,-406.1 174.5,-406.1\"/>\n",
       "</g>\n",
       "<!-- 19 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>19</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-374.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.01</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;19 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;19</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M185.57,-434.83C195.75,-424.94 209.52,-411.55 221.03,-400.36\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"223.47,-402.87 228.2,-393.38 218.59,-397.85 223.47,-402.87\"/>\n",
       "</g>\n",
       "<!-- 25 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>25</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-374.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.85</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;25 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>0&#45;&gt;25</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M192.25,-438.67C216.6,-426.83 256.72,-407.33 284.43,-393.86\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"286.2,-396.89 293.67,-389.37 283.14,-390.6 286.2,-396.89\"/>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>2</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-306\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-302.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.6</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-359.7C27,-351.98 27,-342.71 27,-334.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-334.1 27,-324.1 23.5,-334.1 30.5,-334.1\"/>\n",
       "</g>\n",
       "<!-- 8 -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>8</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-306\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-302.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.82</text>\n",
       "</g>\n",
       "<!-- 7&#45;&gt;8 -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>7&#45;&gt;8</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-359.7C99,-351.98 99,-342.71 99,-334.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-334.1 99,-324.1 95.5,-334.1 102.5,-334.1\"/>\n",
       "</g>\n",
       "<!-- 14 -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>14</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-306\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-302.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.2</text>\n",
       "</g>\n",
       "<!-- 13&#45;&gt;14 -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>13&#45;&gt;14</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-359.7C171,-351.98 171,-342.71 171,-334.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-334.1 171,-324.1 167.5,-334.1 174.5,-334.1\"/>\n",
       "</g>\n",
       "<!-- 20 -->\n",
       "<g id=\"node22\" class=\"node\">\n",
       "<title>20</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-306\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-302.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.76</text>\n",
       "</g>\n",
       "<!-- 19&#45;&gt;20 -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>19&#45;&gt;20</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M243,-359.7C243,-351.98 243,-342.71 243,-334.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"246.5,-334.1 243,-324.1 239.5,-334.1 246.5,-334.1\"/>\n",
       "</g>\n",
       "<!-- 26 -->\n",
       "<g id=\"node27\" class=\"node\">\n",
       "<title>26</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-306\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-302.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.35</text>\n",
       "</g>\n",
       "<!-- 25&#45;&gt;26 -->\n",
       "<g id=\"edge26\" class=\"edge\">\n",
       "<title>25&#45;&gt;26</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M315,-359.7C315,-351.98 315,-342.71 315,-334.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.5,-334.1 315,-324.1 311.5,-334.1 318.5,-334.1\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>3</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5.19</text>\n",
       "</g>\n",
       "<!-- 2&#45;&gt;3 -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>2&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-287.7C27,-279.98 27,-270.71 27,-262.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-262.1 27,-252.1 23.5,-262.1 30.5,-262.1\"/>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>4</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.82</text>\n",
       "</g>\n",
       "<!-- 3&#45;&gt;4 -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>3&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-215.7C27,-207.98 27,-198.71 27,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-190.1 27,-180.1 23.5,-190.1 30.5,-190.1\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>5</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5.15</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;5 -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>4&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-143.7C27,-135.98 27,-126.71 27,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-118.1 27,-108.1 23.5,-118.1 30.5,-118.1\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>6</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.79</text>\n",
       "</g>\n",
       "<!-- 5&#45;&gt;6 -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>5&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M27,-71.7C27,-63.98 27,-54.71 27,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-46.1 27,-36.1 23.5,-46.1 30.5,-46.1\"/>\n",
       "</g>\n",
       "<!-- 9 -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>9</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.34</text>\n",
       "</g>\n",
       "<!-- 8&#45;&gt;9 -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>8&#45;&gt;9</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-287.7C99,-279.98 99,-270.71 99,-262.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-262.1 99,-252.1 95.5,-262.1 102.5,-262.1\"/>\n",
       "</g>\n",
       "<!-- 10 -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>10</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.45</text>\n",
       "</g>\n",
       "<!-- 9&#45;&gt;10 -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>9&#45;&gt;10</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-215.7C99,-207.98 99,-198.71 99,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-190.1 99,-180.1 95.5,-190.1 102.5,-190.1\"/>\n",
       "</g>\n",
       "<!-- 11 -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>11</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.97</text>\n",
       "</g>\n",
       "<!-- 10&#45;&gt;11 -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>10&#45;&gt;11</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-143.7C99,-135.98 99,-126.71 99,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-118.1 99,-108.1 95.5,-118.1 102.5,-118.1\"/>\n",
       "</g>\n",
       "<!-- 12 -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>12</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"99\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.71</text>\n",
       "</g>\n",
       "<!-- 11&#45;&gt;12 -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>11&#45;&gt;12</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M99,-71.7C99,-63.98 99,-54.71 99,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-46.1 99,-36.1 95.5,-46.1 102.5,-46.1\"/>\n",
       "</g>\n",
       "<!-- 15 -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>15</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.53</text>\n",
       "</g>\n",
       "<!-- 14&#45;&gt;15 -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>14&#45;&gt;15</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-287.7C171,-279.98 171,-270.71 171,-262.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-262.1 171,-252.1 167.5,-262.1 174.5,-262.1\"/>\n",
       "</g>\n",
       "<!-- 16 -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>16</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.0</text>\n",
       "</g>\n",
       "<!-- 15&#45;&gt;16 -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>15&#45;&gt;16</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-215.7C171,-207.98 171,-198.71 171,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-190.1 171,-180.1 167.5,-190.1 174.5,-190.1\"/>\n",
       "</g>\n",
       "<!-- 17 -->\n",
       "<g id=\"node20\" class=\"node\">\n",
       "<title>17</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.86</text>\n",
       "</g>\n",
       "<!-- 16&#45;&gt;17 -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>16&#45;&gt;17</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-143.7C171,-135.98 171,-126.71 171,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-118.1 171,-108.1 167.5,-118.1 174.5,-118.1\"/>\n",
       "</g>\n",
       "<!-- 18 -->\n",
       "<g id=\"node21\" class=\"node\">\n",
       "<title>18</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"171\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"171\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.28</text>\n",
       "</g>\n",
       "<!-- 17&#45;&gt;18 -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>17&#45;&gt;18</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171,-71.7C171,-63.98 171,-54.71 171,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.5,-46.1 171,-36.1 167.5,-46.1 174.5,-46.1\"/>\n",
       "</g>\n",
       "<!-- 21 -->\n",
       "<g id=\"node23\" class=\"node\">\n",
       "<title>21</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5.83</text>\n",
       "</g>\n",
       "<!-- 20&#45;&gt;21 -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>20&#45;&gt;21</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M243,-287.7C243,-279.98 243,-270.71 243,-262.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"246.5,-262.1 243,-252.1 239.5,-262.1 246.5,-262.1\"/>\n",
       "</g>\n",
       "<!-- 22 -->\n",
       "<g id=\"node24\" class=\"node\">\n",
       "<title>22</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5.06</text>\n",
       "</g>\n",
       "<!-- 21&#45;&gt;22 -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>21&#45;&gt;22</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M243,-215.7C243,-207.98 243,-198.71 243,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"246.5,-190.1 243,-180.1 239.5,-190.1 246.5,-190.1\"/>\n",
       "</g>\n",
       "<!-- 23 -->\n",
       "<g id=\"node25\" class=\"node\">\n",
       "<title>23</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.28</text>\n",
       "</g>\n",
       "<!-- 22&#45;&gt;23 -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>22&#45;&gt;23</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M243,-143.7C243,-135.98 243,-126.71 243,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"246.5,-118.1 243,-108.1 239.5,-118.1 246.5,-118.1\"/>\n",
       "</g>\n",
       "<!-- 24 -->\n",
       "<g id=\"node26\" class=\"node\">\n",
       "<title>24</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"243\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"243\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.5</text>\n",
       "</g>\n",
       "<!-- 23&#45;&gt;24 -->\n",
       "<g id=\"edge25\" class=\"edge\">\n",
       "<title>23&#45;&gt;24</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M243,-71.7C243,-63.98 243,-54.71 243,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"246.5,-46.1 243,-36.1 239.5,-46.1 246.5,-46.1\"/>\n",
       "</g>\n",
       "<!-- 27 -->\n",
       "<g id=\"node28\" class=\"node\">\n",
       "<title>27</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.61</text>\n",
       "</g>\n",
       "<!-- 26&#45;&gt;27 -->\n",
       "<g id=\"edge27\" class=\"edge\">\n",
       "<title>26&#45;&gt;27</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M315,-287.7C315,-279.98 315,-270.71 315,-262.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.5,-262.1 315,-252.1 311.5,-262.1 318.5,-262.1\"/>\n",
       "</g>\n",
       "<!-- 28 -->\n",
       "<g id=\"node29\" class=\"node\">\n",
       "<title>28</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.84</text>\n",
       "</g>\n",
       "<!-- 27&#45;&gt;28 -->\n",
       "<g id=\"edge28\" class=\"edge\">\n",
       "<title>27&#45;&gt;28</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M315,-215.7C315,-207.98 315,-198.71 315,-190.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.5,-190.1 315,-180.1 311.5,-190.1 318.5,-190.1\"/>\n",
       "</g>\n",
       "<!-- 29 -->\n",
       "<g id=\"node30\" class=\"node\">\n",
       "<title>29</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.17</text>\n",
       "</g>\n",
       "<!-- 28&#45;&gt;29 -->\n",
       "<g id=\"edge29\" class=\"edge\">\n",
       "<title>28&#45;&gt;29</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M315,-143.7C315,-135.98 315,-126.71 315,-118.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.5,-118.1 315,-108.1 311.5,-118.1 318.5,-118.1\"/>\n",
       "</g>\n",
       "<!-- 30 -->\n",
       "<g id=\"node31\" class=\"node\">\n",
       "<title>30</title>\n",
       "<ellipse fill=\"grey\" stroke=\"grey\" cx=\"315\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"315\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.35</text>\n",
       "</g>\n",
       "<!-- 29&#45;&gt;30 -->\n",
       "<g id=\"edge30\" class=\"edge\">\n",
       "<title>29&#45;&gt;30</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M315,-71.7C315,-63.98 315,-54.71 315,-46.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"318.5,-46.1 315,-36.1 311.5,-46.1 318.5,-46.1\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x1b3826d8610>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.reset(seed=100)\n",
    "env._render(revealed=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(P, C, E)=(0, 4, 1) -> 5.0\n",
      "(P, C, E)=(0, 4, 5) -> 5.0\n",
      "Term reward=3.8236\n"
     ]
    }
   ],
   "source": [
    "res, actions, obs_list = run_episode(env, mgps, 0, return_obs=True)\n",
    "for action, obs in zip(actions, obs_list):\n",
    "    print(f\"(P, C, E)={env.action_to_readable(action)} -> {obs}\")\n",
    "print(f\"Term reward={np.round(res.reward, 4)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(P, C, E)=(3, 4, 4) -> 5.0\n",
      "(P, C, E)=(3, 3, 1) -> 5.0\n",
      "(P, C, E)=(3, 2, 3) -> 1.0\n",
      "(P, C, E)=(0, 4, 4) -> 5.0\n",
      "(P, C, E)=(0, 5, 1) -> 1.0\n",
      "Term reward=3.6576\n"
     ]
    }
   ],
   "source": [
    "res, actions, obs_list = run_episode(env, pouct, 0, return_obs=True)\n",
    "for action, obs in zip(actions, obs_list):\n",
    "    print(f\"(P, C, E)={env.action_to_readable(action)} -> {obs}\")\n",
    "print(f\"Term reward={np.round(res.reward, 4)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(P, C, E)=(4, 4, 1) -> 4.0\n",
      "(P, C, E)=(1, 4, 1) -> 5.0\n",
      "(P, C, E)=(1, 4, 5) -> 5.0\n",
      "Term reward=3.8216\n"
     ]
    }
   ],
   "source": [
    "res, actions, obs_list = run_episode(env, mgps, 10, return_obs=True)\n",
    "for action, obs in zip(actions, obs_list):\n",
    "    print(f\"(P, C, E)={env.action_to_readable(action)} -> {obs}\")\n",
    "print(f\"Term reward={np.round(res.reward, 4)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(P, C, E)=(1, 4, 5) -> 5.0\n",
      "(P, C, E)=(1, 5, 5) -> 3.0\n",
      "(P, C, E)=(1, 3, 4) -> 1.0\n",
      "(P, C, E)=(2, 4, 1) -> 1.0\n",
      "(P, C, E)=(1, 3, 0) -> 2.0\n",
      "Term reward=3.6388\n"
     ]
    }
   ],
   "source": [
    "res, actions, obs_list = run_episode(env, pouct, 10, return_obs=True)\n",
    "for action, obs in zip(actions, obs_list):\n",
    "    print(f\"(P, C, E)={env.action_to_readable(action)} -> {obs}\")\n",
    "print(f\"Term reward={np.round(res.reward, 4)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(P, C, E)=(3, 4, 5) -> 3.0\n",
      "(P, C, E)=(4, 4, 5) -> 2.0\n",
      "(P, C, E)=(2, 4, 2) -> 1.0\n",
      "(P, C, E)=(0, 4, 5) -> 5.0\n",
      "(P, C, E)=(0, 5, 1) -> 1.0\n",
      "Term reward=3.6694\n"
     ]
    }
   ],
   "source": [
    "pouct_big = POUCT_policy(steps=10000, rollout_depth=0, exploration_coeff=10)\n",
    "res, actions, obs_list = run_episode(env, pouct_big, 10, return_obs=True)\n",
    "for action, obs in zip(actions, obs_list):\n",
    "    print(f\"(P, C, E)={env.action_to_readable(action)} -> {obs}\")\n",
    "print(f\"Term reward={np.round(res.reward, 4)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jas",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
