{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSPI on the chain walk environment\n",
    "\n",
    "## Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:00.063484Z",
     "iopub.status.busy": "2022-09-19T11:21:00.063286Z",
     "iopub.status.idle": "2022-09-19T11:21:00.620418Z",
     "shell.execute_reply": "2022-09-19T11:21:00.619922Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "import jax\n",
    "import os\n",
    "import json\n",
    "\n",
    "parameters = json.load(open(\"parameters.json\"))\n",
    "n_states = parameters[\"n_states\"]\n",
    "n_actions = 2\n",
    "sucess_probability = parameters[\"sucess_probability\"]\n",
    "gamma = parameters[\"gamma\"]\n",
    "env_seed = parameters[\"env_seed\"]\n",
    "\n",
    "# Sample collection\n",
    "n_repetitions = parameters[\"n_repetitions\"]\n",
    "n_samples = n_states * n_actions * n_repetitions\n",
    "\n",
    "# Trainings\n",
    "max_bellman_iterations = parameters[\"max_bellman_iterations\"]\n",
    "\n",
    "# keys\n",
    "env_key = jax.random.PRNGKey(env_seed)\n",
    "dummy_q_network_key = env_key.copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:00.622900Z",
     "iopub.status.busy": "2022-09-19T11:21:00.622698Z",
     "iopub.status.idle": "2022-09-19T11:21:00.974915Z",
     "shell.execute_reply": "2022-09-19T11:21:00.974411Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pbo.environments.chain_walk import ChainWalkEnv\n",
    "\n",
    "\n",
    "states = np.arange(n_states)\n",
    "actions = np.arange(n_actions)\n",
    "states_boxes = np.arange(n_states + 1) - 0.5\n",
    "actions_boxes = np.arange(n_actions + 1) - 0.5\n",
    "\n",
    "env = ChainWalkEnv(env_key, n_states, sucess_probability, gamma)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Samples on the mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:00.978330Z",
     "iopub.status.busy": "2022-09-19T11:21:00.978179Z",
     "iopub.status.idle": "2022-09-19T11:21:02.109923Z",
     "shell.execute_reply": "2022-09-19T11:21:02.109432Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from pbo.sample_collection.replay_buffer import ReplayBuffer\n",
    "\n",
    "\n",
    "replay_buffer = ReplayBuffer()\n",
    "\n",
    "for state in states:\n",
    "    for action in actions:\n",
    "        # Need to repeat the samples to capture the randomness\n",
    "        for _ in range(n_repetitions):\n",
    "            env.reset(jnp.array([state]))\n",
    "            next_state, reward, absorbing, _ = env.step(jnp.array([action]))\n",
    "\n",
    "            replay_buffer.add(jnp.array([state]), jnp.array([action]), reward, next_state, absorbing)\n",
    "\n",
    "replay_buffer.cast_to_jax_array()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:02.111993Z",
     "iopub.status.busy": "2022-09-19T11:21:02.111855Z",
     "iopub.status.idle": "2022-09-19T11:21:02.948310Z",
     "shell.execute_reply": "2022-09-19T11:21:02.947834Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ8AAAEYCAYAAACDV/v0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAArxklEQVR4nO3deZxcVZ3+8c/Dvq8BhADCKCqKioKgDiigIiAK44gCjiwDIriBDA44Krih8BuVcUcEBBdQBkQyEjZBwFFAghMhiEJkkYQ1hB0Ekjy/P85pU6l0d1V1uquru593XveVqnvvufdUdXd965x77vfINhEREd201GhXICIiJp4En4iI6LoEn4iI6LoEn4iI6LoEn4iI6LoEn4iI6LoEn4iICUTS6ZIekDSjYd1ekm6WtEDS1oOU3UXSnyXNlHRMw/pNJV1X1/9U0nKt6pHgExExsZwB7NK0bgbwTuDqgQpJWhr4FrAr8FJgH0kvrZtPBE6y/ULgYeCgVpVI8ImImEBsXw3MbVp3i+0/tyi6DTDT9u22nwV+AuwhScBOwLl1vzOBPVvVY5lOKx4REd3x1h1X9kNz53dU5oYbn7kZ+FvDqlNsnzIM1ZkM3N3wfBawLbA28IjteQ3rJ7c6WIJPRESPmjN3PtddsmFHZZZd/y9/sz3gdZtekeATEdGzzHwvGO1K9JkNbNTwfMO67iFgDUnL1NZP3/pB5ZpPRESPMrAAd7SMoOuBzerItuWAvYEpLtmpfwW8q+63P3BBq4Ml+ERE9LAFHf5rRdLZwDXAiyXNknSQpH+SNAt4HXChpEvqvhtImgpQWzUfBi4BbgHOsX1zPezRwJGSZlKuAZ3Wsh6ZUiEioje96pXL+aqLntdRmdUn331DrvlERMQSGeGutFGT4BMR0aMMzE/wiYiIbkvLJyIiusrA/HF6XT7BJyKih/XMXT7DLMEnIqJHGeeaT0REdJlh/viMPQk+ERG9qmQ4GJ8SfCIiepaYj0a7EiMiwSciokcZWJBut4iI6La0fCIioqtKhoMEn4iI6LIFTvCJiIguSssnIiK6zoj543TatQSfiIgelm63iIjoqnS7RUTEKBDznW63iIjoopJeJ8EnIiK6LN1uERHRVXa63SIiYhQsSMsnIiK6qYx2G58tn/H5qiIixoXS7dbJ0vKI0umSHpA0o2HdWpIuk3Rb/X/NfsrtKGl6w/I3SXvWbWdIuqNh25at6pHgExHRo/pGu3WytOEMYJemdccAl9veDLi8Pl+0LvavbG9pe0tgJ+Ap4NKGXT7et9329FaVSPCJiOhh862OllZsXw3MbVq9B3BmfXwmsGeLw7wLuMj2Ux2+nL9L8ImI6FF9ud06WYBJkqY1LIe0car1bN9bH98HrNdi/72Bs5vWHS/pRkknSVq+1Qkz4CAiooct6Hyo9RzbWw/1fLYtacD5UyWtD7wcuKRh9ScoQWs54BTgaOBzg50nLZ+IiB7VN9qtw5bPUNxfg0pfcHlgkH3fDZxv+7m/19O+18UzwPeBbVqdMMEnIqJHmc6u97RzzWcAU4D96+P9gQsG2XcfmrrcGgKXKNeLZixebFEJPhERPWy4R7tJOhu4BnixpFmSDgJOAN4i6TbgzfU5kraWdGpD2U2AjYCrmg77Y0k3ATcBk4AvtKpHrvlERPQom2FPr2N7nwE2vamffacBBzc8vxOY3M9+O3VajwSfiIiepaTXiYiI7jLD3/LpFQk+ERE9bLzmdkvwiYjoUUYsGPoItp6W4BMR0cPS8omIiK4yQ8pwMCYk+ERE9CxlGu2IiOiutHwiImJUpOUTERFdZSstn4iI6L7cZBoREV1VptFOt1tERHSV0vKJiIjuKqPd0vKJiIguS4aDiIjoquR2i4iIUdHO7KRjUYJPRESPKjOZpuUTERFdlm63iIjoqnLNJ91uERHRZcntFhERXZX7fCIiYhSM32638fmqIiLGiQWoo6UVSadLekDSjIZ1a0m6TNJt9f81Byg7X9L0ukxpWL+ppOskzZT0U0nLtapHgk9ERI/qG2rdydKGM4BdmtYdA1xuezPg8vq8P0/b3rIu72hYfyJwku0XAg8DB7WqRIJPREQPW+ClOlpasX01MLdp9R7AmfXxmcCe7dZPkoCdgHM7KZ/gExER69m+tz6+D1hvgP1WkDRN0rWS9qzr1gYesT2vPp8FTG51wgw4iIjoUUPM7TZJ0rSG56fYPqXtc9qW5AE2P9/2bEn/AFwh6Sbg0U4rCAk+ERE9bQiTyc2xvXWHZe6XtL7teyWtDzzQ3062Z9f/b5d0JfAq4DxgDUnL1NbPhsDsVidMt1tERI/qu8+nk2WIpgD718f7Axc07yBpTUnL18eTgH8E/mjbwK+Adw1WvlmCT0REDxvuAQeSzgauAV4saZakg4ATgLdIug14c32OpK0lnVqLbg5Mk/QHSrA5wfYf67ajgSMlzaRcAzqtVT3S7RYR0auWrDXT/yHtfQbY9KZ+9p0GHFwf/xZ4+QDHvB3YppN6JPhERPQoM6RrPmNCgk9ERA9LbreIiOiqJBaNiIhRkeATERFdNcSbTMeEBJ+IiB6WAQcREdFdTrdbRER0WQYcRETEqEjwiYiIrsqAg4iIGBVO8ImIiG7LaLeIiOgqZ7RbRESMhnS7RUREl2XAQUREjIK0fCIioqtyk2lERHSfy6CD8SjBJyKih2WodUREdJUZv9d8lhrtCsTEJukzkn402vUYaZL+Q9Kpg2x/r6RLu1mnGAvKaLdOlrEiwWeCkrSdpN9KelTSXEm/kfSa0a7XeCBpB0mzGtfZ/qLtg+v2TSRZ0jIN239se+du1zV6n93ZMlak220CkrQa8AvgMOAcYDlge+CZ0azXSJG0jO153TpXN84TE0e63WI8eRGA7bNtz7f9tO1Lbd8IIOkFkq6Q9JCkOZJ+LGmNvsKS7pT0cUk3SnpS0mmS1pN0kaTHJf1S0pp1375v+YdIukfSvZKOGqhikl5bW2SPSPqDpB0ath0g6fZ6jjskvXeAY3xG0rmSfiTpMeAASavXet4rabakL0hauuG4v5H0zdoS/JOkNzUc70BJt9Tz3i7pAw3bdpA0S9LRku4DzgYuAjaQ9ERdNmjqXry6/v9I3f66Wof/bTju6yVdX+tzvaTXN2y7UtLna50fl3SppEmD/sRjTCqtGXW0tCLpdEkPSJrRsG4tSZdJuq3+v2Y/5baUdI2km+vf/nsatp1R/yan12XLVvVI8JmYbgXmSzpT0q79/KIJ+BKwAbA5sBHwmaZ9/hl4CyWQvZ3ygfsfwDqU36uPNu2/I7AZsDNwtKQ3N1dK0mTgQuALwFrAUcB5ktaRtDLwdWBX26sCrwemD/Ia9wDOBdYAfgycAcwDXgi8qtbj4Ib9twX+AkwCjgN+Jmmtuu0BYHdgNeBA4CRJr24o+7xa3+cD+wG7AvfYXqUu9zTV7Q31/zXq9mua3oe16vvwdWBt4KvAhZLWbtht31qXdSkt1wEDeoxtI3DN5wxgl6Z1xwCX294MuLw+b/YUsJ/tl9Xy/9X4pRT4uO0t6zK9VSUSfCYg248B21EG03wPeFDSFEnr1e0zbV9m+xnbD1I+/N7YdJhv2L7f9mzg18B1tv/P9t+A8ykf8I0+a/tJ2zcB3wf26adq/wJMtT3V9gLblwHTgN3q9gXAFpJWtH2v7ZsHeZnX2P657QWUoLEbcEStwwPAScDeDfs/APyX7eds/xT4M/C2+n5caPsvLq4CLqV0U/ZZABxX36+nB6lTu94G3Gb7h7bn2T4b+BMlyPf5vu1b6/nOAbYchvNGDxruaz62rwbmNq3eAzizPj4T2LOfcrfavq0+vofyN7POUF9Xgs8EZfsW2wfY3hDYgtLK+S+A2oX2k9o99RjwI0qLoNH9DY+f7uf5Kk37393w+K56vmbPB/aqXW6PSHqEEiTXt/0k8B7gUOBeSRdKeskgL7HxfM8Hlq3l+o77XUqroc9se5E/3b/XsbYOr1UZmPEIJZA1vh8P1qA7XDao5290FzC54fl9DY+fYvH3O8aJ4e52G8B6tu+tj+8D1htsZ0nbUFrcf2lYfXztjjtJ0vKtTpjgE9j+E6UpvkVd9UVKq+jltlejtEiW9KrnRg2PNwaau6KgBIwf2l6jYVnZ9gm1npfYfguwPqUl8L1BztcYSO6mDKaY1HDc1Wr3QZ/Jkhpf48bAPfWP6Dzgy5Q/0DWAqSz6fjR/32z1/bPV9nsoAbPRxsDsFuVinDGdBZ4afCZJmtawHNLROcuXsAF/RyWtD/wQOLD2LAB8AngJ8BpKF/TRrc6T4DMBSXqJpH+TtGF9vhGlG+zausuqwBPAo/U6zMeH4bSflrSSpJdRrlX8tJ99fgS8XdJbJS0taYV6QX/D2hrbo177eabWb0E/x1hM/UZ3KfAVSatJWkplUEVjV+K6wEclLStpL8q1rqmUb3fLAw8C8yTtSrleNJj7gbUlrT7A9gdr3f9hgO1TgRdJ2lfSMvXC7kspIxRbUhkQckA7+0bvc4cLMMf21g3LKW2c5v4aVPqCywP97aQyUvZC4JO2+z4vqN3gtv0MpVt9m1YnTPCZmB6nXGC/TtKTlKAzA/i3uv2zwKuBRym/aD8bhnNeBcykXMz8su3Fbqi0fTel7/k/KB/Qd1MC31J1OZLSKphLuQZ1WAfn348SSP4IPEwZjLB+w/brKAMi5gDHA++y/ZDtxymDJ86p5fYFpgx2otqSPBu4vXbzbdC0/al6jt/U7a9t2v4QZYDDvwEPAf8O7G57TqsXKWk5yiCFa1vtG2PACIx2G8AUYP/6eH/gguYd6u/W+cAPbJ/btK0vcIlyvWhGc/nFjuexdFdSjDmSNgHuAJbt1r02naqthINtbzfadVlSkrYDPmS7vwEdMcas8ILJ3vjEQzsqc9tex95ge+uBtks6G9iBct3yfsrozp9TvmBtTLm++G7bcyVtDRxq+2BJ/0Jp1TQO9DnA9nRJV1AGH4gyCvVQ208MVs/cEBcxjtj+X+B/W+4YY8Zw32Q6yBeTNzWvsD2NekuC7R9Rusb7O+ZOndYjwSciooeN186pBJ8YUbbvZMlHyo0o22dQRvtF9JTxnNU6wSciolcZGKfBJ6PdYsKRdLKkTw+y3ZJeOEznOkPSF4bpWFdKOrj1njGejNes1gk+41BNEni+StLPuyTt27DtlTUx4BxJRzasX1bSdfWenzGnkw9m24fa/vwI1GGR5KARw2IIN/qMBel2G5++BTxLSZGxJSUp5R9qLrQvUZJQ3gjcKOks2/dR7qE5r95rMywkLW17/nAdL2LiWaJ7d3paWj7jTM0A8M/Ap20/UYfeTgHeV3fZFLiiJgS9DdhY0vNrmZPaOP5/S7pPJdX/1TVjQd+2MyR9R9LUevPqjirTCZwn6UGVlOvN2a4bj726pB/Ufe+S9ClJS9Vti8x4qoYJ2SQdT0n0+U2VKQq+qeIkldTxj0m6SdIWDfX8QsOxPq4y1cI9kv61qU7LS/qypL9Kur922a3YT903B04GXlfr8EjD5jVVctE9XluXL2go9xKVFPZzJf1Z0rtb/AheIOl39TVdoIWZt5H0jtqqfaS2BDev64+u512mPj+s7rdCi3NFLxinLZ8En/HnRcA827c2rPsD0BckZgA7q6TW2YSSGPBrlHToz7Vx/IsomQDWBX5Pma6g0b6Uu/dXBX4L/E89/2TKfQRHSHrrAMf+BrA6Je3MGylZCQ5sVSHbn6Rk1v5wnaLgw5QUOG+gvB+rA++mZAtYhKRdKC3Bt9TX1TzVwwn1GFtSpmOYDBzbTx1uoSQ9vabWYY2GzXtTskasScnycHw998rAZcBZlPdzb+Dbkl46yMvdD/hXSnaGeZRpF5D0IkpWhSMoN/tNBf5H5a70/6SkJPqUpM0oufv+ZZiTocZI6F6Gg65L8Bl/VgEea1r3KCUYQPmgPYzSGvoY8I+UdDt31G/SV6nkNuuX7dNtP15zOH0GeKUWzWF2ge3f1ISDLwfWsf0528/avp2SDHTv5uOqTOy2N/CJevw7ga+wsMXWqefqa34JJZPHLQ1Zexu9mzI9wYyaOfszDXUScAjwMdtza6qdL/ZX/xbOt/27muHhxyyc/mB34E7b369TJ/wfJYnpgO8/JfFqX10/Dby7vnfvAS50mQrjOUoi1BWB19efxX6UNEFTgP9XzxVjwTht+eSaz/jzBGX+mkarUQIMtu+izo8jaSXgGkor4RuUZJ8XAjMkXW57kTk/6ofc8ZQPx3VYmNhzEiXAweJTGWzQ1AW1NKWV0mwSZdqDxqkEmqcRaJvtKyR9k3L96/mSfgYc5TKXUaMNgBuaztlnHWAl4AYtTHit+ho6MdD0B88Htm16f5ahZAweSPPUFMtS3rtFpmGwvUDS3dT3z/adkn5F+dl/q8P6x6gaO62ZTqTlM/7cCixTu1f6vJJF8zH1ORb4nu37Ka2UabYfBWZRupia7UtJ/PlmSlfWJnX9QNML3A3c0TRFwqq2d2NxcyitlcapBBqnEXiSEgj6PK+p/GLf+Wx/3fZWlIzQL6L/7Nz3svh0D411ehp4WUP9V7c90Nw5nX7vvBu4qun9WcX2YAlTm+v6XK3nItMw1FbbRtT3T9LbgNdRErv+Z4f1jNE0Tls+HQUflVT0zd+qo4fU7pifAZ+TtLKkf6QEjEW+TdfrCjsA36mr7gB2UpnNdDPgr/0cflXKtYOHKIHgiy2q8zvg8XrBe0WVaRK2kPSafuo9n5LY8HhJq9ZBEEeyMJfUdOANkjau3XyfaDrE/TRMUSDpNZK2lbQsJXD9jf6nYDgHOEDSS2tL8LiGOi2gdBOeJGndetzJg1yzuh/YsF5naccvKFMnvE9lqPuytd6bD1LmXxrq+jng3Ib37m2S3lRf879Rfla/lTQJOJWSo2t/yrQV/X0BiF40UYOPpLNU5kBZmXKx+o+ShmN+lxg5H6T09z9AuQh9mBefcvpbwOENQ6E/QbkmcDPwxTr8utkPKF07sylTEwyatr8ee3fKNY47KN/QT6W0mvrzEUqguJ2SHPMs4PR6rMso3YI3UrrJmue2+RrwLkkPS/o6pavxe5RpEO6iBMzFvvHbvogyg+sVlMEAVzTtcnRdf63KrK6/BF48QP2voLx/90lqOf1BvYa0M+Ua0j2U7rkTKfMHDeSHlFRA9wErUH5m2P4zZdK/b1De57cDb7f9LHAK5Vrc1Dpdw0HAqZLWBqij87ZvPlH0gL4MB50sY0TLKRUkTbe9paT3UuZ4OQa4wfYrulHBiIiJavlNNvTzjh3w7oR+/fWgowedUqFXtNPttmxtxu8JTKkjacZQ4y4iYgybqN1uwHeBO4GVgatrX3zziKHFSDq93uDXcka7iIgYwDjtdmsZfOqIocm2d6tzdN8F7NjGsc8AdlnSCkZETGRyZ8tY0fI+H0nLU1KvbNK0/+cGK2f7apUplCMiYijGWFdaJ9q5yfQCyg2EN1CGbkZERFeMra60TrQTfDa0PWLdZ5IOoaQwQcstt9Wy6607UqeKiBhxz949a47tdYbtgBO45fNbSS+3fdNIVMD2KZT7EFh+4408+agjRuI0ERFdccfhR93Veq8OTODgsx3lDvA7KN1uApz7fCIiumACB59dh3JgSWdT0rdMkjQLOM72aUM5VkTEhNSX4WAcameo9V3AGtR0HcAadV2rcvvYXt/2srY3TOCJiOjccA+17u8eTElr1UkNb6v/rzlA2f3rPrdJ2r9h/VYqEzbOlPR1NaSBH0g7ud0Op8xBsm5dfiTpI61fYkRELLHhz3BwBovfg3kMcLntzSiZz49pLlRnzT0O2BbYBjiuIUh9B3g/JSnxZv0cfzHtZDg4CNjW9rG2jwVeW08SERFjjO2rgblNq/cAzqyPz6SkU2v2VuCyOrHiw5RZeHeRtD6wmu1rXZKF/mCA8oto55qPgPkNz+czXmc3iojoMV3KWrBew0y/9wHr9bPPZBadzHBWXTe5Pm5eP6h2gs/3gesknV+f7wnk+k1ERDd0PuBgkqRpDc9Pqbe0tHc629LIh7yWwcf2VyVdSRlyDXBg5n+PiOiCoaXXmTOEKRXul7S+7XtrN9oD/ewzmzKCuc+GwJV1/YZN62fTwoDXfPpmLK0Xme6kzCj5I+Cuui4iIkZad6ZUmEKZ5Zb6/wX97HMJsLOkNetAg52BS2p33WOSXltHue03QPlFDNbyOYsyC+UNLPqSVJ//Q3+FIiJi+Ax3B1h/92ACJwDnSDqIMvPvu+u+WwOH2j7Y9lxJnweur4f6nO2+gQsfpIyiWxG4qC6DGjD42N69/r9px6+ukrQLZXrjpYFTbZ8w1GNFRExIwxx8bO8zwKY39bPvNODghuenU6e272e/LTqpRzv3+Vzezrp+9lka+BYlQ8JLgX0kvbSTykVETHjjdCbTAVs+klYAVqI0zdZk4fDq1WhjGB3lJqSZtm+vx/sJZSz5H5eoxhERE8RYmyCuE4Nd8/kAcASwAeW6T1/weQz4ZhvH7m9M+LadVzEiYgIbp7ndBrvm8zXga5I+YvsbI1WBxvl8gGfuOPyoGYPt38IkYE7Kj9nyvVCHlM/vwJKWf/ESlF3cBGz59FkgaQ3bjwDULrh9bH+7RbnZwEYNz/sd+904n4+kaUMYn/53KT+2y/dCHVI+vwPDUX6oZfs93jgNPu3kdnt/X+ABqDl92sntdj2wmaRNJS0H7E0ZSx4REe2aaAMOGiwtSTVhXN8otuVaFbI9T9KHKTcmLQ2cbvvmJaptRMREMkEHHPS5GPippO/W5x+gjRuIAGxPBaZ2UJ+28w+l/Lgs3wt1SPnRLd8LdRjt8osap8FHtUEz8A7SUpQBAX03IN0IPM/2h0a4bhERE9oKkzfyxocd2VGZ2z595A1Let2tG9qZyXQBcB0lv9s2wE7ALSNbrYiIgOGfybRXDHaT6YuAfeoyB/gpgO0du1O1iIgYrwa75vMn4NfA7rZnAkj6WFdqFRERxRhqzXRisODzTsrw6F9Juhj4CcM4g6mkl1DS7fSl6pkNTLHdtS69WofJwHW2n2hYv4vti9sovw1l7qXra966XYA/1YEWQ6nPD2zvN8Sy21G6RWfYvrSN/bcFbrH9mKQVKXO2v5qS/uiLth9tUf6jwPm27x5sv0HK9w2/v8f2LyXtC7ye0qV7iu3n2jjGP1B+TzeizLB7K3CW7ceGUqeInjPGutI6MeA1H9s/t7038BLgV5RUO+tK+o6knZfkpJKOZmEw+11dBJwt6ZglOXY9/oFt7PNRypwTHwFmSNqjYfMX2yh/HPB14DuSvkRJObQycIykT7ZRfkrT8j/AO/uet1H+dw2P31/PvypwXJvv4enAU/Xx14DVgRPruu+3Uf7zlBlufy3pg5LWaaNMo+8DbwMOl/RDYC/KtcXXAKe2Klx/ficDK9Qyy1OC0LWSduiwLlFJWrcH6rD2aNehp4zT+3xajnZbZOeS3WAv4D22F0u/3cFxbgVe1vzttn4bvtn2ZkM9dj3OX21v3GKfm4DX2X5C0ibAucAPbX9N0v/ZflUb5bekfOjdB2zY0Iq4zvYrWpT/PaWVcSrlV0bA2ZTWALavalH+73WUdD2wm+0HJa0MXGv75S3K32J787662H51w7bptrdsdX5gK+DNwHuAd1ByAJ4N/Mz24y3K32j7FZKWobR6N7A9v05G9Yc23r+bgC1rmZWAqbZ3kLQxcEGrn189xurAJyhTw69L+Tk8QPlSckLjzdWdknSR7V1b7LNaPf+GwEW2z2rY9m3bH2xR/nmUuVgWAMdSvkj9M6X1eHid5Guw8s2TQoryM3wV5bNh7uKlFjvG33sJ6vv5VcqXgRnAx2zf36L8CcCXbc+pc8ecU1/PssB+bfwd/B74GXC27b+0qm8/5bcG/pPyO/gJypeybSit6ENazdosaRXg3ynv+4bAs8BfgJNtn9FpfZqtsMFG3uT9nY12+/Pnxslot0a2H7Z9ypIEnmoBJWFps/XrtpYk3TjAchOwXhuHWKqvq832nZTJlXaV9FXa616cZ3u+7aeAv/R19dh+us3XsDXlD/2TwKO2rwSetn1Vqz+4vvqrzCi4NuWD4sF6/ieBeW2Un9HQQvxD/SPsG2jSssurnMoLbF9q+yDKz/PblK7H29us/3KU1tpKlJYXlGC+bBvlYWG38fLAKrVSf+2g/DnAw8AOtteyvTawY113TqvCkl49wLIV5YtJK9+n/K6dB+wt6TxJy9dtr22j/BmULzB3U3onngZ2o1yrPbmN8nMov4N9yzRKN/Tv6+N2NPYSfAW4F3g7JcPJd/stsai32e7Lo/aflC+2LwTeUo/XyprAGpTLA7+T9DFJ/X22DOTbwP8DLgR+C3zX9uqUbuhWKcQAfkz5fX8r8FlKb8j7gB0ltexBaUVMwNFuI+wI4HJJt7Ew8/XGwAuBD7d5jPUoP/CHm9aL8kvUyv2StrQ9HaC2gHanfPMZtNVQPStppRp8tvr7ycu3v5bBpw5hP0nSf9f/76ezn8fqLMw2bi2cf30V2gueB1MSx36K8iF0jaS7KT+PgwctWSxyjtqKnQJMqS2RVk6jDGpZmhKA/1vS7ZQP3Z+0Uf5U4HpJ1wHbU7oMqd1/Lb+xV5vYPrHpddwHnCjpX9sofz1wFf2/32u0Uf4Ftv+5Pv557a69QtI72igLsF5f0l9JH2x4Ld9QmZGylY9TPuQ/bvumepw7PPQJJLduaDGfJGn/wXaulpG0jO15wIq2rwewfWtDIB7Mw7aPAo6StD1ldO7vJd1CaQ21uuFzWdsXAUg60fa59fyXS/pyG+ffpKGF81VJ19v+fP1i90fgP9o4xuDGUEDpxKgEH9sX12/Y27DogIPrbc9v8zC/AFbpCx6NJF3ZRvn9aGoh1D+A/bQwm8Ng3mD7mVquMdgsy8K50FuyPQvYS9LbKNNVtFtukwE2LQD+qY3yjwIH1K6fTSm/C7NadZM0eM8gx35qoG0N+5wkqW/4/j2SfkDpwvue7d8NXrpkXZf0S2Bz4Cu2/1TXPwi8oc3XcJekfwfO7HvdktYDDmDR6UAGcgvwAdu3NW+ogbyV5SUt1ff7Y/t4SbOBq6ktuRYaey5+0LRt6VaFbX+l/gxOqvU9js4/6taVdCQlAK8mLUzFRXs9K98Gptbut4slfY3SjbYTML2Titj+NfBrSR+hBNX30DrbwN9UrmGvTvkSt6ftn0t6I2UQSytPStrO9v/WLw1za10W1C7kJTPGWjOdGK2WT98H9rVLUH7Ab3a2922j/KxBtv2mjfLPDLB+DkNIx277QkrTf4nUD/47Otj/MeAPQzjPrZ2W6ecY9zQ8foRy3a2T8jcDS5Iv8D2U7pWrtPBC+/2UFtxebZT/DAN/wH6kjfL/Q/mQ/WXfCttnSLoPaGcakwskrWL7Cduf6lsp6YXAn9so3/jl5x3AZZQu0E58j9J1CnAmZTqCB+v1qOltnP8btav8MOBFlM+kzYCfUwa1tLLY72H9AntxXVo5lNLttoDSk3KYpDMoX4bbSaB8KHCqpM0ov4v/Cn9vgX+rjfKtjdPg09GAg4iJQtKBttsZ9TduyqsMlnmB7RlLev6h1iHlF7Xi+ht50wM7G3Bwy5fG4YCDiAnksxOtvO2nbfdN5rik5x+OY0z08kAGHESMO5JuHGgTbYyYnOjle6EOY718W8ZQQOlEgk9MZEs6YnKil++FOoz18oMbYzeOdiLBJyayJR0xOdHL90Idxnr5lkaiK03S4ZQBFaKMMP2vpu0fB95bny5DGVW6ju25ku4EHqeMBpw31OtLGXAQEdGjVnzeRn7B+zobcHDzlwcfcCBpC8q9dNtQMjJcDBzqmkC6n/3fTslWsVN9fiflnq6OR/U2yoCDiIgeNgIDDjanpAB7qt7beBUlQe9A9qGkzRpWCT4REb1s+BOLzgC2l7R2zUayGyUp72Lq9l0oKaAaa3SppBskHTKEVwTkmk+MUTUVzb6UfucFwAeA11GmYxg0w4KkI9rZL2LUDW3AwSRJjbn5TmlMM2T7FkknApcCT1JuBh4om8Pbgd940SSz29meXW/MvkzSn2xf3WklE3xizJH0OmB34NW2n5E0CViOMtvuj1g4VcRAjmhzv4hRJdpL1NhkTqtBALZPo+RXpCZAHSjjy940dbnZnl3/f0DS+ZRrRx0Hn3S7xVi0PuUPrC+33hzgXZTM2r+S9CsAlbmnpkm6WdJn67qP9rPfzpKukfR7Sf9dk7Mi6QRJf1TJlt5OksmI4TcC8/n0pZNSmYLkncBZ/eyzOvBGyhQjfetWlrRq32NgZ0o3XsfS8omx6FLgWJV5oX4J/NT212uCyx0bRuF8sg4NXZqSRf0VzfvVVtOngDfbflJlosMjJX2LkqD1JbYtaY2uv8oIRixrwXkq07E8B3zI9iOSDgWw3Tcdxz8Bl7pM09JnPeD8mjN1GcrMwe3k0FtMgk+MOS7TX2xFmUphR+Cn6n/21nfXC6LLUFpLLwWa70h/bV3/m/oHtRxwDfAo8DfgNEm/oNzPEdF9IxB8bG/fz7qTm56fQZkzqnHd7cArh6MOCT4xJtXMxVcCV9asyItMYyFpU+Ao4DW2H66Zilfo51ACLrO9z2IbpG2AN1G69D5MyUAd0V3j9FbMXPOJMUfSi2sK+z5bAndR7rruS++/GmUkz6Mqc/Q0TmnduN+1wD/WaQj6+rRfVK/7rG57KvAxhunbXkRHOrzHJ4lFI0bWKpTZOtegTAg4EziEcjPcxZLusb2jpP+jzJZ6N9A4R9MpTfsdAJythTNnfooSoC6QtAKlddTZbeYRw2UMBZROJL1ORESPWmndjfzivTr73jP922NjPp+0fCIietk4bR8k+ERE9LCxdB2nEwk+ERG9KvP5RETEqEjwiYiIbhLpdouIiNGQ4BMREd2mcXo7TIJPRESvyoCDiIgYDbnmExER3ZfgExER3ZaWT0REdF+CT0REdNUYmyahEwk+ERG9LMEnIiK6KRkOIiJidOQm04iI6La0fCIioruS4SAiIkaDFox2DUZGgk9ERC8bpy2fpUa7AhERMTC5s6WtY0qHS5oh6WZJR/SzfQdJj0qaXpdjG7btIunPkmZKOmaorystn4iIXmWGfbSbpC2A9wPbAM8CF0v6he2ZTbv+2vbuTWWXBr4FvAWYBVwvaYrtP3Zaj7R8IiJ62Ai0fDYHrrP9lO15wFXAO9uszjbATNu3234W+Amwx1BeV4JPREQvc4cLTJI0rWE5pOmIM4DtJa0taSVgN2Cjfs78Okl/kHSRpJfVdZOBuxv2mVXXdSzdbhERPWqIGQ7m2N56oI22b5F0InAp8CQwHZjftNvvgefbfkLSbsDPgc06rskg0vKJiOhVdudLW4f1aba3sv0G4GHg1qbtj9l+oj6eCiwraRIwm0VbSRvWdR1LyyciooeNRIYDSevafkDSxpTrPa9t2v484H7blrQNpaHyEPAIsJmkTSlBZ29g36HUIcEnIqKXjcx9PudJWht4DviQ7UckHQpg+2TgXcBhkuYBTwN72zYwT9KHgUuApYHTbd88lAok+ERE9LCRaPnY3r6fdSc3PP4m8M0Byk4Fpi5pHRJ8IiJ6lYEF4zPFQYJPREQvG5+xJ8EnIqKXZUqFiIjovkwmFxER3ZaWT0REdFcmk4uIiG4r6XXGZ/RJ8ImI6GWZyTQiIrotLZ+IiOiuXPOJiIjuaz9T9ViT4BMR0cMy1DoiIrovLZ+IiOgqgzLaLSIiui4tn4iI6LrxGXsSfCIielnu84mIiO5L8ImIiK4ySa8TERHdJZxut4iIGAUJPhER0XUJPhER0VXj+JrPUqNdgYiIGJjsjpa2jikdLmmGpJslHdHP9vdKulHSTZJ+K+mVDdvurOunS5o21NeVlk9ERC8b5m43SVsA7we2AZ4FLpb0C9szG3a7A3ij7Ycl7QqcAmzbsH1H23OWpB5p+URE9Kw6pUInS2ubA9fZfsr2POAq4J2LnNX+re2H69NrgQ2H9WWR4BMR0bvMUILPJEnTGpZDmo46A9he0tqSVgJ2AzYapBYHARc11epSSTf0c+y2pdstIqKXdT7gYI7trQfaaPsWSScClwJPAtOB+f3tK2lHSvDZrmH1drZnS1oXuEzSn2xf3Wkl0/KJiOhhIzHgwPZptrey/QbgYeDWxc4rvQI4FdjD9kMNZWfX/x8AzqdcO+pYgk9ERC8b/ms+1FYLkjamXO85q2n7xsDPgPfZvrVh/cqSVu17DOxM6cbrWLrdIiJ6lYEFI3KT6XmS1gaeAz5k+xFJhwLYPhk4Flgb+LYkgHm1K2894Py6bhngLNsXD6UCCT4RET2r/dZMR0e1t+9n3ckNjw8GDu5nn9uBVzavH4oEn4iIXpb0OhER0XUJPhER0VUjd81n1CX4RET0LIPHZ2bRBJ+IiF6WbreIiOiqdLtFRMSoSMsnIiK6LsEnIiK6a2RuMu0FCT4REb3KwIKMdouIiG5LyyciIrouwSciIrrLGWodERFdZnAyHERERNel5RMREV2Xaz4REdFVdoZaR0TEKEjLJyIius1p+URERHclvU5ERHRbplSIiIhRkft8IiKimwx4nLZ8lhrtCkRExADs0vLpZGmDpMMlzZB0s6Qj+tkuSV+XNFPSjZJe3bBtf0m31WX/ob60tHwiInrYcLd8JG0BvB/YBngWuFjSL2zPbNhtV2CzumwLfAfYVtJawHHA1pSG2Q2Spth+uNN6pOUTEdHLhr/lszlwne2nbM8DrgLe2bTPHsAPXFwLrCFpfeCtwGW259aAcxmwy1BeVlo+ERE96nEevuSXPndSh8VWkDSt4fkptk9peD4DOF7S2sDTwG5A4/4Ak4G7G57PqusGWt+xBJ+IiB5le0itihbHvEXSicClwJPAdGD+cJ+nlXS7RURMMLZPs72V7TcADwO3Nu0yG9io4fmGdd1A6zuW4BMRMcFIWrf+vzHles9ZTbtMAfaro95eCzxq+17gEmBnSWtKWhPYua7rWLrdIiImnvPqNZ/ngA/ZfkTSoQC2TwamUq4FzQSeAg6s2+ZK+jxwfT3O52zPHUoF5HGaNygiInpXut0iIqLrEnwiIqLrEnwiIqLrEnwiIqLrEnwiIqLrEnwiIqLrEnwiIqLr/j/19LdWC5auCQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pbo.sample_collection.count_samples import count_samples\n",
    "from pbo.utils.two_dimesions_mesh import TwoDimesionsMesh\n",
    "\n",
    "\n",
    "samples_count, n_outside_boxes, _ = count_samples(replay_buffer.states, replay_buffer.actions, states_boxes, actions_boxes, replay_buffer.rewards)\n",
    "samples_visu_mesh = TwoDimesionsMesh(states, actions, sleeping_time=0)\n",
    "\n",
    "samples_visu_mesh.set_values(samples_count, zeros_to_nan=True)\n",
    "samples_visu_mesh.show(\n",
    "    f\"Samples repartition, \\n{int(100 * n_outside_boxes / n_samples)}% are outside the box.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train FQI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:02.950221Z",
     "iopub.status.busy": "2022-09-19T11:21:02.950028Z",
     "iopub.status.idle": "2022-09-19T11:21:19.435641Z",
     "shell.execute_reply": "2022-09-19T11:21:19.434841Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49e7f57382aa4b4da146ccd209626958",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0]\n"
     ]
    }
   ],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from pbo.sample_collection.dataloader import SampleDataLoader\n",
    "from pbo.networks.learnable_q import TableQ\n",
    "\n",
    "\n",
    "data_loader_samples = SampleDataLoader(replay_buffer, 1, None)\n",
    "q = TableQ(\n",
    "    n_states=n_states,\n",
    "    n_actions=n_actions,\n",
    "    gamma=gamma,\n",
    "    network_key=dummy_q_network_key,\n",
    "    zero_initializer=True\n",
    ")\n",
    "\n",
    "q_functions = np.zeros((max_bellman_iterations + 1, n_states, n_actions))\n",
    "bellman_iteration_functions = np.zeros((max_bellman_iterations + 1, n_states, n_actions))\n",
    "v_functions = np.zeros((max_bellman_iterations + 1, n_states))\n",
    "\n",
    "\n",
    "q_i = env.discretize(q, q.to_weights(q.params), states, actions)\n",
    "policy_q = q_i.argmax(axis=1)\n",
    "\n",
    "q_functions[0] = q_i\n",
    "v_functions[0] = env.value_function(policy_q)\n",
    "print(policy_q)\n",
    "\n",
    "for bellman_iteration in tqdm(range(1, max_bellman_iterations + 1)):\n",
    "    A = np.zeros((q.weights_dimension, q.weights_dimension))\n",
    "    b = np.zeros(q.weights_dimension)\n",
    "\n",
    "    for batch_samples in data_loader_samples:\n",
    "        state = batch_samples[\"state\"][0, 0]; action = batch_samples[\"action\"][0, 0]\n",
    "        reward = batch_samples[\"reward\"][0, 0]; next_state = batch_samples[\"next_state\"][0, 0]\n",
    "        next_action = policy_q[next_state]\n",
    "\n",
    "        phi = np.zeros((q.weights_dimension, 1))\n",
    "        phi[state * n_actions + action, 0] = 1\n",
    "        next_phi = np.zeros((q.weights_dimension, 1))\n",
    "        next_phi[next_state * n_actions + next_action, 0] = 1\n",
    "\n",
    "        A += phi @ (phi - gamma * next_phi).T\n",
    "        b += reward * phi.reshape(q.weights_dimension)\n",
    "\n",
    "    q_weigth_i = np.linalg.solve(A, b)\n",
    "    q_i = env.discretize(q, q_weigth_i, states, actions)\n",
    "    policy_q = q_i.argmax(axis=1)\n",
    "\n",
    "    q_functions[bellman_iteration] = q_i\n",
    "    bellman_iteration_functions[bellman_iteration] = env.apply_bellman_operator(q_i)\n",
    "    v_functions[bellman_iteration] = env.value_function(policy_q)\n",
    "    print(policy_q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T11:21:19.439440Z",
     "iopub.status.busy": "2022-09-19T11:21:19.439090Z",
     "iopub.status.idle": "2022-09-19T11:21:19.465529Z",
     "shell.execute_reply": "2022-09-19T11:21:19.464724Z"
    }
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(\"figures/data/LSPI/\"):\n",
    "    os.makedirs(\"figures/data/LSPI/\")\n",
    "np.save(f\"figures/data/LSPI/{max_bellman_iterations}_Q.npy\", q_functions)\n",
    "np.save(f\"figures/data/LSPI/{max_bellman_iterations}_BI.npy\", bellman_iteration_functions)\n",
    "np.save(f\"figures/data/LSPI/{max_bellman_iterations}_V.npy\", v_functions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 ('env_cpu': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "af5525a3273d35d601ae265c5d3634806dd61a1c4d085ae098611a6832982bdb"
   }
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "01f6272db69444dea062a900278cd1e9": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_e0d8dec1a9a64f6688c88502c3e27c4b",
       "placeholder": "​",
       "style": "IPY_MODEL_2b74ef5680a24ed1bb120486c9186f51",
       "value": " 2/2 [00:03&lt;00:00,  1.94s/it]"
      }
     },
     "0f2cbf1c7d7344cd9f849521ed9c30b4": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "102d1fd4bb724fc99af13f7df4ecabc0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_0f2cbf1c7d7344cd9f849521ed9c30b4",
       "placeholder": "​",
       "style": "IPY_MODEL_c93f66067e5841db9311ef2174a46d61",
       "value": "100%"
      }
     },
     "2b74ef5680a24ed1bb120486c9186f51": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "49e7f57382aa4b4da146ccd209626958": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_102d1fd4bb724fc99af13f7df4ecabc0",
        "IPY_MODEL_8b31111590cc44bd85dba3b054b21bee",
        "IPY_MODEL_01f6272db69444dea062a900278cd1e9"
       ],
       "layout": "IPY_MODEL_9c4ac75eb3714a489139f5691e4fb29f"
      }
     },
     "4cea0bf3347848e996209531c58a5499": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "8b31111590cc44bd85dba3b054b21bee": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_4cea0bf3347848e996209531c58a5499",
       "max": 2.0,
       "min": 0.0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_a85d7cdb543348b1a5b936adde3d206f",
       "value": 2.0
      }
     },
     "9c4ac75eb3714a489139f5691e4fb29f": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "a85d7cdb543348b1a5b936adde3d206f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "c93f66067e5841db9311ef2174a46d61": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "e0d8dec1a9a64f6688c88502c3e27c4b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
