{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "abstract-hazard",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.\n",
      "WARNING:root:Limited tf.summary API due to missing TensorBoard installation.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import os, sys\n",
    "sys.path.append('environments/')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import pickle, os, csv, math, time, joblib\n",
    "from joblib import Parallel, delayed\n",
    "import datetime as dt\n",
    "from datetime import date, datetime, timedelta\n",
    "from collections import Counter\n",
    "import copy as cp\n",
    "import tqdm\n",
    "from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier\n",
    "from lightgbm import LGBMRegressor, LGBMClassifier\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error\n",
    "from sklearn.metrics import log_loss, f1_score, precision_score, recall_score, accuracy_score\n",
    "#import matplotlib.pyplot as plt|\n",
    "#import matplotlib.ticker as ticker\n",
    "import collections \n",
    "#import shap\n",
    "import seaborn as sns\n",
    "import random\n",
    "from sklearn.linear_model import LinearRegression\n",
    "np.seterr(all=\"ignore\")\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import math\n",
    "import statsmodels.api as sm\n",
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "import numpy as np\n",
    "import json\n",
    "import util as util_fqi\n",
    "import sys\n",
    "sys.path.append('models/')\n",
    "from lmmfqi import LMMFQIagent\n",
    "from fqi import FQIagent\n",
    "import gym\n",
    "from gym import spaces\n",
    "from gym.utils import seeding\n",
    "import numpy as np\n",
    "from os import path\n",
    "from os.path import join as pjoin\n",
    "from sepsis_env import SepsisEnv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "quiet-italian",
   "metadata": {},
   "source": [
    "# Generate Tuples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "relevant-morgan",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = SepsisEnv()\n",
    "fg_tuples = env.generate_tuples(group=\"foreground\", n_trajectories=10)\n",
    "bg_tuples = env.generate_tuples(group=\"background\", n_trajectories=10)\n",
    "all_tuples = bg_tuples + fg_tuples\n",
    "random.shuffle(all_tuples)\n",
    "split = 0.8\n",
    "train_tuples = all_tuples[:int(split*len(all_tuples))]\n",
    "test_tuples = all_tuples[int(split*len(all_tuples)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "academic-knowing",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/aishwaryamandyam/anaconda3/envs/research/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
      "  warnings.warn(msg, FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:ylabel='Density'>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD4CAYAAAD2FnFTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAmRklEQVR4nO3deXxV9Z3/8dcnC0kIIWFJ2JMQRDSIAkbRUWt/1gVqC3ZGR7R27IxTpx2ZduqjM2On87Md205bO3XaaZkqnTpdfqVOra1lHCwutQtVKQiILCIhhCXsCZBASMjy+f1xT3xcwwkJ4Z7cm+T9fDzuI2e998Pl5r5zvt9zvsfcHRERkc7Skl2AiIikJgWEiIiEUkCIiEgoBYSIiIRSQIiISKiMZBeQKKNHj/bS0tJklyEi0q+89tprh929MGzdgAmI0tJS1qxZk+wyRET6FTPb2dU6NTGJiEgoBYSIiIRSQIiISCgFhIiIhFJAiIhIKAWEiIiEUkCIiEgoBYSIiISKNCDMbK6ZbTWzSjN7IGT9R83sDTNbb2Yrzaw8WF5qZieD5evN7NEo6xQRkdNFdiW1maUDi4EbgD3AajNb5u6b4zZb6u6PBtvPBx4B5gbrtrv7zKjqExkIlq7aFbr8zjnFfVyJDERRHkFcDlS6e5W7nwKeABbEb+Du9XGzuYBubycikiKiDIgJwO64+T3Bsncws/vMbDvwMPDxuFWTzWydmf3GzK4JewEzu9fM1pjZmkOHDiWydhGRQS/pndTuvtjdpwD/APxTsHgfUOzus4D7gaVmNjxk3yXuXuHuFYWFoYMRiohIL0UZEDXApLj5icGyrjwB3ALg7s3uXhtMvwZsB86PpkwREQkT5XDfq4GpZjaZWDAsBO6M38DMprr7tmD2ZmBbsLwQqHP3NjMrA6YCVRHWKjKghHVeq+NazlZkAeHurWa2CFgBpAOPu/smM3sIWOPuy4BFZnY90AIcAe4Odn8X8JCZtQDtwEfdvS6qWkVE5HSR3jDI3ZcDyzstezBu+hNd7PcU8FSUtYmIyJklvZNaRERSkwJCRERCKSBERCSUAkJEREIpIEREJJQCQkREQikgREQklAJCRERCKSBERCSUAkJEREIpIEREJFSkYzGJyNnTSKySKnQEISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEkoBISIioRQQIiISKtKAMLO5ZrbVzCrN7IGQ9R81szfMbL2ZrTSz8rh1nw7222pmN0VZp4iInC6ygDCzdGAxMA8oB+6ID4DAUnef4e4zgYeBR4J9y4GFwHRgLvAfwfOJiEgfifII4nKg0t2r3P0U8ASwIH4Dd6+Pm80FPJheADzh7s3uvgOoDJ5PRET6SJRjMU0AdsfN7wHmdN7IzO4D7geGANfF7ftqp30nhOx7L3AvQHGxxqoREUmkpHdSu/tid58C/APwT2e57xJ3r3D3isLCwmgKFBEZpKIMiBpgUtz8xGBZV54AbunlviIikmBRBsRqYKqZTTazIcQ6nZfFb2BmU+Nmbwa2BdPLgIVmlmVmk4GpwB8irFVERDqJrA/C3VvNbBGwAkgHHnf3TWb2ELDG3ZcBi8zseqAFOALcHey7ycx+AmwGWoH73L0tqlpFROR0kd4wyN2XA8s7LXswbvoTZ9j3i8AXo6tORETOJOmd1CIikpoUECIiEkoBISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEkoBISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEkoBISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEirSgDCzuWa21cwqzeyBkPX3m9lmM9tgZi+aWUncujYzWx88lkVZp4iInC4jqic2s3RgMXADsAdYbWbL3H1z3GbrgAp3bzSzjwEPA7cH6066+8yo6hMRkTOL8gjicqDS3avc/RTwBLAgfgN3f8ndG4PZV4GJEdYjIiJnIcqAmADsjpvfEyzryj3As3Hz2Wa2xsxeNbNbIqhPRETOILImprNhZncBFcC1cYtL3L3GzMqAX5nZG+6+vdN+9wL3AhQXF/dZvSIig0GURxA1wKS4+YnBsncws+uBzwDz3b25Y7m71wQ/q4BfA7M67+vuS9y9wt0rCgsLE1u9iMggF2VArAammtlkMxsCLATecTaSmc0CHiMWDgfjlo8ws6xgejRwFRDfuS0iIhGLrInJ3VvNbBGwAkgHHnf3TWb2ELDG3ZcBXwWGAU+aGcAud58PXAg8ZmbtxELsy53OfhIRkYhF2gfh7suB5Z2WPRg3fX0X+70MzIiyNhEROTNdSS0iIqEUECIiEkoBISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEkoBISIioRQQIiISSgEhIiKhFBAiIhJKASEiIqEUECIiEkoBISIioRQQIiISqkc3DDKznwHfBZ519/ZoSxIZHJau2pXsEkTOqKdHEP8B3AlsM7Mvm9m0CGsSEZEU0KOAcPcX3P2DwGygGnjBzF42sz83s8woCxQRkeTocR+EmY0CPgz8JbAO+AaxwHg+kspERCSpetoH8XNgGvBD4P3uvi9Y9d9mtiaq4kREJHl6egTxHXcvd/cvdYSDmWUBuHtFVzuZ2Vwz22pmlWb2QMj6+81ss5ltMLMXzawkbt3dZrYteNx9lv8uERE5Rz0NiC+ELHvlTDuYWTqwGJgHlAN3mFl5p83WARXufjHwU+DhYN+RwGeBOcDlwGfNbEQPaxURkQQ4YxOTmY0FJgA5ZjYLsGDVcGBoN899OVDp7lXBcz0BLAA2d2zg7i/Fbf8qcFcwfRPwvLvXBfs+D8wFftyDf5OIiCRAd30QNxHrmJ4IPBK3vAH4x272nQDsjpvfQ+yIoCv3AM+eYd8JnXcws3uBewGKi4u7KUdERM7GGQPC3b8PfN/M/sTdn4qqCDO7C6gArj2b/dx9CbAEoKKiwiMoTURk0Oquiekud/9/QKmZ3d95vbs/ErJbhxpgUtz8xGBZ59e4HvgMcK27N8ft++5O+/76TLWKiEhidddJnRv8HAbkhTzOZDUw1cwmm9kQYCGwLH6DoF/jMWC+ux+MW7UCuNHMRgSd0zcGy0REpI9018T0WPDzn8/2id291cwWEftiTwced/dNZvYQsMbdlwFfJRY+T5oZwC53n+/udWb2eWIhA/BQR4e1iIj0jZ5eKPcwsVNdTwK/BC4GPhk0P3XJ3ZcDyzstezBu+voz7Ps48HhP6hMRkcTr6XUQN7p7PfA+YmMxnQf8XVRFiYhI8vU0IDqONG4GnnT3YxHVIyIiKaJHTUzAM2b2JrEmpo+ZWSHQFF1ZIhLm5Kk2vvmrbXzv5WpyMtOZNHIo8y8ZT3ZmerJLkwGoRwHh7g8E/RDH3L3NzE4QuypaRPpIU0sbtz32Mhtr6rl5xjhyhqTzs7V7qDl6knuumszwHI28L4nV0yMIgAuIXQ8Rv88PElyPiHThwV9sZGNNPY/eNZu5F40DID8nkx+8Us3P19XwZ1eWEJwNKJIQPeqDMLMfAv8KXA1cFjy6HMVVRBLrlxv385M1e/ib6857OxwAphQO48bysWw90MC6XUeTV6AMSD09gqgAyt1dw1mI9LF2d7723FamFObyifdMPW39lVNGsXHvMZ7duI8ZE/PJTO/xfcBEzqinn6SNwNgoCxGRcK/vPsq2g8e5/4ZpZIR8+aeZcUP5GE6camPtriNJqFAGqp4eQYwGNpvZH4CO8ZJw9/mRVCUiALg7L209xIXjhjPvoq7/Rps8KpdJI3L43bbDXFY6kjT1RUgC9DQgPhdlESISbsfhExw+3syn511AWlrXX/pmxjVTC1n6h11s2VfP9PH5fVilDFQ9amJy998Qu4I6M5heDayNsC4RAVZX15GdmcZ7Z4zrdtvy8cPJy85QZ7UkTE/PYvoIsVuCPhYsmgA8HVFNIgI0NreycW89MyeNIGdI9xfCpZlxycQCtu5voLG5tQ8qlIGup53U9wFXAfUA7r4NKIqqKBGB12uO0dbuXFba89uxz5xUQJs7b+zVaDhy7noaEM3ufqpjJrhYTqe8ikRoY80xivKyGJef0+N9xuVnU5SXpWYmSYieBsRvzOwfgRwzuwF4Evif6MoSGdwamlqoPnyCiyacXWezmXHxxHx21TXS0NQSUXUyWPQ0IB4ADgFvAH9F7B4P/xRVUSKD3eZ99TicdUAAXDB2OABb9zckuCoZbHo6WF+7mT0NPO3uh6ItSUQ21hxj9LAsxuRlnfW+4/Kzyc/J5M39DVSUjoygOhkszngEYTGfM7PDwFZgq5kdMrMHz7SfiPReU0sbOw6fYPr44b0afM/MuGBsHtsONtDS1h5BhTJYdNfE9EliZy9d5u4j3X0kMAe4ysw+GXl1IoPQtoPHaXeYNiav189xwdjhtLQ5VYdOJLAyGWy6a2L6EHCDux/uWODuVWZ2F/Ac8G9RFicyGL11oIHszDQmjRza6+coK8wlM93YdrCBaWN7HzQCS1ftOm3ZnXOKk1BJ3+vuCCIzPhw6BP0Q3d6dxMzmmtlWM6s0swdC1r/LzNaaWauZ3dppXZuZrQ8ey7p7LZGBwN15a38DU4vySD/D0BrdyUxPo2RULpUHjyewOhlsuguIU71ch5mlA4uBeUA5cIeZlXfabBfwYWBpyFOcdPeZwUODAsqgsO9YEw3NrefUvNThvMJhHGxopl6nu0ovdRcQl5hZfcijAZjRzb6XA5XuXhVcZPcEnW5T6u7V7r4BUE+aCLDtQOzU1Kljhp3zc00pij1H1SEdRUjvnDEg3D3d3YeHPPLcvbsmpgnA7rj5PcGynso2szVm9qqZ3RK2gZndG2yz5tAhnX0r/d/2QycYOzybvOxzv7/0uPxscjLTqTyojmrpnVS+9VSJu1cAdwJfN7MpnTdw9yXuXuHuFYWFhX1foUgCtbS1U117gimFuQl5vjQzphTmsv3QcXQzSOmNKAOiBpgUNz8xWNYj7l4T/KwCfg3MSmRxIqlmV10jre1OWeG5Ny91KCscxrGTLRxtVD+EnL0oA2I1MNXMJpvZEGAh0KOzkcxshJllBdOjiV2LsTmySkVSwPZDx0kzmDw6MUcQACWjYqfKVteqmUnOXmQB4e6twCJgBbAF+Im7bzKzh8xsPoCZXWZme4DbgMfMbFOw+4XAGjN7HXgJ+LK7KyBkQNt+8DgTCnLIzuz+3g89NWZ4NlkZaeysbUzYc8rg0dNbjvaKuy8nNrBf/LIH46ZXE2t66rzfy3R/lpTIgNHc2kbN0ZNcMzWxfWlpZpSMGqojCOmVVO6kFhk0dtedpN0T27zUoWRULgcbmjnaeMZLl0ROo4AQSQE7a09gQPE5DK/RlY5+iNd2Hkn4c8vApoAQSQE7axsZm5+d0P6HDhMLhpJuxupqBYScHQWESJK1tTu76hrf/ks/0YZkpDG+IJs11XWRPL8MXAoIkSTbf6yJU23tlIxKfP9Dh9JRuWzYc4ymlrbIXkMGHgWESJJ1nGFUGmFAlIzK5VRbO2/UHIvsNWTgUUCIJNnO2hMUDM0kP+fcx1/qSnHQfLVazUxyFhQQIknk7uysbYz06AFgWFYGUwpzWaOOajkLCgiRJKo7cYqG5tbIOqjjXVY6kjXVdbS3a+A+6RkFhEgSdQyBEWUHdYfZJSOob2plu+4PIT2kgBBJouraE2RnplGUlxX5a80uHgHAul1HI38tGRgUECJJtLO2kZKRuaRZ7+8/3VNlo3PJz8lk7S71Q0jPKCBEkuR4cyuHjjdT2gf9DwBpacbMSQUKCOkxBYRIkuzqw/6HDrOLR7Dt4HHqm3QDIemeAkIkSXbWniA9zZgwIqfPXnN2SQHu8Pruo332mtJ/KSBEkqS69gQTC3LITO+7X8NLJhVgBmt3Hu2z15T+SwEhkgQnT7Wx92hTnzYvAQzPzmRq0TDW7VY/hHRPASGSBK/vOUqbe591UMebXTyCdbuO6oI56ZYCQiQJOobeLk5SQBw72ULVYd2GVM5MASGSBKurjzBmeBZDh0R6W/hQs0sKAHS6q3Qr0oAws7lmttXMKs3sgZD17zKztWbWama3dlp3t5ltCx53R1mnSF9qa3fW7jzS5/0PHcpGD2N4doauqJZuRRYQZpYOLAbmAeXAHWZW3mmzXcCHgaWd9h0JfBaYA1wOfNbMRkRVq0hf2rq/gYbm1qT0P0BwwVzxCNbpCEK6EeURxOVApbtXufsp4AlgQfwG7l7t7huA9k773gQ87+517n4EeB6YG2GtIn1mzc5Y/0OyjiAAZhcXsPVAAw26YE7OIMqAmADsjpvfEyxL2L5mdq+ZrTGzNYcOHep1oSJ9aXX1EcblZ1MQ4Q2CujO7eATusGGP7jAnXevXndTuvsTdK9y9orCwMNnliHTL3Vm9o46K0pFYHwzQ15VLJhUAsHanmpmka1EGRA0wKW5+YrAs6n1FUlbN0ZPsr2+ioiS5XWr5ObEL5nQmk5xJlAGxGphqZpPNbAiwEFjWw31XADea2Yigc/rGYJlIv/aHHbH+h8snj0xyJcEFc7uP4q4L5iRcZAHh7q3AImJf7FuAn7j7JjN7yMzmA5jZZWa2B7gNeMzMNgX71gGfJxYyq4GHgmUi/drq6jqGZ2cwbUxeskthdkkBRxtb2KEL5qQLkV6l4+7LgeWdlj0YN72aWPNR2L6PA49HWZ9IX1sV9D+kpSWv/6HDrOAOc2t3HaWscFiSq5FU1K87qUX6k8PHm6k6dCIlmpcAziscRl52hvohpEsKCJE+0jH+0mWlqREQHXeY0xXV0hUFhEgfWbWjjuzMNGZMyE92KW+bXTyCrfvrOd7cmuxSJAUpIET6yOrqOmZNGsGQjNT5tZtVXEC7wwbdYU5CpM4nVWQAa2hqYfPeei5Lkf6HDrMmxTqq1ykgJIQCQqQPvLbzCO0Oc1IsIPKHZnJe0TBdUS2hFBAifWB1dR0Zacas4oJkl3Ka2cUFumBOQikgRPrAH3bUMX1CflJuENSd2cUjqDtxShfMyWkUECIRa2pp4/Xdx1KuealDx3UZq3ZosAJ5JwWESMTW7jrCqbZ2Lk+R6x86mzw6l6K8LF6tqk12KZJiFBAiEft95WHS04wrpoxKdimhzIwrykbxalWt+iHkHRQQIhFbue0wsyYVMCwr9fofOlxRNooD9c1U1zYmuxRJIQoIkQgda2xhQ80xrp46OtmlnNEVZbHmLzUzSTwFhEiEXt5+GHe4+rzUDojJo3MpVD+EdKKAEInQysrDDMvKePsWn6lK/RASRgEhEqGVlYe5omwkmemp/6t2RdlI9UPIO6T+p1akn9pd18jO2kauSvHmpQ5XlMXOslIzk3RQQIhEZGXlYQCuSfEO6g5l6oeQThQQIhFZWXmYMcOzmNJPbuepfgjpTAEhEoH2duflysNcfV4hZsm//3RPdfRDaFwmgYgDwszmmtlWM6s0swdC1meZ2X8H61eZWWmwvNTMTprZ+uDxaJR1iiTaxr3HONLYwtVTU/Pq6a50nI7727cOJbkSSQWRBYSZpQOLgXlAOXCHmZV32uwe4Ii7nwf8G/CVuHXb3X1m8PhoVHWKROGFLQdJM7j2/KJkl3JWSkblMnl0Lr9WQAjRHkFcDlS6e5W7nwKeABZ02mYB8P1g+qfAe6w/HY+LdOHFLQe4tGQEI3OHJLuUs3bt+YW8WlVLU0tbskuRJIsyICYAu+Pm9wTLQrdx91bgGNBxTD7ZzNaZ2W/M7JoI6xRJqL1HT7Jpbz3XXzgm2aX0yrXTCmlqadfw35KyndT7gGJ3nwXcDyw1s+GdNzKze81sjZmtOXRIh8SSGl7ccgCA9/TTgLiybBRZGWn8euvBZJciSRZlQNQAk+LmJwbLQrcxswwgH6h192Z3rwVw99eA7cD5nV/A3Ze4e4W7VxQWFkbwTxA5e89tPkDpqKFMKcxNdim9kp2Zzh9NGcWLWw7qdNdBLsqAWA1MNbPJZjYEWAgs67TNMuDuYPpW4Ffu7mZWGHRyY2ZlwFSgKsJaRRKi7sQpXt5ey7wZ4/rV6a2d3Th9LLvqGtmyryHZpUgSRRYQQZ/CImAFsAX4ibtvMrOHzGx+sNl3gVFmVkmsKanjVNh3ARvMbD2xzuuPursaRCXlrdi0n7Z25+YZ45Jdyjm5/sIxmMX+PTJ4RXoHE3dfDizvtOzBuOkm4LaQ/Z4CnoqyNpEoPLNhL5NH5zJ9/GldZv1KYV4WFSUjWLFpP5+84bTWXRkkUrWTWqTfOXy8mVe213JzP29e6nDT9LG8ub+BnbW6qnqwUkCIJMiy9Xtpd3jfJf27eanD3IvGArF/lwxOCgiRBHnytT1cPDGfC8b27+alDhNHDOXyySP5+foanc00SCkgRBJgY80xtuyr57ZLJya7lIT6wKwJVB06wRs1x5JdiiSBAkIkAZ5cs5shGWnMv6TzYAH923tnjGNIeho/W9v5EiYZDBQQIufoeHMrP1tXw7yLxpI/NDPZ5SRUfk4m15cX8fT6Go3NNAhFepqryGDw1Gt7aGhq5cN/VJrsUiJx15wSlr+xn2c27OPWAdaE1lOHGppZt/sI1YcbOdTQxBf/dzMFQ4cwbWwe111QxM0zxjGiHw7M2B0FhMg5aG93/uv3O5hVXMCs4hHJLicSV04ZxXlFw/jhK9WDLiDeOtDAD16p5s39DaQZTCjIoXz8cGZMKKD2RDOv7z7Kr948yL8s38KHrijhr//PeeTnDJyjSAWEyDl4fssBqmsbuf/GackuJTJmxoeuKOGzyzaxbteRARuE8Zpb2/jmi5V8+zfbyUw3rr+wiIrSkQzPjn353zmnGAB3Z/O+er7z2yqW/K6Kn62r4Qu3XMRN08cms/yEUR+ESC+1tztff2EbJaOGMu+igfGF0JU/uXQi+TmZLH5pe7JLidzBhiZuf+xVvvVSJR+YNYFP3TCN6y4Y83Y4xDMzpo/P5+sLZ/GL+66iKC+Lv/rha3z+mc20tLUnofrEUkCI9NJzm/ezZV89n3jPVDLTB/av0rCsDP7iqsm8sOUAm/YO3FNeN9YcY8G3fs/W/Q18+4Oz+dfbLmFoVs8aWi6eWMDP//oq7r6yhO+u3MEdS15l/7GmiCuO1sD+VItEpKWtna899xZlhbnMv2R8ssvpEx++qpS8rAz+/cVtyS4lEs++sY/bHn0FA376sSuZ14sBF4dkpPHPCy7iGwtnsmlvPe//1kpe33004bX2FQWESC98/+Vqth08zgNzLyBjgB89dMjPyeQvryljxaYDvLK9NtnlJIy78+8vbuNjP1rLBePyeHrRVUwfn39Oz7lg5gSevu8qsjLSuH3JKzz7xr4EVdu3BscnWySB9h9r4t+ef4vrLijihvL+ede43vqra8uYUJDDP//PJloHQBt7U0sbH39iPY88/xYfmDWBH3/kCoryshPy3NPG5vH0fVdRPm44H/vRWha/VNnvhixRQIichfZ251NPvk6bO597//QBMWrr2cjOTOf/vu9C3tzfwJLf9e97eB2ob+L2x17hmQ17+fu503jkTy8hOzM9oa8xelgWSz9yBfMvGc9XV2zl7366gVOt/SdYdZqryFlY8rsqVlYe5kt/PIPiUUOTXU5S3DR9LO+dMZZHnnuLq88bzcUTC5Jd0ll7becR7vvRWuqbWnjsrku5McLTUrMz0/nGwpmUFeby9Re2sauukcfuurRfXFinIwiRHnph8wEe/uWbvHfGWBZeNqn7HQYoM+NLH7iYorwsFi1dx+HjzckuqcfcncdX7uD2x14hM8P46Uf/KNJw6GBm/O315/ONhTNZv/so8xevZN2uI5G/7rlSQIj0wB921LHox2u5aEI+X731kkHXtNRZ/tBMFn9wNgcbmrjne6s50dya7JK6deTEKe5bupaHntnMu6cV8cyiayjv4zv/LZg5gSfuvYL2drj10VdY/FIlbe2p2y+hgBDpxvObD/Ch765ifH4O3737MnJ7eF78QDereATfumM2b9Qc44P/uYq6E6eSXVIod+fJNbu57mu/ZsWmA3x63gV8588uTdrAirOLR7D8E9cw76KxfHXFVhYueSVlry3RJz0FLV21K3R5x+X90jdOtbbztee3suS3VVw8sYDH765g1LCsZJeVUq4vH8O377qUj/94Hbcs/j1fXziT2Sk0FMfru4/yL8u3sGpHHZeWjOCLH7goJW7olJ+TyTfvmMW7pxXxxf/dzPu+uZLbLp3Ip26cRtHwxJxFlQgKCJFO2tqdFZv28/Av36S6tpE75xTzf28uJ2dIYs9wGShumj6WpR+5go//eB23fvtl7rqihL+5biqFeckJ06aWNp7bfICfrN7NysrDFAzN5Et/PIPbKyaRlpY6TYNmxq2XTuSG8jF861fb+N7L1fxi/V7ed/F47pxTzOzigqQ3ZUYaEGY2F/gGkA78p7t/udP6LOAHwKVALXC7u1cH6z4N3AO0AR939xVR1iqDW3u78+b+Bp7bvJ+n1u5hd91JphYN43t/fhnvnlaU7PJS3qUlI/jl317DV375Jj9atYv/Xr2bm2eM4/0zx3Nl2aiEnz7a2f5jTazaUcvLlbU8u3Ef9U2tTCjI4e/nTuPPrixlWAo3C+bnZPKZm8v54JwS/nNlFT9fW8NTa/dw/phhvHtaEVefN5rLSkcm5Q+UyN41M0sHFgM3AHuA1Wa2zN03x212D3DE3c8zs4XAV4DbzawcWAhMB8YDL5jZ+e7eJ3cscXfaHdrdY4/22HSbO94Obe60tLXT1NJGc2vsZ1NLO82t7/wZW9729vqmljZOdky3ttEcP9/SxqnWdsyg/mQrHX84mEGaGRlpaazYtJ+czHRyhgSPzPS357Mz08nOTCMrI/YzOyOdrMy02PKMd64bkpFGWpqRbkaaGWlpsddIN8OMpP/Vci46/u/a2j328NjP1rZ2jje3cuxkC/UnW6lvamHfsSZ2HD7OjsMn2LS3nqONLZjBnMkj+fS8C7mxfMyguUo6EfKyM/nCLTO45+oyHl+5g6fX1fCzdTVkphvnFeVx4bg8po3Jo2h4FqOHZTEydwg5melkZaaTlRH7XKab0drxfxf3f3jyVBvHTp7iaGMLRxtbONJ4it11jVQdPkHVoRPUHD0ZqyErg+suLOJPKyZxZdmolDpi6E7p6Fy+cMsMHph3IcvW72XZ6zX81+93sOS3VWSkGSWjhjK1KI8pRbmMGZ5N4bAsRudlMTw7k4KhmYyJoGkqyli9HKh09yoAM3sCWADEB8QC4HPB9E+Bb1ns22kB8IS7NwM7zKwyeL5XEl1k7fFmrvrKr94ZAhGcVJBmsfOhczJjX+ZZwZd47Ms9jYKcTIZkpOEOu+oacXj7qst2d1ranKONp9gXhMrJU7FQaTzVSqJPgkgLQiktzTAg1fPCgzBva/ezfi+GZ2dQVjiMm8rHUlE6gmunFSbsStrBavLoXD5/y0V85uYLWbWjjpe3H2bLvgZ+t+1wQm9dOiwrg7LCXC4tGcGfX1XKFWWjuHDccNL7USiEGZaVwZ1zirlzTjGNp1pZXX2E1Tvq2HawgbcONvD8lgOnnfl0ycR8frHo6oTXEmVATAB2x83vAeZ0tY27t5rZMWBUsPzVTvuedrNfM7sXuDeYPW5mW8+yxtHA4bPcpy+kal0wAGt7I4JCQpzT+/bBBBRwhufot/+nm/qwkHgfTLH3bCdgf/P27NnWVtLVitRtmOsBd18CLOnt/ma2xt0rElhSQqRqXaDaeku19U6q1paqdUFia4uygbUGiL/cdGKwLHQbM8sA8ol1VvdkXxERiVCUAbEamGpmk81sCLFO52WdtlkG3B1M3wr8ymMN78uAhWaWZWaTganAHyKsVUREOomsiSnoU1gErCB2muvj7r7JzB4C1rj7MuC7wA+DTug6YiFCsN1PiHVotwL3RXQGU6+bpyKWqnWBaust1dY7qVpbqtYFCazN+tv45CIi0jd0kreIiIRSQIiISKhBFxBm9lUze9PMNpjZz82sIG7dp82s0sy2mtlNSajtNjPbZGbtZlYRt7zUzE6a2frg8Wiq1BasS+r71qmWz5lZTdx79d4k1zM3eF8qzeyBZNbSmZlVm9kbwfu0Jsm1PG5mB81sY9yykWb2vJltC34mZRTALmpLic+ZmU0ys5fMbHPw+/mJYHli3jt3H1QP4EYgI5j+CvCVYLoceB3IAiYD24H0Pq7tQmAa8GugIm55KbAxye9bV7Ul/X3rVOfngE8l+3MW1JIevB9lwJDgfSpPdl1x9VUDo5NdR1DLu4DZ8Z9z4GHggWD6gY7f1RSpLSU+Z8A4YHYwnQe8FfxOJuS9G3RHEO7+nLt33N3kVWLXWEDc8B7uvgPoGN6jL2vb4u5nezV4nzhDbUl/31LY28PNuPspoGO4GenE3X9L7EzGeAuA7wfT3wdu6cuaOnRRW0pw933uvjaYbgC2EBt1IiHv3aALiE7+Ang2mA4bGuS04T2SaLKZrTOz35jZNckuJk4qvm+LgibEx5PVLBFIxfcmngPPmdlrwbA1qWaMu+8LpvcDY5JZTIhU+ZwBsaZoYBawigS9d/16qI2umNkLQNiNZj/j7r8ItvkMsWssfpRqtYXYBxS7e62ZXQo8bWbT3b0+BWrrc2eqE/g28HliX36fB75G7A8BOd3V7l5jZkXA82b2ZvDXcspxdzezVDonP6U+Z2Y2DHgK+Ft3r48fkflc3rsBGRDufv2Z1pvZh4H3Ae/xoJGOPhreo7vautinGWgOpl8zs+3A+UBCOxZ7UxtJGBalp3Wa2XeAZ6KspRspPWSMu9cEPw+a2c+JNYmlUkAcMLNx7r7PzMYBB5NdUAd3P9AxnezPmZllEguHH7n7z4LFCXnvBl0Tk8VuYvT3wHx3b4xblbLDe5hZocXur4GZlRGrrSq5Vb0tpd634JehwweAjV1t2wd6MtxMUphZrpnldUwTO3kjme9VmPiheO4GUukoNiU+ZxY7VPgusMXdH4lblZj3Ltm98Eno9a8k1i68Png8GrfuM8TOOtkKzEtCbR8g1k7dDBwAVgTL/4TYyMbrgbXA+1OltlR43zrV+UNiI3hvCH5JxiW5nvcSO7NkO7GmuqTV0qmuMmJnVb0efLaSWhvwY2JNqS3B5+weYkP/vwhsA14ARqZQbSnxOQOuJtbMtSHuO+29iXrvNNSGiIiEGnRNTCIi0jMKCBERCaWAEBGRUAoIEREJpYAQEZFQCggREQmlgBARkVD/HwTikXxMMLVhAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.distplot([x[3] for x in all_tuples])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ready-bullet",
   "metadata": {},
   "source": [
    "# Train Agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cutting-mauritius",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N actions:  25\n",
      "Learning policy\n",
      "Run 0 :\n",
      "Initialize: get batch, set initial Q\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [08:32<00:00,  1.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Learn policy\n",
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAZIklEQVR4nO3df5RV5X3v8fdHfghEoGaYtsYhDmnIRdQGycSYa8l1WaOYWtQaA6TeKtqSZkn11iaK8S5jzO1aEm9S7Q0rS2rV1FixmhiRciU3GFLTLC2DQeWH1JHEMKDNBK2aWpSB7/3j7INnD3uGM8PsGWaez2utWZ797Ofs8+zZeD7zPM/+oYjAzMzSdcRgN8DMzAaXg8DMLHEOAjOzxDkIzMwS5yAwM0vcyMFuQG9NmjQpmpubB7sZZmZDyvr1638ZEY1F64ZcEDQ3N9Pa2jrYzTAzG1IkvdjdOg8NmZklzkFgZpY4B4GZWeKG3ByBmaVpz549tLe3s3v37sFuymFtzJgxNDU1MWrUqLrf4yAwsyGhvb2d8ePH09zcjKTBbs5hKSLYtWsX7e3tTJkype73eWjIzIaE3bt309DQ4BDogSQaGhp63WtyEJjZkOEQOLi+/I6SCYJ1P3uFr31vK2937hvsppiZHVaSCYKnXnyVv36sjc59DgIz65ujjjrqgLIbb7wRSbS1te0vu/XWW5G0/+LX5uZmZs2alXvfjBkzOPHEEwFYu3YtEydOZMaMGcyYMYMzzzxzf71ly5Yxbdo0pk2bRktLC2vXrt2/7vTTT++XC2yTCQIzs7KcdNJJLF++fP/yAw88wAknnJCr88Ybb7B9+3YAtmzZcsA2Zs2axYYNG9iwYQPf//73AVi5ciW33347P/rRj3juuedYtmwZF198MTt27OjX9icXBH4gm5n1t/PPP5+HH34YgBdeeIGJEycyadKkXJ1PfepT3H///QDcd999zJ8//6DbXbJkCbfccsv+bc2cOZMFCxawdOnSfm1/MqePeo7JbPj40iOb2Lzz9X7d5vT3TOCLv3/CwSsWmDBhApMnT2bjxo08/PDDzJ07l7vuuitX58ILL2TBggV87nOf45FHHuHee+/lnnvu2b/+8ccfZ8aMGQBcdNFFXH/99WzatIkPfehDue20tLQcsO1DlUwQVLlDYGZlmDdvHsuXL2f16tWsWbPmgC/rhoYGjj76aJYvX87xxx/PuHHjcutnzZrFypUrB7LJ+yUTBMJdArPhoq9/uZfp3HPP5fOf/zwtLS1MmDChsM7cuXO54ooruPvuu+va5vTp01m/fj1nnHHG/rL169fT0tLSH03eL5kgMDMr07hx41iyZAkf+MAHuq1zwQUX8NJLL3H22Wezc+fOg27zmmuu4dprr+XRRx+loaGBDRs28NBDD/HYY4/1Z9PTC4LwbLGZ9dGbb75JU1PT/uWrr746t37evHk9vn/8+PFce+21dX/enDlz2LlzJ6eddhqdnZ28/PLLPP300zQ2Fj5fps801L4YW1paoi/nzd7x+Db+1z9u4dkbz2L8mPpvxmRmh4ctW7Zw/PHHD3YzBk1nZycLFixg3759fOtb3+rxCuKi35Wk9RFROKZUao9A0mzgNmAEcEdE3Nxl/aXALUD1pNivR8QdZbZpaMWemVnFyJEjc2cZ9eu2S9kqIGkEsBT4ONAOrJO0IiI2d6l6f0QsKqsdZmbWszIvKDsFaIuIbRHxNrAcOK/EzzOzYW6oDWUPhr78jsoMgmOB7TXL7VlZVxdKekbSg5ImF21I0kJJrZJaOzo6DqlR/ndkNjSNGTOGXbt2OQx6UH0ewZgxY3r1vsE+a+gR4L6IeEvSZ4BvAmd0rRQRy4BlUJks7ssH+fa1ZkNbU1MT7e3tHOofg8Nd9QllvVFmEOwAav/Cb+KdSWEAImJXzeIdwFdKbE/2oaV/gpmVYNSoUb166pbVr8yhoXXAVElTJI0G5gEraitIOqZmcQ5w4C35+on7A2ZmxUrrEUREp6RFwGoqp4/eGRGbJN0EtEbECuBKSXOATuAV4NKy2mNmZsVKnSOIiFXAqi5lN9S8vg64rsw2HNAmjw2ZmeUk8zwCzxWbmRVLJgiqfOaZmVleMkHgDoGZWbFkgsDMzIolFwQeGTIzy0smCHxlsZlZsWSCoMr3KTEzy0smCNwhMDMrlkwQmJlZseSCwANDZmZ5yQSBR4bMzIolEwRmZlYsuSDwSUNmZnnpBIFPGzIzK5ROEGR8G2ozs7xkgsD9ATOzYskEgZmZFUsvCDwyZGaWk0wQeK7YzKxYMkFQ5Q6BmVleMkEgTxebmRVKJgjMzKxYckHgK4vNzPKSCQJPFpuZFUsmCKp8ZbGZWV4yQeAOgZlZsWSCwMzMiiUXBJ4sNjPLKzUIJM2WtFVSm6TFPdS7UFJIaimvLWVt2cxsaCstCCSNAJYC5wDTgfmSphfUGw9cBTxZVltquUNgZpZXZo/gFKAtIrZFxNvAcuC8gnpfBpYAu0tsi68sNjPrRplBcCywvWa5PSvbT9JMYHJE/GOJ7TAzsx4M2mSxpCOArwF/UUfdhZJaJbV2dHQc0ueGZ4vNzHLKDIIdwOSa5aasrGo8cCKwVtLPgFOBFUUTxhGxLCJaIqKlsbGxb63xyJCZWaEyg2AdMFXSFEmjgXnAiurKiHgtIiZFRHNENANPAHMiorXENvn0UTOzLkoLgojoBBYBq4EtwD9ExCZJN0maU9bndscdAjOzYiPL3HhErAJWdSm7oZu6p5fZFjMzK5bclcVmZpaXTBDIlxabmRVKJgjMzKxYckHgs4bMzPKSCQIPDJmZFUsmCKr8hDIzs7xkgsBzxWZmxZIJAjMzK5ZcEHiy2MwsL5kg8NCQmVmxZIKgyh0CM7O8ZILATygzMyuWTBCYmVmx5ILATygzM8tLJgg8WWxmViyZIKhyf8DMLC+5IDAzszwHgZlZ4uoKAknHSTozez1W0vhym1UezxWbmeUdNAgk/QnwIHB7VtQEfLfENpXCTygzMytWT4/gCuA04HWAiHge+PUyG1UudwnMzGrVEwRvRcTb1QVJIxmC36buD5iZFasnCH4o6QvAWEkfBx4AHim3WWZmNlDqCYLFQAfwLPAZYFVEXF9qq0rkyWIzs7yRddT5s4i4DfibaoGkq7KyIcNzxWZmxerpEVxSUHZpP7djwLhDYGaW122PQNJ84NPAFEkralaNB14pu2H9zbehNjMr1tPQ0I+Bl4BJwFdryt8AnimzUWZmNnC6DYKIeBF4EfhoXzcuaTZwGzACuCMibu6y/k+pXKewF/gVsDAiNvf18+rhyWIzs7x6riw+VdI6Sb+S9LakvZJer+N9I4ClwDnAdGC+pOldqv19RJwUETOArwBf6/0u1MeTxWZmxeqZLP46MB94HhgL/DGVL/iDOQVoi4ht2QVpy4HzaitERG2gvIsBmMsNTxebmeXUddO5iGgDRkTE3oi4C5hdx9uOBbbXLLdnZTmSrpD0ApUewZX1tKcv3CEwMytWTxC8KWk0sEHSVyT9eZ3vq0tELI2I3wKuBf5nUR1JCyW1Smrt6Ojor482MzPq+0L/71QmexcB/wFMBi6s4307srpVTVlZd5YD5xetiIhlEdESES2NjY11fHT3PFlsZpZ30CuLs7OHAP4T+FIvtr0OmCppCpUAmEfluoT9JE3N7mYK8HtU5iFK4cliM7NiBw0CSecCXwaOy+oLiIiY0NP7IqJT0iJgNZUexZ0RsUnSTUBrRKwAFmUPvNkDvErxVcxmZlaieu41dCvwB8CzEb0bWImIVcCqLmU31Ly+qjfb6w8eGjIzy6tnjmA7sLG3IXD48diQmVmRenoE1wCrJP0QeKtaGBGlXfxVJl9HYGaWV08Q/CWV2z+MAUaX25zyeLLYzKxYPUHwnog4sfSWmJnZoKhnjmCVpLNKb8kAGeozHWZm/a2eIPgs8Kik/5T0uqQ36rnp3OHGI0NmZsXquaBs/EA0xMzMBkdPTyibFhHPSZpZtD4iniqvWf1Pni02MyvUU4/gamAh+aeTVQVwRiktMjOzAdXTE8oWZi/PiYjdteskjSm1VSXyZLGZWV49k8U/rrPssOaBITOzYj3NEfwmlQfJjJV0Mu98l04Axg1A20rhK4vNzPJ6miM4G7iUynMEvso7QfAG8IVym9X/PFdsZlaspzmCbwLflHRhRHx7ANtkZmYDqJ45giZJE1Rxh6SnhvKVxp4sNjPLqycILouI14GzgAYqj668udRWlcBDQ2ZmxeoJgupX6CeAv4uITQzhk3DcITAzy6snCNZL+h6VIFgtaTywr9xm9T8N3ewyMytVPbehvhyYAWyLiDclNQALSm2VmZkNmG57BJIuBoiIfcDYiPj3bHkX8LEBaV0JhvwTN83M+llPQ0NX17z+P13WXVZCW8rlkSEzs0I9BYG6eV20PGS4P2BmltdTEEQ3r4uWD3tDNrnMzErW02TxNEnPUPkO/a3sNdny+0pvmZmZDYieguD4AWvFAPJcsZlZXk/3GnpxIBtSNj+hzMysWD0XlJmZ2TCWYBB4bMjMrNZBryyWNA54f7a4NSLeKrdJ5fDAkJlZsZ6uLB4l6VagHbgLuBvYJmlxtn7GwTYuabakrZLaqu/rsv5qSZslPSNpjaTj+rgfdfNksZlZXk9DQ18FjgKOi4gPRcRMKmcSvU/SN4CHetqwpBHAUuAcYDowX9L0LtV+ArRExG8DDwJf6dtuHJznis3MivU0NPQJYGrU3JwnIl6X9Fngl1S+4HtyCtAWEdsAJC0HzgM212zvBzX1nwAu7l3zzczsUPXUI9gXBXdoi4i9QEdEPHGQbR8LbK9Zbs/KunM58H+LVkhaKKlVUmtHR8dBPrZnHhkyM8vrKQg2S/qjroXZXUm39Gcjsm22ALcUrY+IZRHREhEtjY2NffsMTxebmRXqaWjoCuA7ki4D1mdlLcBY4II6tr0DmFyz3JSV5Ug6E7ge+G8DcUaSJ4vNzPJ6urJ4B/ARSWcAJ2TFqyJiTZ3bXgdMlTSFSgDMAz5dW0HSycDtwOyI+EVvG98bniw2Myt20OsIIuIx4LHebjgiOiUtAlYDI4A7I2KTpJuA1ohYQWUo6CjggewWED+PiDm9/SwzM+u7eh5V2WcRsQpY1aXshprXZ5b5+d20aaA/0szssJbMLSY8MmRmViyZIKhyf8DMLC+dIHCXwMysUDpBYGZmhZILAs8Vm5nlJRMEvrLYzKxYMkFQFZ4uNjPLSSYIfGWxmVmxZILAzMyKpRcEHhkyM8tJJgg8MmRmViyZIKhyh8DMLC+ZIJBni83MCiUTBGZmViy5IPCVxWZmeckEgUeGzMyKJRMEZmZWLLkg8C0mzMzykgkCjwyZmRVLJgiqPFlsZpaXTBB4stjMrFgyQWBmZsWSCwKPDJmZ5SUUBB4bMjMrklAQVIRni83McpIJAk8Wm5kVSyYIzMysWHJB4IEhM7O8UoNA0mxJWyW1SVpcsP5jkp6S1Cnpk6W2pcyNm5kNYaUFgaQRwFLgHGA6MF/S9C7Vfg5cCvx9We04gLsEZmY5I0vc9ilAW0RsA5C0HDgP2FytEBE/y9btK7EdZJ9R9keYmQ1JZQ4NHQtsr1luz8p6TdJCSa2SWjs6OvqlcWZmVjEkJosjYllEtERES2Nj46Fty2NDZmY5ZQbBDmByzXJTVjYoPDBkZlaszCBYB0yVNEXSaGAesKLEz6uLLyw2M8srLQgiohNYBKwGtgD/EBGbJN0kaQ6ApA9LagcuAm6XtKms9niu2MysWJlnDRERq4BVXcpuqHm9jsqQkZmZDZIhMVncnzw0ZGaWl0wQyNPFZmaFkgmCKncIzMzykgkCTxabmRVLJgjMzKxYckHgJ5SZmeUlFwRmZpbnIDAzS1xyQeCBITOzvGSCwGcNmZkVSyYIqjxXbGaWl0wQ+MpiM7NiyQSBmZkVSzAIPDZkZlYrmSDwZLGZWbFkgqDKk8VmZnnJBIF7BGZmxZIJAjMzK5ZcEHhkyMwsL5kg8HUEZmbFkgmCKk8Wm5nlJRMEniw2MyuWTBCYmVmx5IIgPF1sZpaTTBB4ZMjMrFgyQVDlyWIzs7xkgsCTxWZmxZIJAjMzK1ZqEEiaLWmrpDZJiwvWHynp/mz9k5Kay2wP+MpiM7OuRpa1YUkjgKXAx4F2YJ2kFRGxuaba5cCrEfF+SfOAJcDckloEwJX3/YRRR4h3HTmSPXv3MWtqI6NHumNkZukqLQiAU4C2iNgGIGk5cB5QGwTnATdmrx8Evi5JEeVO6X723qdyy+OPHEnj+CMZcYQnEszs8HXl707l9z/4nn7fbplBcCywvWa5HfhId3UiolPSa0AD8MvaSpIWAgsB3vve9/apMcc1jOMPP/Je3nx7L+NGj+BXb3WyaefrjDxCTPvN8ezZG77GwMwOaxPHjiplu2UGQb+JiGXAMoCWlpY+fVuPGnEEf3nBSf3aLjOz4aDMwfEdwOSa5aasrLCOpJHARGBXiW0yM7MuygyCdcBUSVMkjQbmASu61FkBXJK9/iTwWNnzA2Zmllfa0FA25r8IWA2MAO6MiE2SbgJaI2IF8LfAPZLagFeohIWZmQ2gUucIImIVsKpL2Q01r3cDF5XZBjMz65lPoDczS5yDwMwscQ4CM7PEOQjMzBKnoXa2pqQO4MU+vn0SXa5aToD3OQ3e5zQcyj4fFxGNRSuGXBAcCkmtEdEy2O0YSN7nNHif01DWPntoyMwscQ4CM7PEpRYEywa7AYPA+5wG73MaStnnpOYIzMzsQKn1CMzMrAsHgZlZ4pIJAkmzJW2V1CZp8WC3p79ImizpB5I2S9ok6aqs/N2S/p+k57P/Hp2VS9JfZ7+HZyTNHNw96BtJIyT9RNLKbHmKpCez/bo/u/U5ko7Mltuy9c2D2vA+kvRrkh6U9JykLZI+msAx/vPs3/RGSfdJGjMcj7OkOyX9QtLGmrJeH1tJl2T1n5d0SdFndSeJIJA0AlgKnANMB+ZLmj64reo3ncBfRMR04FTgimzfFgNrImIqsCZbhsrvYGr2sxD4xsA3uV9cBWypWV4C/FVEvB94Fbg8K78ceDUr/6us3lB0G/BoREwDPkhl34ftMZZ0LHAl0BIRJ1K5lf08hudxvhuY3aWsV8dW0ruBL1J5HPApwBer4VGXiBj2P8BHgdU1y9cB1w12u0ra14eBjwNbgWOysmOArdnr24H5NfX31xsqP1SedrcGOANYCYjK1ZYjux5vKs/D+Gj2emRWT4O9D73c34nAT7u2e5gf4+rzzN+dHbeVwNnD9TgDzcDGvh5bYD5we015rt7BfpLoEfDOP6qq9qxsWMm6wycDTwK/EREvZateBn4jez0cfhe3AtcA+7LlBuDfI6IzW67dp/37m61/Las/lEwBOoC7suGwOyS9i2F8jCNiB/C/gZ8DL1E5busZ3se5Vm+P7SEd81SCYNiTdBTwbeB/RMTrteui8ifCsDhPWNK5wC8iYv1gt2UAjQRmAt+IiJOB/+CdoQJgeB1jgGxY4zwqIfge4F0cOHyShIE4tqkEwQ5gcs1yU1Y2LEgaRSUE7o2I72TF/ybpmGz9McAvsvKh/rs4DZgj6WfAcirDQ7cBvyap+sS92n3av7/Z+onAroFscD9oB9oj4sls+UEqwTBcjzHAmcBPI6IjIvYA36Fy7Ifzca7V22N7SMc8lSBYB0zNzjgYTWXSacUgt6lfSBKVZz9viYiv1axaAVTPHLiEytxBtfyPsrMPTgVeq+mCHvYi4rqIaIqIZirH8bGI+EPgB8Ans2pd97f6e/hkVn9I/eUcES8D2yX9l6zod4HNDNNjnPk5cKqkcdm/8eo+D9vj3EVvj+1q4CxJR2e9qbOysvoM9iTJAE7GfAL4V+AF4PrBbk8/7tfvUOk2PgNsyH4+QWV8dA3wPPB94N1ZfVE5g+oF4FkqZ2UM+n70cd9PB1Zmr98H/AvQBjwAHJmVj8mW27L17xvsdvdxX2cArdlx/i5w9HA/xsCXgOeAjcA9wJHD8TgD91GZB9lDpfd3eV+OLXBZtv9twILetMG3mDAzS1wqQ0NmZtYNB4GZWeIcBGZmiXMQmJklzkFgZpY4B4ElS9Kvsv82S/p0P2/7C12Wf9yf2zfrTw4Cs8oNv3oVBDVXt3YnFwQR8V972SazAeMgMIObgVmSNmT3wB8h6RZJ67J7vn8GQNLpkh6XtILKVa5I+q6k9dl98xdmZTcDY7Pt3ZuVVXsfyra9UdKzkubWbHut3nnmwL3ZFbVmpTvYXzVmKVgMfC4izgXIvtBfi4gPSzoS+GdJ38vqzgROjIifZsuXRcQrksYC6yR9OyIWS1oUETMKPusPqFwl/EFgUvaef8rWnQycAOwE/pnKvXV+1N87a9aVewRmBzqLyv1cNlC5pXcDlQeBAPxLTQgAXCnpaeAJKjf9mkrPfge4LyL2RsS/AT8EPlyz7faI2EflViHN/bAvZgflHoHZgQT8WUTkbtol6XQqt4CuXT6TygNR3pS0lso9b/rqrZrXe/H/nzZA3CMwgzeA8TXLq4HPZrf3RtIHsgfBdDWRyuMR35Q0jcqjQqv2VN/fxePA3GweohH4GJWbpJkNGv/FYVa5o+febIjnbirPN2gGnsombDuA8wve9yjwp5K2UHlk4BM165YBz0h6Kiq3ya56iMojFp+mctfYayLi5SxIzAaF7z5qZpY4Dw2ZmSXOQWBmljgHgZlZ4hwEZmaJcxCYmSXOQWBmljgHgZlZ4v4/udNeK+f2Nq0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "lmm_agent = LMMFQIagent(train_tuples=train_tuples, test_tuples=test_tuples, estimator='linnet', gamma=0.5, state_dim=46, batch_size=len(train_tuples), iters=1000)\n",
    "# The two policies are in lmm_agent.piE_foreground and lmm_agent.piE_background\n",
    "Q_dist = lmm_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"LMMFQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "spanish-stewart",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       [1., 1., 1., ..., 1., 1., 1.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lmm_agent.Qtable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rolled-adelaide",
   "metadata": {},
   "outputs": [],
   "source": [
    "fqi_agent = FQIagent(train_tuples=train_tuples, test_tuples=test_tuples, state_dim=46, gamma=0.5, batch_size=len(train_tuples), iters=1000, estimator='gbm')\n",
    "Q_dist = fqi_agent.runFQI(repeats=1)\n",
    "plt.plot(Q_dist, label= \"FQI\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"Q Estimate\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "productive-observer",
   "metadata": {},
   "source": [
    "# Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "advisory-collaboration",
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate_agent(sepsis_env, ds='foreground', plot='action'):\n",
    "        algos = ['fqi', 'lmmfqi', 'oracle', 'random']\n",
    "        val_rewards = {}\n",
    "        alg_actions = {}\n",
    "        for alg in algos:\n",
    "            val_rewards[alg] = []\n",
    "            alg_actions[alg] = []\n",
    "        \n",
    "        for alg in algos:\n",
    "            print(\"Alg: \", alg)\n",
    "            pend.reset()\n",
    "            \n",
    "            for i in range(100):\n",
    "                s = pend._get_obs().reshape((1, -1))\n",
    "#                 print(pend.state)\n",
    "                if alg == 'fqi':\n",
    "                    # FQI agent\\n\",\n",
    "                    fqi_action = fqi_agent.piE.predict(s)\n",
    "                    #fqi_action = np.rint(fqi_action)\n",
    "                    alg_actions['fqi'].append(fqi_action[0])\n",
    "                    ns, cost, _, _ = pend.step(fqi_action)\n",
    "                    val_rewards['fqi'].append(cost/10)\n",
    "                    s = ns\n",
    "                elif alg == 'lmmfqi':\n",
    "                    # LMMFQI agent\n",
    "                    if ds == 'foreground':\n",
    "                        group = [1]\n",
    "                        lmmfqi_action = lmm_agent.piE_foreground.predict(s, group)\n",
    "                    else:\n",
    "                        group = [0]\n",
    "                        lmmfqi_action = lmm_agent.piE_background.predict(s, group)\n",
    "                    alg_actions['lmmfqi'].append(lmmfqi_action[0])\n",
    "                    ns, cost, _, _ = pend.step(lmmfqi_action)\n",
    "                    val_rewards['lmmfqi'].append(cost/10)\n",
    "                elif alg == 'random':\n",
    "                   # Random action\n",
    "                    random_action = pend.action_space.sample()\n",
    "                    random_action = np.rint(random_action)\n",
    "                    alg_actions['random'].append(random_action[0])\n",
    "                    ns, cost, _, _ = pend.step(random_action)\n",
    "                    val_rewards['random'].append(cost)\n",
    "                elif alg == 'oracle':\n",
    "                    # Oracle\n",
    "                    best_reward = None\n",
    "                    best_ns = None\n",
    "                    best_action = None\n",
    "                    actions = [-2, -1, 0, 1, 2]\n",
    "                    for j, a in enumerate(actions):\n",
    "                        a = np.asarray([a])\n",
    "                        ns, cost, _, _ = pend.step(a)\n",
    "                        if best_reward is None or cost > best_reward:\n",
    "                            best_reward = cost\n",
    "                            best_ns = ns\n",
    "                            best_action = a\n",
    "                    alg_actions['oracle'].append(best_action[0])\n",
    "                    val_rewards['oracle'].append(best_reward)\n",
    "                    ns = best_ns\n",
    "                else:\n",
    "                    raise Exception(\"Invalid algorithm selected\")\n",
    "\n",
    "        #plt.title(\\\"Cumulative Reward for ds: \\\" + str(ds))\\n\",\n",
    "        plt.xlabel(\"Step\")\n",
    "        \n",
    "        x = [i for i in range(100)]\n",
    "        rewards_fqi = util_fqi.cumulative_reward(val_rewards['fqi'])\n",
    "        rewards_lmmfqi = util_fqi.cumulative_reward(val_rewards['lmmfqi'])\n",
    "        rewards_oracle = util_fqi.cumulative_reward(val_rewards['oracle'])\n",
    "        rewards_random = util_fqi.cumulative_reward(val_rewards['random'])\n",
    "        if plot == 'reward':\n",
    "            plt.plot(x, rewards_fqi, label=\"FQI\", alpha=0.7)\n",
    "            plt.plot(x, rewards_lmmfqi, label='LMMFQI', alpha=0.7)\n",
    "#             plt.plot(x, rewards_oracle, label='Oracle')\n",
    "#             plt.plot(x, rewards_random, label='Random', alpha=0.7)\n",
    "            plt.ylabel(\"Cumulative Reward\")\n",
    "        else:\n",
    "            plt.plot(x, alg_actions['fqi'], label=\"FQI\", alpha=0.7)\n",
    "            plt.plot(x, alg_actions['lmmfqi'], label='LMMFQI', alpha=0.7)\n",
    "#             plt.plot(x, alg_actions['oracle'], label='Oracle')\n",
    "#             plt.plot(x, alg_actions['random'], label='Random', alpha=0.7)\n",
    "            plt.ylabel(\"Action\")\n",
    "        plt.legend()\n",
    "        plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
