{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PBO max linear on the chain walk environment\n",
    "\n",
    "## Define parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:38.287271Z",
     "iopub.status.busy": "2022-09-19T13:26:38.287065Z",
     "iopub.status.idle": "2022-09-19T13:26:39.075211Z",
     "shell.execute_reply": "2022-09-19T13:26:39.074437Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "import jax\n",
    "import os\n",
    "import json\n",
    "\n",
    "parameters = json.load(open(\"parameters.json\"))\n",
    "n_states = parameters[\"n_states\"]\n",
    "n_actions = 2\n",
    "sucess_probability = parameters[\"sucess_probability\"]\n",
    "gamma = parameters[\"gamma\"]\n",
    "env_seed = parameters[\"env_seed\"]\n",
    "\n",
    "# Sample collection\n",
    "n_repetitions = parameters[\"n_repetitions\"]\n",
    "n_samples = n_states * n_actions * n_repetitions\n",
    "\n",
    "# Weights collection\n",
    "n_weights = parameters[\"n_weights\"]\n",
    "\n",
    "# Trainings\n",
    "max_bellman_iterations = parameters[\"max_bellman_iterations\"]\n",
    "training_steps = parameters[\"training_steps\"]\n",
    "fitting_steps = parameters[\"fitting_steps_pbo\"]\n",
    "batch_size_samples = parameters[\"batch_size_samples\"]\n",
    "batch_size_weights = parameters[\"batch_size_weights\"]\n",
    "initial_weight_std = parameters[\"initial_weight_std\"]\n",
    "learning_rate = {\"first\": parameters[\"starting_lr_pbo\"], \"last\": parameters[\"ending_lr_pbo\"], \"duration\": training_steps * fitting_steps * n_samples // batch_size_samples}\n",
    "\n",
    "# Visualisation of errors and performances\n",
    "max_bellman_iterations_validation = max_bellman_iterations + 20\n",
    "\n",
    "# Search for an unused seed\n",
    "max_used_seed = 0\n",
    "if not os.path.exists(\"figures/data/PBO_max_linear/\"):\n",
    "    os.makedirs(\"figures/data/PBO_max_linear/\")\n",
    "for file in os.listdir(\"figures/data/PBO_max_linear/\"):\n",
    "    if int(file.split(\"_\")[0]) == max_bellman_iterations and int(file.split(\"_\")[2]) == n_repetitions and int(file.split(\"_\")[3][:-4]) > max_used_seed:\n",
    "        max_used_seed = int(file.split(\"_\")[3][:-4])\n",
    "\n",
    "# keys\n",
    "env_key = jax.random.PRNGKey(env_seed)\n",
    "seed = max_used_seed + 1\n",
    "key = jax.random.PRNGKey(seed)\n",
    "shuffle_key, q_network_key, random_weights_key, pbo_network_key = jax.random.split(key, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:39.078140Z",
     "iopub.status.busy": "2022-09-19T13:26:39.077866Z",
     "iopub.status.idle": "2022-09-19T13:26:39.521090Z",
     "shell.execute_reply": "2022-09-19T13:26:39.520556Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pbo.environments.chain_walk import ChainWalkEnv\n",
    "\n",
    "\n",
    "states = np.arange(n_states)\n",
    "actions = np.arange(n_actions)\n",
    "states_boxes = np.arange(n_states + 1) - 0.5\n",
    "actions_boxes = np.arange(n_actions + 1) - 0.5\n",
    "\n",
    "env = ChainWalkEnv(env_key, n_states, sucess_probability, gamma)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Samples on the mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:39.526178Z",
     "iopub.status.busy": "2022-09-19T13:26:39.525212Z",
     "iopub.status.idle": "2022-09-19T13:26:40.377537Z",
     "shell.execute_reply": "2022-09-19T13:26:40.377037Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from pbo.sample_collection.replay_buffer import ReplayBuffer\n",
    "\n",
    "\n",
    "replay_buffer = ReplayBuffer()\n",
    "\n",
    "for state in states:\n",
    "    for action in actions:\n",
    "        # Need to repeat the samples to capture the randomness\n",
    "        for _ in range(n_repetitions):\n",
    "            env.reset(jnp.array([state]))\n",
    "            next_state, reward, absorbing, _ = env.step(jnp.array([action]))\n",
    "\n",
    "            replay_buffer.add(jnp.array([state]), jnp.array([action]), reward, next_state, absorbing)\n",
    "\n",
    "replay_buffer.cast_to_jax_array()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:40.379699Z",
     "iopub.status.busy": "2022-09-19T13:26:40.379496Z",
     "iopub.status.idle": "2022-09-19T13:26:41.289514Z",
     "shell.execute_reply": "2022-09-19T13:26:41.288975Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEYCAYAAACZaxt6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkmUlEQVR4nO3dd5hkVZ3G8e9LzqAOIDAgLsKKAVAJBlTGgCRxd11JKosJwRVBBBFFUVQW14AREVFZQVCUqARBiSIgoJIRkSAZhgwiwvS7f5zTTk3R3VU1t3u6avr9PE890133nntPVdfU755wf0e2iYiIaGKBya5AREQMvgSTiIhoLMEkIiIaSzCJiIjGEkwiIqKxhSa7AhERMbI3zVjS990/q6cyl13xxC9tbzZBVRpVgklERJ+aef8sLv7l9J7KLLzSX6ZNUHXGlGASEdG3zCwPTXYlupIxk4iIaCwtk4iIPmVgiMHIUpJgEhHRx4YYjG6uBJOIiD5lzKwByZ+YYBIR0cfSzRUREY0YmJVgEhERTaVlEhERjRgyZhIREc0NxlyuBJOIiL5lnDGTiIhoyDBrMGJJgklERL8qd8APhgSTiIi+JWahya5EVxJMIiL6lIGhdHNFRERTaZlEREQj5Q74BJOIiGhoyAkmERHRQFomERHRmBGzBmRB3ASTiIg+lm6uiIhoJN1cERExDsQsj383l6SbgUeAWcBTttcfZb8NgAuB7Wz/bKxjJphERPSpkk5lwsZMZtieOdpGSQsCXwDO6OZgCSYREX1sEru5dgOOAzboZucEk4iIPmXPVTfXNEmXtvx+mO3D2g8NnCHJwHfat0taBfh3YAYJJhERg2+o95bJzNHGQFpsbPt2SSsAZ0q6zvZ5Ldu/Cuxje0jq7vwJJhERU4zt2+u/90g6AdgQaA0m6wM/roFkGrCFpKdsnzjaMRNMIiL6VJkaPL4D8JKWBBaw/Uj9eVPggDnOaz+3Zf8jgF+MFUggwSQioo9NyNTgFYETaqtjIeBo26dL2gXA9qFzc9AEk4iIPjURU4Nt3wisO8LzIwYR2zt1c9wEk4iIPjYr6VQiIqKJJHqMiIhxMTQB6VQmQoJJRESfmojZXBMlwSQiok8ZZcwkIiKam8BEj+MqwSQiok/ZTEgK+omQYBIR0bc0N7m5JkWCSUREnzJpmURExDjIbK6IiGjEiKHM5oqIiKbSMomIiEZM7oCPiIjGNJlrwPckwSQiok+lZRIREeMiLZOIiGjEVlomERHRXG5ajIiIRsqyvenmioiIRpSWSURENFNmcw1Gy2QwQl5ERPS1tEwiIvpY0qlEREQjSfQYERHjIsv2RkREI2XZ3rRMIiKioXRzRUREI2XMJN1cERHRUBI9RkREI4N002KCSURE30o3V0REjIMkeoyIiEYyNTgiIsZFurkiIqKRpFOJiIhxkTGTiIhoJFODIyJiXGTMJCIimnHGTCIioiGTMZOIiBgHaZlEREQjGYCPiIhxkWASERGN5KbFiIgYFxMxAC/pZuARYBbwlO3127a/HdgHUN1vV9uXj3XMBJOIiH7lCe3mmmF75ijbbgJea/sBSZsDhwEbjXWwBJOIiD41WQPwtn/b8utFwPROZQbj1sqIiOjWNEmXtjx2HmEfA2dIumyU7a3eA5zW6aRpmURE9LG5aJnMbB8DGcHGtm+XtAJwpqTrbJ/XvpOkGZRgsnGnk6ZlEhHRp4Znc/Xy6Oq49u3133uAE4AN2/eRtA5wOPAW2/d1OmaCSUREH7PV06MTSUtKWnr4Z2BT4Kq2fVYDjgfeafv6buqZbq6IiD42AVODVwROkAQlBhxt+3RJuwDYPhT4FPAs4JC639OmD7dLMImI6FOegKnBtm8E1h3h+UNbfn4v8N5ejptgEhHRx7rpuuoHCSYREX0r6VQiImIcpGUSERGNJAV9REQ05zIIPwgSTCIi+liW7Y2IiEZMxkwiIqKxzOaKiIhxMChjJsnNFZNK0qclHTXZ9Zhokj4u6fAxtr9d0hnzsk4xGMY7N9dESTCZoiRtLOm3kh6SdL+kCyRtMNn1mh9I2kTSba3P2T6wpqhA0uqSLGmhlu0/sr3pvK5r9Dd7cIJJurmmIEnLAL8AdgWOBRYBXg08MZn1miiSFrL91Lw617w4T0wdgzJmkpbJ1LQWgO1jbM+y/bjtM2xfASBpDUlnSbpP0kxJP5K03HBhSTdL2lvSFZIek/Q9SStKOk3SI5J+JekZdd/hq/CdJd0h6U5Je41WMUkvry2mByVdLmmTlm07SbqxnuMmSW8f5RiflvQzSUdJehjYSdKytZ53Srpd0uckLdhy3AskfbO21K6T9PqW471L0rX1vDdKen/Ltk0k3SZpH0l3AcdQVqVbWdKj9bFyW3fe8CJED9btr6h1+E3LcV8p6ZJan0skvbJl2zmSPlvr/IikMyRNG/MvHgPL7u0xWRJMpqbrgVmS/k/S5sNf/C0E/A+wMrA2sCrw6bZ93gq8kRKY3kz5Av04sDzlc/Whtv1nAGtS1k7YR9Ib2islaRXgFOBzwDOBvYDjJC1f1134OrC57aWBVwJ/HOM1vgX4GbAc8CPgCOAp4HnAS2o9WrOibgT8BZgG7A8cL+mZdds9wFbAMsC7gIMlvbSl7LNrfZ8D7AhsDtxhe6n6uKOtbq+p/y5Xt1/Y9j48s74PX6ekAf8KcIqkZ7XstkOtywqUluWoAToG26B0cyWYTEG2H6Ysw2ngu8C9kk6WtGLdfoPtM20/YfteypfZa9sO8w3bd9cV284HLrb9B9t/p6zc9pK2/T9j+zHbVwI/ALYfoWrvAE61fartIdtnApcCW9TtQ8CLJC1u+07bV4/xMi+0faLtIUoQ2ALYo9bhHuBgYLuW/e8Bvmr7Sds/Af4EbFnfj1Ns/8XFucAZlG7BYUPA/vX9enyMOnVrS+DPto+0/ZTtY4DrKEF72A9sX1/Pdyyw3jicN/qM6S2QJJjEPGf7Wts72Z4OvIjSCvkqQO2y+nHtDnoYOIpyxd7q7pafHx/h96Xa9r+15edb6vnaPQd4W+3ielDSg5Sgt5Ltx4BtgV2AOyWdIun5Y7zE1vM9B1i4lhs+7ncoV/XDbrfn6CT4Zx1r6+2iOlHhQUpgan0/7q1BdLysXM/f6hZglZbf72r5+W88/f2O+YR7fEyWBJPA9nWUbqAX1acOpHwuX2x7GUqLoeklz6otP68GtHf9QAkAR9peruWxpO2Daj1/afuNwEqUK/XvjnG+1v9Xt1ImF0xrOe4ytl/Yss8qklpf42rAHZIWBY4DvgSsaHs54FTmfD/a/w93+j/dafsdlADYajXg9g7lYn4zQLO5EkymIEnPl/QRSdPr76tSup0uqrssDTwKPFTHMfYeh9N+UtISkl5I6ev/yQj7HAW8WdKbJC0oabE6wD29tpbeUsdOnqj1G+rmxLbvpHRNfVnSMpIWUJlk0Np1twLwIUkLS3obZazoVMp4xKLAvcBTkjanjLeM5W7gWZKWHWX7vbXu/zLK9lOBtSTtIGkhSdsCL6DMwOtIZYLETt3sGwNgQJomCSZT0yOUAeeLJT1GCSJXAR+p2z8DvBR4iDIQfPw4nPNc4Abg18CXbD/tBj3bt1IGzj9O+cK9lRLIFqiPPSlX7fdTxnB27eH8O1ICwzXAA5TB+ZVatl9MmSAwE/g88J+277P9CGUywbG13A7AyWOdqLb0jgFurN1qK7dt/1s9xwV1+8vbtt9HGfD/CHAf8FFgK9szO71ISYtQBu0v6rRvDIZBaZnIg3KvfgwkSasDNwELz6t7PXpVr+Lfa3vjya5LU5I2Bv7b9kgTHGLALLbGKl71oF6umeCGbT55me31J6hKo8oNVhHzEdu/AX7TcccYGINyvZ9gEhHRpwYpBX3GTGJC2b7Ztvq1iwvA9hHzQxdXzIcMWL09JklaJhERfWxQurnSMokpR9Khkj45xnZLet44nesISZ8bp2OdI+m9nfeM+UqmBsdkkfRMSSeoJGG8RdIOLdvWlXS1SgLHPVueX1jSxfWek4HTyxet7V1sf3YC6jBHssaI5gYnnUq6ueZP3wL+AaxIydl0iqTLay6r/6EkBbwCuELS0bbvotzDcVy912NcSFrQ9qzxOl7ElJRurpgM9Q7xtwKftP1onSp6MvDOustzgbNqgsY/A6tJek4tc3AXx/+ppLtqavTz6h3tw9uOkPRtSafWmyFnqKRfP07SvSpp49uzCbcee1lJP6z73iJpP0kL1G1zrMiolgWmJH2eknjxmyop3b+p4mBJ90h6WNKVkl7UUs/PtRxrb5XU9HdIendbnRaV9CVJf5V0d+0iW3yEuq8NHAq8otbhwZbNz1DJJfZIbf2t0VLu+ZLOVMn79SdJ23T4E6wh6Xf1NZ2k2ZmNkbR1bXU+WFtqa9fn96nnXaj+vmvdb7EO54rJlnQqMYnWAp6yfX3Lc5cDw1/6VwGbqqRSWZ2Sdv1rwN62n+zi+KdR7hRfAfg9Jb17qx0od3cvDfwW+Hk9/yrA64E9JL1plGN/A1iWkmbktZS71t/VqUK2P0HJXPzBmtL9g5SUJ6+hvB/LAttQ7iafg6TNKC21N9bX1Z4a/6B6jPUo6etXAT41Qh2upSShvLDWYbmWzdtRsgo8g5IF4PP13EsCZwJHU97P7YBDJL1gjJe7I/Buyt37T1HS1CNpLcpd93tQlgE4Ffh5vSP+i5QUNPtJWpOSe+0d45ycMiZKxkxikiwFPNz23EOUL3coX5y7UlorHwZeRUmvclO90j1XJTfViGx/3/Yjtp+grHGyrubMQXWS7Qtq6vcXA8vbPsD2P2zfSEnOuF37cVUWqtoO2Lce/2bgy8xuUfXqyfqan0/J9HBtzdHVbhtKOverambiT7fUScDOwIdt319Tqxw4Uv07OMH27+r06B8xO138VsDNtn9QU83/gZJUctT3n5IIc7iunwS2qe/dtsApdemAJymJKRcHXln/FjtS0sKcDPxvPVcMBPX4mBwZM5n/PEpZv6PVMpSAge1bqOuDSFoCuJByFf8NSvLFU4CrJP3a9v2tB6lfWp+nfNktz+xEi9MoAQuenvp95bYunwUprYh20yhp4ltTr7enXe+a7bMkfZMyfvQcSccDe7ms5dJqZeCytnMOWx5YArhMsxMKq76GXoyWLv45wEZt789CwJFjHKs9lf/ClPdujrT1tock3Up9/2zfLOlsyt/+Wz3WPyZTxkxiklwPLFS7M4atC4y0kNSngO/avpvSirjU9kPAbZQunXY7UBIxvoHSdbR6fX60dOy3Aje1pZRf2vYWPN1MSmuiNfV6a9r1xyhf7MOe3Vb+af/lbH/d9ssoGXfXYuTsx3fy9PT4rXV6HHhhS/2XtT3a2iG9/re/FTi37f1ZyvZYyZja6/pkreccaetrq2pV6vsnaUvgFZREm1/ssZ4xmebHbi6V1N3tV73RR2r3x/HAAZKWlPQqSgCY42q39stvAny7PnUT8DqV1RbXBP46wuGXpvS930f5Yj+wQ3V+BzxSB4AXV0kr/yJJG4xQ71mUzLyfl7R0nRSwJyUtPZQlel8jabXarbZv2yHupiWlu6QNJG0kaWFKIPo7I6esP5ayRvwLaktt/5Y6DVG65Q6WtEI97ipjjPncDUyv4xTd+AUl1fw7VaZmL1zrvfYYZd7RUtcDgJ+1vHdbSnp9fc0fofytfquyPvzhlGWK/4uS5n+kgB79ZoDugO8YTCQdrbIGxJKUwdtrJI3H+hYxcT5A6S+/hzIou6ufvsTtt4DdW6bu7kvpU78aOLBOF273Q0pXyu2UVO5jpjmvx96KMkZwE+UK+nBKq2Yku1G++G+kJCs8Gvh+PdaZlG64KyjdUu1re3wN+E9JD0j6OqVr77uUtPG3UALg067IbZ9GWWHyLMrg+Fltu+xTn79IZdXJXwH/Okr9z6K8f3dJ6pguvo7BbEoZg7mD0h32Bcr6KaM5krKQ2V3AYpS/Gbb/RFnE7BuU9/nNwJtt/wM4jDKWdWpNb/8e4HDVNeXr7LNXt58o+oPd22OydExBL+mPtteT9HbKGhcfAy6zvc68qGBExFS16OrT/exPjjqbfkR/fe8+k5KCvpturoVrs/nfgJPrTJEBGRKKiBhw80s3F/Ad4GZgSeC82pfdPiPmaSR9v94wdlWzKkZETF1yb4/J0jGY1Bkxq9jewsUtwIwujn0EsFnTCkZETFm9zuSaxGDS8T4TSYtSUm2s3rb/AWOVs32eypKtERExVya366oX3dy0eBLlhrTLKFMNIyJiXhmQEepugsl02xPWXSVpZ0rKCrTIIi9beMUVJupUERET7h+33jbT9vLjdsD5KJj8VtKLbV85ERWwfRhlHjyLrraqV9lrj4k4TUTEPHHT7nvd0nmvHsxHwWRjyh3CN1G6uQQ495lEREyw4TvgB0A3wWTzuTmwpGMo6TqmSboN2N/29+bmWBERU9VkTvftRcdgYvsWSetSFh8CON/25V2U275p5SIiprwJCCaSbqZkEp9FWf9o/bbtoqQo2oKS6Xon278f65jd5ObanbIGwwr1cZSk3ebmBURERN+YYXu9UVKvbE5J+LomZYLUt0fYZw7ddHO9B9ioZqNF0hcoa2B8o+sqR0TEIHkL8EOX5I0XSVpO0kqjLDAHdJdORZSm0LBZTOZyXhERU8gEpVMxcIaky+rtGe1WYc6F2G6jw0J13bRMfgBcLOmE+vu/ARlIj4iYF3qfzTVN0qUtvx9Wb8FotbHt2+s6PWdKus72eU2q2c0A/FcknUOZIgzwrqwfHRExD8xdvq2ZnVLQ2769/ntPbShsCLQGk9uZc1XP6cxe9XREo3ZzDa+oKOmZlKzBR9XHLfW5iIiYaOOc6LGuwLr08M+UBdras7ufDOyo4uXAQ2ONl8DYLZOjKavkXdZWRdXf/2WkQhERMX4m4D6TFYETyuxfFgKOtn26pF0AbB8KnEqZFnwDZWrwuzoddNRgYnur+u9z57bGkjajzFVeEDjc9kFze6yIiClpnIOJ7RuBdUd4/tCWnw38dy/H7eY+k19389wI+yxIWWd8c+AFwPaSXtBL5SIiprxBX89E0mLAEpSZAc9g9nTgZegwRazaELihRkEk/Zgyd/maRjWOiJgiJnv1xF6MNWbyfmAPYGXKuMlwMHkY+GYXxx5pnvJGvVcxImIKG/REj7a/BnxN0m62J+xu99b1TIAnbtp9ryZrxk8DZqb8wJbvhzqkfD4DTcv/a4OyTzcftEyGDUlazvaDALXLa3vbh3Qo19U85db1TCRd2ml+9FhSfrDL90MdUj6fgfEoP7dlRzzegASTbtKpvG84kADYfgB4XxflLgHWlPRcSYsA21HmLkdERLcGfQC+xYKSVKeKDc/SWqRTIdtPSfog8EvK1ODv2766UW0jIqaS+WQAftjpwE8kfaf+/n7gtG4ObvtUys0v3WrPH9OrlB/s8v1Qh5Sf3PL9UIfJLj+nAQkmqg2O0XeQFqAMkL++PnUF8GzbPd3QEhERvVlslVW92q579lTmz5/c87Km41Zzo+OYie0h4GJKfq4NgdcB105stSIiAiYsBf24G+umxbWA7etjJvATANsz5k3VIiJiUIw1ZnIdcD6wle0bACR9eJ7UKiIiigEZMxkrmPwHZTrv2ZJOB37MOK6wKOn5lPQqw6lZbgdOtj3PutBqHVYBLrb9aMvzm9k+vYvyG1Jyol1S845tBlxXJx7MTX1+aHvHuSy7MaUb8irbZ3Sx/0bAtbYflrQ48DHgpZR0NwfafqhD+Q8BJ9i+daz9xig/PF38Dtu/krQD8EpKF+phtp/s4hj/QvmcrkpZAfR6SgbUh+emThF9Z4Bmc406ZmL7RNvbAc8HzqakVllB0rclbdrkpJL2YXZw+l19CDhG0seaHLsev2O65PpleBKwG3CVpLe0bD6wi/L7A18Hvi3pfygpZpYEPibpE12UP7nt8XPgP4Z/76L871p+fl89/9LA/l2+h9+npJaGktl5WeAL9bkfdFH+s5QVOM+X9AFJy3dRptUPgC2B3SUdCbyNMja3AXB4p8L173cosFgtsyglqFwkaZMe6xJVXXlvsuvwrMmuQ18ZkPtMOs7mmmPncvf724Btbb++0/5jHOd64IXtV5/1avVq22vO7bHrcf5qe7UO+1wJvML2o5JWB34GHGn7a5L+YPslXZRfj/IldhcwveUq/2Lb63Qo/3tKK+BwykdAwDGUq3Vsn9uh/D/rKOkSYAvb99bFbi6y/eIO5a+1vfZwXWy/tGXbH22v1+n8wMuANwDbAltTcrgdAxxv+5EO5a+wvY6khSit0pVtz1JZZOHyLt6/K4H1apklgFNtbyJpNeCkTn+/eoxlgX0pS1GvQPk73EO5yDio9WbdXkk6zfbmHfZZpp5/OnCa7aNbth1i+wMdyj8b2B8YAj5FuTB6K6V1t3unxYxGWOROlL/hSyjfDfePVb4e45+t+Pp+foUS3K8CPmz77g7lDwK+ZHumpPWBY+vrWRjYsYv/B78HjgeOsf2XTvUdofz6wBcpn8F9KRdZG1JauTt3WlVW0lLARynv+3TgH8BfgENtH9FrfdottvKqXv19vc3m+tMBfTqbq5XtB2wf1iSQVEOUBJLtVqrbOpJ0xSiPKymLv3SywHDXlu2bgU2AzSV9he66856yPcv234C/DHet2H68y9ewPuU/7icoq5idAzxu+9xO/4GG6y/pGfUqTrbvred/DHiqi/JXtbTgLq//qYYnXnTsYiqn8pDtM2y/h/L3PITS1Xdjl/VfhNKaWoLSMoISnBfuojzM7qZdFFiqVuqvPZQ/FngA2MT2M20/C5hRnzu2U2FJLx3l8TLKhUYnP6B81o4DtpN0nKRF67aXd1H+CMoFya2U3oPHKQsanU9ptXUyk/IZHH5cSun2/X39uRutrfgvA3cCb6ZkwPjOiCXmtKXt4TxYX6RcqD4PeGM9XifPAJajdMf/TtKHJY303TKaQ4D/BU4Bfgt8x/aylG7fTimjAH5E+by/CfgMpbfincAMSR17ODoR88Fsrgm2B/BrSX9mdmbh1YDnAR/s8hgrUv6AD7Q9L8qHopO7Ja1n+48AtYWyFeXKZMyr+uofkpaoweRl/zx5uTrrGEzqlOuDJf20/ns3vf09lmV2NmdLWsn2nfVKqZtg+F5KIs/9KF8qF0q6lfL3eG8X5ec4R21lngycXFsKnXyPMsljQUpA/amkGylfoj/uovzhwCWSLgZeTemio3a3dbyirla3/YW213EX8AVJ7+6i/CXAuYz8fi/XRfk1bL+1/nxi7R49S9LWXZQFWHE4CaukD7S8lm9Iek8X5femfGnvbfvKepybPPcL4q3f0qI9WNJ/dVFmIUkL2X4KWNz2JQC2r28JrGN5wPZewF6SXk2Zffp7SddSWiudbiBc2PZpAJK+YPtn9fy/lvSlLs6/eksL5CuSLrH92Xqhdg3w8S6OMbYBGTOZlGDiskTkWpTmZOsA/CW2Z3V5mF8ASw0Hg1aSzumi/I60XcHXD/SOmn23/1heY/uJWq41eCwMdPOfaPictwFvk7QlJb1/t+VWH2XTEPDvXZR/CNipdrU8l/JZuK1Tt0SLbcc49t9G29ayz8GShqeb3yHph5Qus+/a/t3YpUtWa0m/AtYGvmz7uvr8vcBrunwNt0j6KPB/w69b0orATsy5fMJorgXeb/vP7RtqYO5kUUkLDH9+bH9e0u3AedSWVgetPQs/bNu2YKfCtr9c/wYH1/ruT+9fXStI2pMSUJeRZqdeoruej0OAU2t31+mSvkbptnod8MdeKmL7fOB8SbtRguS2dL4b/e8qY8DLUi7K/s32iZJeS5nU0cljkja2/Zt6EXB/rctQ7bJtZoAG4CerZTL8BXxRg/KjXnnZ3qGL8reNse2CLso/McrzM5mL9NW2T6E0tRupX+Q39bD/w8Dlc3Ge63stM8Ix7mj5+UHKuFUv5a8GmuR725bSnXGuZg88301pYb2ti/KfZvQvzN26KP9zypfmr4afsH2EpLuAbpZ9OEnSUrYftb3f8JOSngf8qYvyrRczWwNnUroce/FdSlclwP9R0rffW8dz/tjF+b9Ru6Z3BdaifCetCZxImeTRydM+h/WC9PT66GQXSjfXEKWnY1dJR1AubrtJaLsLcLikNSmfxXfDP1vI3+qi/HyjpwH4iKlC0rtsdzOrbb4przJ5ZA3bVzU9/9zWIeXntPhKq/q57+5tAP7aAwdgAD5iCvnMVCtv+3Hbw4vTNT3/eBxjqpcvBmRq8KR1c0VMNklXjLaJLmYETvXy/VCHQS/fjYyZRPS/pjMCp3r5fqjDoJfvLMEkou81nRE41cv3Qx0GvfzYJrnrqhcZgI+I6FOLP3tVr7FjbwPwV39xcgbg0zKJiOhnA3K9n2ASEdHHBmUAPlODYyBJ+oSkq1Xysf1R0kaS9ugmlUu3+0X0hQGZGpxgEgNH0iuArYCXumQXfgMl/ckedHcHd7f7RUyuXgNJgklET1YCZrbkRpsJ/Cclc/HZks4GUFl759LagvlMfe5DI+y3qaQLJf1e0k9rskwkHSTpmtr66SbpX8S40lw8JkuCSQyiM4BVJV0v6RBJr7X9deAOYIbtGXW/T9RZLesAr5W0Tvt+kqYB+wFvcFnT5VJgT5XU/v9OWXdnHeBz8/g1RhQD0jLJAHwMnLpcwMsoqednAD/RyKtLbiNpZ8rnfCXgBUD7Hcsvr89fUJO8LgJcCDwE/B34nqRfUO4niJjnBmUAPsEkBlLNDHsOcE7NOjtH2n9JzwX2Ajaw/UDNBLvYCIcScKbt7Z+2QdoQeD2lC+2DlAy/EfPWgASTdHPFwJH0rzXl97D1gFuAR5idDn0Z4DHgobpGSesSuq37XQS8qqZtR9KSktaq4ybL2j4V+DCw7kS9nogxpZsrYsIsRVlNcDnKAmc3ADtTVtk7XdIddTzkD5TVHG8FWteoOaxtv52AYzR7Zb/9KAHnJEmLUVovvd2GHDEeBmhxrKRTiYjoU0ussKrX2qa365jLv5V0KhER0WZQWiYJJhER/SzBJCIimkrLJCIimhmg9UwSTCIi+lmCSURENCHSzRUREeNhQIJJ7oCPiIjG0jKJiOhjGpAby9MyiYjoVxO4OJakBSX9oWbFbt+2mqSz6/YrJG3R6XgJJhERfUzu7dGD3YFrR9m2H3Cs7ZcA2wGHdDpYgklERD+bgJaJpOnAlsDhY5x1mfrzspQF5caUMZOIiD42F1ODp0m6tOX3w2wf1rbPV4GPMnsphnafBs6QtBuwJPCGTidNMImI6Ge9B5OZY2UNlrQVcI/tyyRtMspu2wNH2P6ypFcAR0p6ke2h0Y6bYBIR0a8mZj2TVwFb10H1xYBlJB1l+x0t+7wH2AzA9oV1XZ9pwD2jHTRjJhER/Wycx0xs72t7uu3VKYPrZ7UFEoC/UpasRtLalKBz71jHTTCJiOhTw+lUJmg215znkg6QtHX99SPA+yRdDhwD7OQOKymmmysiop9N4E2Lts8Bzqk/f6rl+Wso3WFdSzCJiOhjSfQYERHNZD2TiIgYDxp1Mm5/STCJiOhnaZlERERTGTOJiIhmzITO5hpPCSYREX0sLZOIiGguwSQiIpoYvgN+ECSYRET0KztjJhER0VxaJhER0VyCSURENJWWSURENGNgaDCiSYJJREQ/G4xYksWxIiKiubRMIiL6WMZMIiKiudxnEhERTaVlEhERzWSlxYiIaKrk5hqMaJJgEhHRz7Jsb0RENJWWSURENJMxk4iIaC4p6CMiYhxkanBERDSXlklERDRiUGZzRUREY2mZREREY4MRSxJMIiL6We4ziYiI5hJMIiKiEZN0KhER0YxwurkiImIcJJhERERjCSYREdFIxkwiImI8ZMwkIiKaSzCJiIhmBicF/QKTXYGIiBh8aZlERPQrMzAtkwSTiIh+ltlcERHR1KDM5sqYSUREP7N7e3RJ0oKS/iDpF6Ns30bSNZKulnR0p+OlZRIR0a8MDE1Yy2R34FpgmfYNktYE9gVeZfsBSSt0OlhaJhERfavHVkmXLRNJ04EtgcNH2eV9wLdsPwBg+55Ox0wwiYjoZ70Hk2mSLm157DzCUb8KfJTRh/fXAtaSdIGkiyRt1qma6eaKiOhnvQ/Az7S9/mgbJW0F3GP7MkmbjLLbQsCawCbAdOA8SS+2/eBox03LJCKiXw2PmfTy6OxVwNaSbgZ+DLxO0lFt+9wGnGz7Sds3AddTgsuoEkwiIvqWwUO9PTod0d7X9nTbqwPbAWfZfkfbbidSWiVImkbp9rpxrOMmmERE9LMJmhrcTtIBkrauv/4SuE/SNcDZwN627xurfMZMIiL61cRODcb2OcA59edPtTxvYM/66EqCSUREPxuQO+ATTCIi+lmCSURENDM465kkmERE9CsDQ4ORNjjBJCKin6VlEhERjSWYREREM13f1T7pEkwiIvqVwV3c1d4PEkwiIvpZWiYREdFYxkwiIqIRO1ODIyJiHKRlEhERTXlAWiZJQR8REY2lZRIR0beSmysiIpqa4PVMxlOCSUREP8tNixER0YQBp2USERGN2GmZREREc2mZREREcwPSMpEHZNpZRMRUI+l0YFqPxWba3mwi6jOWBJOIiGgsd8BHRERjCSYREdFYgklERDSWYBIREY0lmERERGP/DyUr3kRsrk8xAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pbo.sample_collection.count_samples import count_samples\n",
    "from pbo.utils.two_dimesions_mesh import TwoDimesionsMesh\n",
    "\n",
    "\n",
    "samples_count, n_outside_boxes, _ = count_samples(replay_buffer.states, replay_buffer.actions, states_boxes, actions_boxes, replay_buffer.rewards)\n",
    "samples_visu_mesh = TwoDimesionsMesh(states, actions, sleeping_time=0)\n",
    "\n",
    "samples_visu_mesh.set_values(samples_count, zeros_to_nan=True)\n",
    "samples_visu_mesh.show(\n",
    "    f\"Samples repartition, \\n{int(100 * n_outside_boxes / n_samples)}% are outside the box.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:41.291837Z",
     "iopub.status.busy": "2022-09-19T13:26:41.291612Z",
     "iopub.status.idle": "2022-09-19T13:26:42.091369Z",
     "shell.execute_reply": "2022-09-19T13:26:42.090800Z"
    }
   },
   "outputs": [],
   "source": [
    "from pbo.weights_collection.weights_buffer import WeightsBuffer\n",
    "from pbo.networks.learnable_q import TableQ\n",
    "\n",
    "\n",
    "weights_buffer = WeightsBuffer()\n",
    "\n",
    "# Add initial validation weights\n",
    "q = TableQ(\n",
    "    n_states=n_states,\n",
    "    n_actions=n_actions,\n",
    "    gamma=gamma,\n",
    "    network_key=q_network_key,\n",
    "    zero_initializer=True\n",
    ")\n",
    "validation_initial_weight = q.to_weights(q.params)\n",
    "\n",
    "weights_buffer.add(validation_initial_weight)\n",
    "\n",
    "# Add randow weights\n",
    "q_random = TableQ(\n",
    "    n_states=n_states,\n",
    "    n_actions=n_actions,\n",
    "    gamma=gamma,\n",
    "    network_key=q_network_key,\n",
    "    zero_initializer=False\n",
    ")\n",
    "\n",
    "while len(weights_buffer) < n_weights:\n",
    "    weights = q_random.random_init_weights()\n",
    "    weights_buffer.add(weights)\n",
    "\n",
    "weights_buffer.cast_to_jax_array()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train max linear PBO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:26:42.093749Z",
     "iopub.status.busy": "2022-09-19T13:26:42.093475Z",
     "iopub.status.idle": "2022-09-19T13:27:16.553305Z",
     "shell.execute_reply": "2022-09-19T13:27:16.552691Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bd1967749f847b1a1fe2c9afd75df09",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/400 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      "[1 1 0 0 1 1 1 0 0 1 1 0 0 0 1 1 1 0 0 0]\n",
      "[1 0 0 0 1 1 0 1 1 1 1 0 0 0 1 1 0 0 1 0]\n",
      "[1 0 0 0 0 1 1 0 1 1 1 0 0 1 1 1 0 1 1 0]\n",
      "[1 0 0 0 0 0 0 1 1 1 1 1 0 1 1 1 1 1 1 0]\n",
      "[1 0 0 0 0 0 0 0 1 1 1 1 0 1 1 1 1 1 1 0]\n",
      "[1 0 0 0 0 0 0 0 1 1 1 1 0 1 1 1 1 1 1 0]\n",
      "[1 0 0 0 0 0 0 0 1 1 1 1 0 1 1 1 1 1 1 0]\n",
      "[1 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n",
      "[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]\n"
     ]
    }
   ],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from pbo.sample_collection.dataloader import SampleDataLoader\n",
    "from pbo.weights_collection.dataloader import WeightsDataLoader\n",
    "from pbo.networks.learnable_pbo import MaxLinearPBO\n",
    "\n",
    "\n",
    "\n",
    "data_loader_samples = SampleDataLoader(replay_buffer, batch_size_samples, shuffle_key)\n",
    "data_loader_weights = WeightsDataLoader(weights_buffer, batch_size_weights, shuffle_key)\n",
    "pbo_max_linear = MaxLinearPBO(\n",
    "    q=q,\n",
    "    max_bellman_iterations=max_bellman_iterations,\n",
    "    network_key=pbo_network_key,\n",
    "    learning_rate=learning_rate,\n",
    "    n_actions=n_actions,\n",
    "    initial_weight_std=initial_weight_std\n",
    ")\n",
    "importance_iteration = jnp.ones(max_bellman_iterations + 1)\n",
    "\n",
    "for _ in tqdm(range(training_steps)):\n",
    "    params_target = pbo_max_linear.params\n",
    "    for _ in range(fitting_steps):\n",
    "        data_loader_weights.shuffle()\n",
    "        for batch_weights in data_loader_weights:\n",
    "            data_loader_samples.shuffle()\n",
    "            for batch_samples in data_loader_samples:\n",
    "                pbo_max_linear.params, pbo_max_linear.optimizer_state, _ = pbo_max_linear.learn_on_batch(\n",
    "                    pbo_max_linear.params, params_target, pbo_max_linear.optimizer_state, batch_weights, batch_samples, importance_iteration\n",
    "                )\n",
    "\n",
    "q_functions = np.zeros((max_bellman_iterations_validation + 1, n_states, n_actions))\n",
    "bellman_iteration_functions = np.zeros((max_bellman_iterations_validation + 1, n_states, n_actions))\n",
    "v_functions = np.zeros((max_bellman_iterations_validation + 1, n_states))\n",
    "\n",
    "batch_iterated_weights = validation_initial_weight.reshape((1, -1))\n",
    "for bellman_iteration in range(max_bellman_iterations_validation + 1):\n",
    "    q_i = env.discretize(q, batch_iterated_weights[0], states, actions)\n",
    "    policy_q = q_i.argmax(axis=1)\n",
    "\n",
    "    q_functions[bellman_iteration] = q_i\n",
    "    bellman_iteration_functions[bellman_iteration] = env.apply_bellman_operator(q_i)\n",
    "    v_functions[bellman_iteration] = env.value_function(policy_q)\n",
    "    print(policy_q)\n",
    "\n",
    "    batch_iterated_weights = pbo_max_linear(pbo_max_linear.params, batch_iterated_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-19T13:27:16.555821Z",
     "iopub.status.busy": "2022-09-19T13:27:16.555581Z",
     "iopub.status.idle": "2022-09-19T13:27:16.568746Z",
     "shell.execute_reply": "2022-09-19T13:27:16.568287Z"
    }
   },
   "outputs": [],
   "source": [
    "np.save(f\"figures/data/PBO_max_linear/{max_bellman_iterations}_Q_{n_repetitions}_{seed}.npy\", q_functions)\n",
    "np.save(f\"figures/data/PBO_max_linear/{max_bellman_iterations}_BI_{n_repetitions}_{seed}.npy\", bellman_iteration_functions)\n",
    "np.save(f\"figures/data/PBO_max_linear/{max_bellman_iterations}_V_{n_repetitions}_{seed}.npy\", v_functions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 ('env_cpu': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "af5525a3273d35d601ae265c5d3634806dd61a1c4d085ae098611a6832982bdb"
   }
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "0c3ec18ceda04eeaac5065d198d6dc55": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "11e8266823c34d60b7b27efba21077e3": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_de378af01c3c461499fcb9cbf9694f34",
       "max": 400.0,
       "min": 0.0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_6d0ee008b055473da422ba9e9978c2cd",
       "value": 400.0
      }
     },
     "13eafe4b708644b180741998874b0158": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "14bbe8c369064d1eabc9e3a1e00ea423": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "DescriptionStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "DescriptionStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "description_width": ""
      }
     },
     "5bd1967749f847b1a1fe2c9afd75df09": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_e8be9278333d4053a3b2695519bd5253",
        "IPY_MODEL_11e8266823c34d60b7b27efba21077e3",
        "IPY_MODEL_a259a4c8a1cb4bf0aa3f231b541ffb9e"
       ],
       "layout": "IPY_MODEL_932e5bd94fdd4c6e953c3562cc87fd22"
      }
     },
     "6d0ee008b055473da422ba9e9978c2cd": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "932e5bd94fdd4c6e953c3562cc87fd22": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "a259a4c8a1cb4bf0aa3f231b541ffb9e": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_13eafe4b708644b180741998874b0158",
       "placeholder": "​",
       "style": "IPY_MODEL_14bbe8c369064d1eabc9e3a1e00ea423",
       "value": " 400/400 [00:18&lt;00:00, 25.09it/s]"
      }
     },
     "ac5413faeeec42568ffdcdaf9ca8bba1": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "de378af01c3c461499fcb9cbf9694f34": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "1.2.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "1.2.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "1.2.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "overflow_x": null,
       "overflow_y": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "e8be9278333d4053a3b2695519bd5253": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "1.5.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "1.5.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "1.5.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_tooltip": null,
       "layout": "IPY_MODEL_ac5413faeeec42568ffdcdaf9ca8bba1",
       "placeholder": "​",
       "style": "IPY_MODEL_0c3ec18ceda04eeaac5065d198d6dc55",
       "value": "100%"
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
