{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b09e6a3-0ed0-418e-b584-e8880e2d2970",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "libEGL warning: Not allowed to force software rendering when API explicitly selects a hardware device.\n",
      "pybullet build time: May 20 2022 19:45:31\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "\n",
    "import gym\n",
    "import d4rl\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('fivethirtyeight')\n",
    "\n",
    "import equinox as eqx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from jaxrl_m.common import TrainStateEQX\n",
    "from src.agents.iql_equinox import GaussianPolicy, GaussianIntentPolicy\n",
    "\n",
    "import optax\n",
    "import ott\n",
    "import wandb\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_psi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.psi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None))\n",
    "def eval_ensemble_phi(ensemble, s):\n",
    "    return eqx.filter_vmap(ensemble.phi_net)(s)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_viz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf_initial)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g, z), g - dim 29, z - dim 256\n",
    "def eval_ensemble_icvf_latent_z(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.classic_icvf)(s, g, z)\n",
    "\n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None)) # V(s, g ,z ), g, z - dim 256\n",
    "def eval_ensemble_icvf_latent_zz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zz)(s, g, z)\n",
    "    \n",
    "@eqx.filter_vmap(in_axes=dict(ensemble=eqx.if_array(0), s=None, g=None, z=None))\n",
    "def eval_ensemble_icvf_latent_zzz(ensemble, s, g, z):\n",
    "    return eqx.filter_vmap(ensemble.icvf_zzz)(s, g, z)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d29df4f8-c9db-4ad7-932a-f52a760a8d33",
   "metadata": {},
   "source": [
    "# PointUMaze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a2402bfb-5e66-45f0-a3c2-1ecd5050072d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbYAAAG1CAYAAACVndHeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA4UElEQVR4nO3de3SV1Z3/8U+ukOTkgkJCFAJJhDoIqOlYpl5GxSp1pctapJdfHcrqtLM6QbpAVhlaWjvptDVStIaCMNracbDepo7+KgO/maWkLYo2llIJrVPRkyBEAnJNOCeEXE5+f2DShuwd2CfPyTnn4f1aiz/4PvvsZz+XnG+ek+/ZO+X48eO9AgDAJ1LjPQAAALxEYgMA+AqJDQDgKyQ2AICvkNgAAL5CYgMA+AqJDQDgKyQ2AICvkNgAAL5CYgMA+MqIJbYdO3bo05/+tEpKSnTRRRfpYx/7mJ5//vmR2j0A4DyRPhI72bp1q+644w6NHj1ac+fOVSAQ0AsvvKAvfvGLam5u1le/+tWRGAYA4DyQEutJkLu7u3XVVVdp//79evHFFzVz5kxJUmtrq2666Sbt3btX27dvV0lJSSyHAQA4T8T8o8itW7eqqalJ8+bN609qkpSfn6+lS5eqs7NTTz31VKyHAQA4T8Q8sb3yyiuSpNmzZw/adtNNN0mStm3bFuthAADOEzH/G1swGJQklZeXD9pWVFSkQCCgxsbGs/azZcsWdXZ29v8/JSXF2M4Wt4l1+5HYR7K393If0ew7lv34WW+v218xvGrv2k80+6D9yO3jbNc5MzOz/yHoXMU8sbW1tUmS8vLyjNtzc3P72wyls7PzvE9syXLMiXiOXJHYzi7eb3heSLREEolEYtp/Mia2aPA9NgCAr8Q8sfU9qdmeyk6cOGF9mgMAwFXMP4rs+9taMBjUFVdcMWDbwYMHFQqFVFFRcdZ+UlJSBnw8FK+P07z8iMqrjxwT7aPFaMY5Eh93xrIfr0QzHi8/mosl12Nz/YgqWc7DUFJTzc8aXn28Z7sGQ/Xv+hqv2p/rdpOYP7Fdc801kqS6urpB27Zs2TKgDQAAwxXzxHb99ddr8uTJevbZZ9XQ0NAfb21t1Q9/+ENlZmbqc5/7XKyHAQA4T8T8o8j09HT96Ec/0h133KHKysoBU2rt27dP3/3udzVp0qRYDwMAcJ4Ykbki//Zv/1b//d//rZqaGj3//PPq6urStGnT9J3vfEdz584diSEAAM4TI5LYJOnDH/6wnn322ZHaHQDgPDViiW24zqyKHKqda79ece3LVgHlVf+xrqL0snI02asi41ld6VW1oVf7jXV1opf7jVeFZbQVgvEQ6+rHs7VPyKpIAABGEokNAOArJDYAgK+Q2AAAvkJiAwD4StJURZ7JtVLmm9/8pjHe3t7uxXA8lYiVUQASTyLOj5mdnW2Mf//73zfGY1FdyxMbAMBXSGwAAF8hsQEAfIXEBgDwFRIbAMBXkqYqcrhzRdqqH8Ph8LDGBQA4O+aKBAAgSiQ2AICvkNgAAL5CYgMA+AqJDQDgK0lTFXkm5lMEgOTHXJEAAJwFiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRFnjlXZLyqIvOmVFi3XVQ+1RjPyMjwZN8tLQfMGxyrh9It4xk3bqzrkIxOnjxpjB8/dtz+ol7L9bSE8/JyjfFAbsAYT001/w5nix8/1mqMd3V1GeNpaWnGeHqa+UfsognFxnhmZqYxPpTOU53G+P73Wozx7p5uY7ynp8cYt92/BWPynfqJRCLGeDhknq+1ra3NGFcUxXL5YwqM8ays0e6dGRw+dNgY7+4yn2sry/0+vni844jMujrN929L027ra9re/r0n+2auSAAAokRiAwD4CokNAOArJDYAgK+Q2AAAvkJiAwD4StKU+5+rWH8NwFbSL0mRrDHGeLjbseTXomzmh43xFFuNsIWt3PvY0eOuQzIabSkDLy+Z4tyX7XqGw+YS8XC7+asGaZbf4dJSzD8CxVMuNsazsrKMcdtXKDIzzP23tYWM8V7L1waGkjrKXF4/bUKpMW4r+e6y3Ke2r28cPXLUGO9OsZT7yxzPLjT/3IwrzTHGo5kc98SJE8Z4e8cp575MLvrQdGM83fF62o7syKEjjiMyS882fx2muNT+vhbrcn8bJkEGAOADJDYAgK+Q2AAAvkJiAwD4CokNAOArSVMVeeYkyPEy1ITGturHU6fMlVeuVT+2CXKdz4t5zlx1dHS49WORYakEzBzlPsGv7dhOnDBXFXacNB+DrR/b5MUXjL3AGE+3HFt6urmftHRz+/ZwuzEezT1uu4/GXGCuTrVNRtzrOLtwu6UC1TYJcm+veb+jLRMR2+6XaKrleo6bx+TVPZ9uuc6ZmW4ToNuOzXWcrvdRtuM4RwKTIAMA8AESGwDAV0hsAABfIbEBAHyFxAYA8JWkqYo8k61SJjU11rnau6q1bsc5JLu7zZVdrlVDtn5cx2Pv39yPV/1LUleXZb5DS9xVyDK3YCRiPneumpv3GeO2azMUa0VmhuO9aik2tFU/2uaQdBXraynF/p609WN/PzKfbK/eK2xVmvHEXJEAAESJxAYA8BXnxPbMM89oyZIluuGGG1RYWKiCggI98cQT1vZtbW1asWKFpk+frsLCQs2YMUP33HOPQiHzF2wBABgO5w9iv/e972nfvn268MILVVRUpH37zH8rkE6vmVVZWaldu3Zp9uzZmjdvnhoaGrRmzRpt27ZNmzdv1ujR5lkHAACIhvMT25o1a9TQ0KBgMKi///u/H7Lt6tWrtWvXLi1ZskTPPfecqqur9dxzz2nJkiXasWOH1q1bF/XAAQAwcU5sN9xwg0pKSs7arre3V48//rgCgYCWLVs2YNuyZcsUCAS0YcOGc95v31yRZ/vn+npvpcT4H/px6rzHuYter+WflX9P9rm+V7v8cxWz4pFgMKiWlhbNmjVLOTkDl3jPycnRrFmztGfPHjU3N8dqCACA81BME5sklZWVGbf3xfvaAQDghZgltra2NklSfr556Yy8vLwB7QAA8ALfYwMA+ErMElvfE1lra6txe9+TWl87AAC8ELMJxcrLyyVJjY2Nxu198b525+Ivq2NslTKxXmW7peWAdVvZzA8b47aVr13nf9vTtMcYd51Szbaqb1m5+e+hduYd27583xhscuzf7oILxhjjZeWlnvR/oOWgU9zVxEkTjXHbvI9Dsd1Hje/sce7LJCsryxgvK5/sSf/Hj5v/HNEU3GN5hfscgmMLxxrj4wrHOfdl8l7ze8Z4V6c3812WOp5r21yRnZbxNDX8znVIznwxV2R5ebmKi4tVX1+vcDg8YFs4HFZ9fb0mTZqkCRMmxGoIAIDzUMwSW0pKiubPn69QKKRVq1YN2LZq1SqFQiEtWLAgVrsHAJynnD+K3LBhg1577TVJ0ptvvilJevzxx/XKK69Ikj760Y/qC1/4giRp8eLF2rx5s2pra9XQ0KDLL79cO3fuVF1dnSoqKlRVVeXVcQAAICmKxPbaa6/pqaeeGhD7zW9+o9/85jf9/+9LbDk5Odq0aZPuu+8+bdy4US+//LKKioq0aNEiLV++3PrZPQAA0XJObOvXr9f69evPuX1+fr5qampUU1PjuisAAJwl3jKrFmfOGRbr6kerISp1UixzvXlVwWnbtWv1kK25+zm1tTfHh1PldK48uy88Otc2tmFGM37ba7w73+Z+UlItY3Xcre2I3cfv3c+mM4/uF/t7hdtK3NZ+bDuO/Y+mlW2sfauPJ9RckQAAxAOJDQDgKyQ2AICvkNgAAL5CYgMA+ErSVEWeq7hVSw7JrXrQqxKlWcf2G+P/ePBtY3zHnguN8f836xZjvHfU6OgGlgwS8TZKFrG9raPYsZ8lzzHHvnr3z3hiAwD4CokNAOArJDYAgK+Q2AAAvkJiAwD4StJURZ7rXJGxropMzzCvPi1J3T095g2dneb21hW0zVVCtpW4bVVFturHG4rMqwYXt7QY46/97w5j/OC0vzaPJ+I2/mjY6qg6LefaVWqa+Xe+zFGWY3As7OrqMl/7SE/ErSNJPRHza7w636lp5lW9O095c657He/3aEQs58ir+8W2YrUzy9uX9Vxb2kci5vHY3qPSM2KfCqJ9b2auSADAeY/EBgDwFRIbAMBXSGwAAF8hsQEAfCVpqiITxbhCc0WhJB07eswY7+joMMbtVZFmZZeUGeO2qiHb3I/F+81zSP7x1CljfE+7OX54b7MxnpsbMMYnlkw0xodmrpg7cuSoMb7PMiZXxReNN8azi7I96f/dpneN8e5uS2XtENLTzVWLk0onOfdl0h5uN8b37X3Pk/4LxuQb4xNLLvakf0k6ePCQMX740BHLK9zKXCeUTDDGXSs7bRXOTcEmp35sVZqjR5vndx07bqy1rwNOe04MPLEBAHyFxAYA8BUSGwDAV0hsAABfIbEBAHwlaaoiE2WuyGTy//5mjjH+2v/+zhjfEzZXPx4pmmzeQQxWvh2M63neittK3HHfuS8xVyQAAFEisQEAfIXEBgDwFRIbAMBXSGwAAF9JmqrIc0VV5J/1jjLPC3fwr64yxg/v3WfpiEowJJCErJbE2djem23zYw4HT2wAAF8hsQEAfIXEBgDwFRIbAMBXSGwAAF9JmqrIRJkr8mS7eTVhSRo9Js8Yz8jIMMa7u7uc9h06ETJvcDzk3kjEGLetfO0qLc18W52wjT8KKZaDzs3NtbzCrfKqq9N8bU70eHMM2TnmlbitFWJDDN92z3t1vnt6zKt6u94vtkPwbPxDnKP0NPMq4wGP7vmOkx3GeOepTqd+ei0H4TpO2wraaZbzcPLYcaf+oxHtezZzRQIAznskNgCAr5DYAAC+QmIDAPgKiQ0A4CtJUxV5pnjNCXl8iOqh8klTjfHMUZnGeHd3t9O+G99pNMZd51rLzDSPZ+KkiU792Jw4ccIYf//A+570L0kXXniBMV5YVOjYk/nctew/YIy3D1EV62LS5EnGeHq6uWptKLb76N2mvc59mWRnmys4iy8e70n/tp+p9w8c8qR/SSosGmeMX+hRVeS+ve8Z452dtqpI831ne18rLS91Go+tKtJWpdnY+JZT/4mOJzYAgK+Q2AAAvuKc2Pbv369169bpU5/6lKZPn65x48Zp6tSpmj9/vrZv3258TVtbm1asWKHp06ersLBQM2bM0D333KNQyLsv7AIAIEXxN7ZHHnlEtbW1Ki0t1Y033qixY8cqGAxq06ZN2rRpk37yk59o7ty5/e3D4bAqKyu1a9cuzZ49W/PmzVNDQ4PWrFmjbdu2afPmzRo92rxuGAAArpwTW0VFhf7rv/5L11577YD4q6++qk9+8pNaunSpKisrNWrUKEnS6tWrtWvXLi1ZskTV1dX97aurq1VbW6t169Zp6dKlwzsKAAA+4PxR5G233TYoqUnS1Vdfreuuu07Hjx/Xm2++Kel0td7jjz+uQCCgZcuWDWi/bNkyBQIBbdiwIcqhAwAwmKfFI32T/fZNtBkMBtXS0qJZs2YpJydnQNucnBzNmjVLe/bsUXNz81n77psEOdp/gFmK5R98xXaZ43b5z7/7biTfsz1LbPv27dOvfvUrjR8/Xpdddpmk04lNksrKyoyv6Yv3tQMAYLg8SWxdXV36yle+olOnTqm6urr/ia2trU2SlJ+fb3xdXl7egHYAAAzXsBNbJBLRwoUL9eqrr2rBggX63Oc+58W4AACIyrASWyQS0V133aWf//zn+sxnPqMHH3xwwPa+J7LW1lbj6/ue1PraAQAwXFHPFdn3pPb0009r3rx5Wr9+vVJTB+bJ8vJySVJjo3mOw754XzsAAIYrqsT2l0lt7ty5evjhh41LjpeXl6u4uFj19fUKh8MDKiPD4bDq6+s1adIkTZgwIfojwHnK3xVknrCdIrc5s89PcTt33tzXtqpCa7VhHH+czjbWaCoknT+K7Pv48emnn9btt9+uRx55xJjU+gY0f/58hUIhrVq1asC2VatWKRQKacGCBc6DBgDAxvmJbeXKlXrqqacUCAR0ySWXDEpYklRZWamZM2dKkhYvXqzNmzertrZWDQ0Nuvzyy7Vz507V1dWpoqJCVVVVwz8KAAA+4JzY9u49vcZTKBTS/fffb2xTUlLSn9hycnK0adMm3Xfffdq4caNefvllFRUVadGiRVq+fLmysrKGMXwAAAZyTmzr16/X+vXrnV6Tn5+vmpoa1dTUuO4OAAAnrMcGAPCVqMv9R9qZ84e5Vv184QtfMMa7u7ud+skqnGgdY/7YQmPctkx7r6XCylYEFCozL2/vei5s4wnk5jr1k5pqjtvOacfJDmN8qH3Y4qMtH2H3rSoxuB+3/Z5sP2mM2+8Xc/+2crNAbsAYP/MrM+ciEokY46ETlxnjvbbSPks4Pd1cHJaVnW3uxnJj2+J/bHrPGH+5Ybd5QFHo6uwyxo8cPupJ/4HcHGM8RebrbGMrujx6xG2cffP2DhqP5UYd6rvEh532bOda3TiiVZEAACQyEhsAwFdIbAAAXyGxAQB8hcQGAPCVpK2KdHX99dcb47YqNHsloP13Ads24qcNdf286su1vVdxm1i3l+zVhrFu7xq3VW/a7Nr3vlP7oRw8YO4rdCLkSf8TSy42xjMzM536sZ27xuAep34yMsxv7aNHjzbGcwLmqs5kxRMbAMBXSGwAAF8hsQEAfIXEBgDwFRIbAMBXkqYqcqRFU50W60q6ZO9nqIrSWFctJlrcJp5Vka79uIqmWtYrcVsgOk4rcaekWKqMLYtCp/Ym3jMOc0UCAPABEhsAwFdIbAAAXyGxAQB8hcQGAPCVpK2KTKaqNde+Eq1qcSTOtetrYj1Wr/p3NSIVgpZ9xHruR9fx+FqMqyXTLNWP6ZZ4Wq857iXX+244eGIDAPgKiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRFDncF7Wj25xKP5jXJMvdjrKsrvRxTolVXukqmCkHbuXOtljzZ3m6Mt7x3wHlMNlk5WcZ4cW7Ak/6PHWs1xiM9PY49ma9/8UXjnXoJBCzHZSlAbHnnj079e+ls9zxzRQIAznskNgCAr5DYAAC+QmIDAPgKiQ0A4CtJUxUJu0Sr4BuJ1ceTJe4qmaoibVyPobvbXDnYbqmWjEYgz1wlmJ2T7Un/Rw4fNcY7Ozud+rGdu+KLi8wvsFQ55lqqPW3nuqur+6xjSyY8sQEAfIXEBgDwFRIbAMBXSGwAAF8hsQEAfMV3VZHxqmaLRqKNKZ6rjydaBWcyVT96tTJxrI/NOp7kLwSNH8u5S0+3vbWbX2BbcTtZ8cQGAPAVEhsAwFdIbAAAXyGxAQB8hcQGAPAV31VFxlo0lX0jse949BNPyVL96mXlqI2t2tCraklXztfGVtpnOxWxHb4vZGRmGOOpKeZnmfS02KeCkbwfeWIDAPgKiQ0A4CvOia2jo0MrVqzQrbfeqksvvVRFRUWaOnWq5syZo5/97Gfq6uoa9Jq2tjatWLFC06dPV2FhoWbMmKF77rlHoVDIk4MAAKCPc2ILh8P66U9/qpSUFN1yyy2666679IlPfEL79+/XokWL9NnPflaRSGRA+8rKSq1bt05Tp07VwoULNWXKFK1Zs0a33XabOjo6PD0gAMD5zfkvhmPGjNHevXuVmZk5IN7d3a3bb79ddXV1evHFFzVnzhxJ0urVq7Vr1y4tWbJE1dXV/e2rq6tVW1urdevWaenSpcM7CgAAPuCc2FJTUwclNen03GSf+MQn9Morr6ixsVHS6WqXxx9/XIFAQMuWLRvQftmyZfrJT36iDRs2kNiGKV6VgCOxgnay8Kr60cvzkGjVkjYBy2rPkyaXeLaPY0eOGeNHDx/xpP/C8eYVrjMy3N5ibdfg3aa9Tv30dJn7yc42rxh+0cXF1r7ecdpzYvCseCQSiWjLli2SpGnTpkmSgsGgWlpaNGvWLOXk5Axon5OTo1mzZmnPnj1qbm72ahgAgPNc1F9e6Ozs1AMPPKDe3l4dO3ZMv/71r7V7927deeeduv766yWdTmySVFZWZuyjrKxMW7ZsUTAY1IQJE6IdCgAA/YaV2FauXNn//5SUFH31q1/VP//zP/fH2traJEn5+fnGPvLy8ga0AwBguKL+KDIQCOj48eM6evSo/vjHP+r+++/Xhg0b9IlPfIJEBQCIm2H/jS01NVUXX3yxvvSlL2n16tX6zW9+owceeEDSn5/IWltbja/tS4B97QAAGC5PJwi78cYbJUmvvPKKJKm8vFyS+qskz9QX72vnIlnmZfS6r1hKxHMa62rDeK0aHs/7y7Va0iu2/lNTLfMXWleBdmc75u7uHk/6T/PoGLwap+1cp6WbV8pOHzW40n2kxKJK19MptQ4cOCBJysg4PQFneXm5iouLVV9fr3A4PKBtOBxWfX29Jk2aROEIAMAzzontT3/6k9rb2wfF29vb9c1vflOSdPPNN0s6nYnnz5+vUCikVatWDWi/atUqhUIhLViwIJpxAwBg5Pys//zzz2vdunX6m7/5G5WUlCg3N1f79+/XSy+9pKNHj+qjH/2oFi5c2N9+8eLF2rx5s2pra9XQ0KDLL79cO3fuVF1dnSoqKlRVVeXpAQEAzm/Oie3jH/+4Dhw4oNdff12vv/66wuGw8vLydNlll+mOO+7Q3/3d3w34XDknJ0ebNm3Sfffdp40bN+rll19WUVGRFi1apOXLlysrK8vTAwIAnN+cE9uVV16pK6+80uk1+fn5qqmpUU1NjevuAABwwgra6Bfrarl4zUUIWFfjhi+x0CgAwFdIbAAAXyGxAQB8hcQGAPAVEhsAwFdIbAAAX0nacv9kKk2nzP3sbOfI9Tq79uMVryZyjeb+9eqYXfvxKh6JRIxxryYoluznNd0yKbCrHusxdDv1Y7tkruO0neueHvM5jXR2OvXvpVj8bPLEBgDwFRIbAMBXSGwAAF8hsQEAfIXEBgDwlaStisSfeVWdFuv9elWxmoi8qt70ct9etY+10ImQMf7unnc920dhUaExPq5wnCf973u32RjvdKw2tN0vZZeUOvUzfvx4Y7zHUqX55v82OPWf6HhiAwD4CokNAOArJDYAgK+Q2AAAvkJiAwD4iu+qImNd8TVU/8k0f2Wi7TfWlX3xqtRMpmrJWO/XPp7EqtL0g66uLmO8xzL/ZneP25yWiY4nNgCAr5DYAAC+QmIDAPgKiQ0A4CskNgCAr/iuKtImEec1TLQxxbqCb6j2rvtwPRexnjfTq/F4uYK2V/3Eei7SBJu6MjoJNg2qbeXu7i5z3LaytpdGskqXJzYAgK+Q2AAAvkJiAwD4CokNAOArJDYAgK+cN1WR8ZRoK1l7V83m3XEl2irgySRZ7iNbPD3d/DaUnZ1tjEcjYqn6a29v96T/zFGZxrjt2KzTY1qqK9vDbuNMTU2z7Ne844yMDKf+Ex1PbAAAXyGxAQB8hcQGAPAVEhsAwFdIbAAAX0maqsje3t5zqv6K9byGI7GCdryq07yauzISiRjjqan236PiVRUZzVgTTbyqIm3nzrWfrOwsY7z4omJj3M5+Ht4/+L4xfvjwEcd9mE0smWCMZ2aaqyVtbOeo8Z0mp36yslqNcVulacGYfKf+R0LfuYjm/k6en14AAM4BiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRFniuvVoGOproy0aoWbRJtv9HsI9bVjLb+bfeFa9xLiXb941XhmnDLWMeRbUXsnh7zCtoRme93L8X++v8ZT2wAAF8hsQEAfMWTxFZbW6uCggIVFBTot7/97aDtbW1tWrFihaZPn67CwkLNmDFD99xzj0KhkBe7BwCg37AT25tvvqmamhrl5OQYt4fDYVVWVmrdunWaOnWqFi5cqClTpmjNmjW67bbb1NHRMdwhAADQb1iJraurS1VVVZoxY4YqKyuNbVavXq1du3ZpyZIleu6551RdXa3nnntOS5Ys0Y4dO7Ru3brhDAEAgAGGldjuv/9+/elPf9LatWuVljZ4xdbe3l49/vjjCgQCWrZs2YBty5YtUyAQ0IYNG85pX31zRZ7rnJFne/3Z+nFt7yWv9u3VMUciEeM/W/t4/ovXWG37TaZ/8btuMv7zVorlX5JwHL7tXPdEIsZ/Q90XySjqxPbGG2/ogQce0PLly3XppZca2wSDQbW0tGjWrFmDPqrMycnRrFmztGfPHjU3N0c7DAAABogqsZ06dar/I8jFixdb2wWDQUlSWVmZcXtfvK8dAADDFdUXtO+9914Fg0H96le/Mn4E2aetrU2SlJ9vXhIhLy9vQDsAAIbL+Ynt9ddf15o1a/S1r31N06ZNi8WYAACImlNi6+7uVlVVlS677DLdfffdZ23f90TW2mpe9K7vSa2vHQAAw+X0UWQoFOr/e9i4ceOMbW6++WZJ0s9+9rP+opLGxkZj2754eXm5yzCisnXrVmO8u9s8d5ptjr/swonWfeSNLTLG09Mtp9lS+mXbt+0L7bbiKFs/tvEEcnOd+kmxzMvY3d1ljJ86af/Ooutci6OzzKsujxo1ytKP237b208a4z2W+8W2A9u1sZ3r1FT3Sr1IxHwfnThxwqkfW8Ws7X6xrXxt68dW6fjHpveM8ePHzL8QD7VSto3tGArGFDj3ZRIOtxvj7Za4q4KCAqf2GRkZxrjt/godPOw6pJjru4+iqQh3SmyjRo3S/PnzjdteffVVBYNB3XrrrRo7dqxKSkpUXl6u4uJi1dfXKxwOD6iMDIfDqq+v16RJkzRhgnlZdQAAXDkltqysLK1Zs8a4raqqSsFgUEuXLtVVV13VH58/f75+8IMfaNWqVaquru6Pr1q1SqFQSEuXLo1u5AAAGMR82ZrFixdr8+bNqq2tVUNDgy6//HLt3LlTdXV1qqioUFVVVayHAAA4j8R8dv+cnBxt2rRJVVVV2r17t9auXavdu3dr0aJF+sUvfqEsy99KAACIhmdPbOvXr9f69euN2/Lz81VTU6OamhqvdgcAgFHSrKB95jyHrpUytjkpXZfOGffhW6zbyq/8iDGeOSrTGLdVZNo0vWOuLnU9F5mZ5vFMKLFXfLoIWarx3j/wfhS9mY/tgrEXGuMFY8Y49WPT8l6LMX6y3Zsqt5LSSca4tYJ2CLb7aE/Tu859mWRnZxvjxRcXe9L/8WPHjfEjR4540r8kFRYVGuO5uQHHnsz30b695mkBOzs7nXq3VemWlZc69WOriuw8ZR5PS5tbBW007NWyQ8ejqYpkoVEAgK+Q2AAAvkJiAwD4CokNAOArJDYAgK/4rirSVlWEc2A7dSOzcLiFV9czIQ8upmy/tSbnmsiJIjneX1wrEP32Y8ATGwDAV0hsAABfIbEBAHyFxAYA8BUSGwDAV5KmKvJMzlU/iF5CFhTGulrSv6iWxCApsf9hZq5IAACiRGIDAPgKiQ0A4CskNgCAr5DYAAC+QmIDAPhK0pb7x0vBmALrthNtbcZ4T4+5kLq7u8tp32MLxxnjrvM+RyLm8tn3Dxx068giLc18WxWOL7S+JtVyELZK31Od5iXuvTqG7JxsYzw3L9fyCreS5GNHjpl7iaK02Tbx91Dn20V3T48x7tW5Ts/IMMa9Gr8kdXZY7pf29z3pPz8/3xhPSXX9OonlZ/Og2zjT080/g2lpacb4UO9rh5327C4WX9HiiQ0A4CskNgCAr5DYAAC+QmIDAPgKiQ0A4CtJUxXZ29t7TtUzsZ4EOSs7y7qtveOUMd7R0WGMd3d3O+17XJG5SsxWFWfTecpcIXb4/UNO/dgEAuaKwtSuVutrIqdOGuPZBRcY4ynpo43xA0dDZxndmcz3SyDXXP1oq5Z0rYo8cvioMd7jeE9IUpqlAm5ckbmK1lV7uN0YP3zoiCf92yryci3XIBrvh80VnCdOuN4vZrZjyMzMdOqnt9dcQf3+QbefTVtV5OjR5p+bnCz7+5pXmAQZAIAokdgAAL5CYgMA+AqJDQDgKyQ2AICvJE1V5LmKdVUk/iy11zyHYNmeTcb4ZaPsVXQnTpmrwbb/KccYP1pUYYyn5RYb4z3W28J1Lj/Xfvx7P9p+KzZfSZwbr+7HxDOS7808sQEAfIXEBgDwFRIbAMBXSGwAAF8hsQEAfCVpqiLPda5IDIdbZd/syA5j/K4rzHMLFgUC1j23Wqoi65vN82+u3rXdGG+feosxnp5hniPPdktlpMW6WtK/zs9qyfPvOrtirkgAAKJEYgMA+AqJDQDgKyQ2AICvkNgAAL6SNFWRieLQEKtMX3zpTGPctpqt6wra7+17z7zBsWooPSPDGJ9QMsEYT+0xr7h9Z9f7xvj0EnMFYmphgXVMhZaVr221Zr8/Yq6W3N35e2P8/UPmc9QRMV+b/UfGG+NpReWWEbldg6LxRcZ4aqr775qRiLnesPndfc59mWSOMq8CPaFkoif9h8JhY3yfR+OXpPyCPGO8YMwYT/o/ctg8D6rrz3iK5Yaf6Hiu09PTnMbz3lt/cOo/0fHEBgDwFRIbAMBXokpsM2bMUEFBgfFfZWXloPanTp3SypUrVVFRoaKiIl166aVavHixDh2yf6wHAEA0ov4bW15enqqqqgbFS0pKBvw/Eono85//vLZs2aKrrrpKt912m4LBoDZs2KBf//rXeumllzR27NhohwEAwABRJ7b8/Hx94xvfOGu7J598Ulu2bNG8efP04x//WCkf/HX0pz/9qZYuXarvfe97qq2tjXYYAAAMEPOqyA0bNkiSvv3tb/cnNUn64he/qB/96Ef6+c9/rpqaGmVlZQ3ZT6LMFdndZa9ySk8zn87MTHNVWWqa2yfBXZ3m6kSvzout+i2t27xSdlGmefypAXPVpSaZKwElKaXXXNlXcPikMX7tdHM1233/51JjPDPdfI56Iub4154xH/MLQXM1pitbZaqtmm0otkq3Tsv94irNUtVru19cpbSb5xb1avySlGKpNvXqGLy6BimWskjbe4iN9T6yvFV0d3U59R+NpJgrsrOzU0888YQeeOABPfLII9q+ffCktB0dHdq+fbumTJky6CPKlJQU3XjjjQqHw/r9780l2gAAuIr6ie3gwYO66667BsQqKir06KOPqrS0VJLU1NSkSCSisrIyYx998WAwqKuvvjraoQAA0C+qJ7Y777xTv/jFL/T2229r//792rp1qz772c9qx44duu2223TixAlJUltbm6TTf48zycvLG9AOAIDhiuqJ7etf//qA/8+cOVMPP/ywJOmZZ57Rv//7v2vRokXDHx0AAI48/YL2F7/4RUlSfX29pD8/kbW2thrb9z2p9bUDAGC4PK2KvPDCCyVJ7R9UOU2ePFmpqalqbGw0tu+Ll5fb5t9zlwiVk+fMNtQEW4y3J91csbq9o8QYn9J+wBhPfct8H0hSb9hcVXYwZK5OfPrQdGP80f97pTE+Ks1cdVmYba5a23X0qDEudVjiAIYyku/Nnj6x9VVG9lVAZmVl6cMf/rDefvtt7d27d0Db3t5e/fKXv1ROTo6uvNL8ZgQAgCvnxLZ79+7+J7Iz49XV1ZKkefPm9ccXLFggSfqXf/mXARn73/7t37Rnzx59+tOfPut32AAAOFfOH0X+53/+p9atW6err75aEydOVHZ2tt555x29+OKL6urq0tKlS3XNNdf0t//85z+v559/Xs8++6zeffddXXPNNWpsbNTGjRs1adIkfetb3/L0gAAA5zfnxHbddddp9+7damho0Guvvab29nZdeOGFuvnmm/XlL39Zs2fPHtA+NTVVTz75pB588EE988wzWrduncaMGaP58+frW9/6FvNEAgA85ZzYrr32Wl177bVOrxk1apS+/vWvD/qaAAAAXkuaFbR7e3sHrBRsW2nYNteaV4bqvtdS5ug6R5q9WtK8c+cjtvTjWrX0ZMbNxnhkz4vG+EdGWVYAl7Q/bB7Tjw9+yBhvumCqMZ7fbe7nVLd57ry2U+a/77Z1mtt7dn+53hND9mUOezVWWzfeVbmZ+/H0Z9n2o+bRMdjG6noMXp1r63uO/WZx6t9LCTVXJAAAiYjEBgDwFRIbAMBXSGwAAF8hsQEAfCVpqiLPFK85IcdfVGzdduT9w8Z4R4d5fkHbqrs2pZeY17VLSXWraOo8ZZ4fsekd+1yOLu7PnWWMF463r6CtXLd4z6EjxrhXx2C7zkNdfxfvNu0xxnsc7wnJvsK17X5x1R42r3Dteq5tP7H5YwqM8TKPxi9J7x84aI4ffN+T/ieWTDDGXVe+tr2vNQabbK8wRtMt98To0aON8eLi8dYxmd/V3EW7gnY0eGIDAPgKiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRF9vb2DqtKJqlW1naVJCtxe8nHh4akZLsjbXH/vh95Nq8lc0UCAHAaiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRFnikxqxxtKwGbW9vmc7OxtXddpTeSHjHGXcdjk56eEdP+JSk9w9yXV+fINsefLe66gnKWZc6+np6ecxjdQGlp5tW+bfMCus7N191lnr8yI8N8nW1s/duumaf3S4zvSfdjcLsGrv14uvq4RyIR8/tOLN7LeWIDAPgKiQ0A4CskNgCAr5DYAAC+QmIDAPhK0lRFnjlXZLyqfro6u6zb0rMDjr05rnzdaV75OsWxn+4ec5WbrYrOVVqa+fcl28rd0bBdf9sxWO8XS9hWweV6DWy7HTV6lGW/7hVitmOznW9rVaSlws52LkZnWe4XyyFELPtNtawA7+X9Yrsnvbrnu20rn7vOm2g5efZxuvVvq67sOml/X4sX5ooEAOADJDYAgK+Q2AAAvkJiAwD4CokNAOArSVMVeaZ4zRW5P/gn67aLyi81xrMz3ebUs2na+Vtj3PVU2OZZHFc4znVIRiePHjPGg0OcO1e5ebnGeCDXXJnqWkV7MPimMd7Zccr8AlvRpWW/RePHG+Np6eZ5H4fSY5nL8Z3fv22MW392LOFMSwXnBReOMb/Aci5st+mJg0eM8QNtbZZXuCsYU2CM52Rne9L/e2/tMsa7uxyrDS33S3Gx+X6xM59tW/Xj/uBux/7djeR7Nk9sAABfIbEBAHyFxAYA8BUSGwDAV0hsAABfIbEBAHwlacr9z3US5FiXlLa+vSOqbcmgJd4DcHAo3gMYpgPxHoAH9sV7AA4Ox3sAw5Ts45fc35uZBBkAgA+Q2AAAvkJiAwD4CokNAOArJDYAgK8kTVXkmWyVMq6T3QIA4icWlew8sQEAfGVYiW3jxo26/fbbVVpaqqKiIs2cOVNf+tKX1NzcPKBdW1ubVqxYoenTp6uwsFAzZszQPffco1AoNKzBAwBwpqg+iuzt7dXdd9+txx57TKWlpbrjjjsUCATU0tKibdu2ad++fZowYYIkKRwOq7KyUrt27dLs2bM1b948NTQ0aM2aNdq2bZs2b96s0aNHe3pQAIDzV1SJ7V//9V/12GOP6ctf/rJWrlyptLSBiyN2d/954cPVq1dr165dWrJkiaqrq/vj1dXVqq2t1bp167R06dLoRg8AwBmcP4o8efKkVq5cqcmTJ+u+++4blNQkKT39dL7s7e3V448/rkAgoGXLlg1os2zZMgUCAW3YsCHKoQMAMJjzE1tdXZ2OHz+uO++8Uz09Pdq8ebOCwaDy8/N1ww03qKysrL9tMBhUS0uLbrrpJuXk5AzoJycnR7NmzdKWLVvU3Nzc/9GlzXDnisz2aAn4kUBlJ4BzEeu5caNhe68dybkinRPbG2+8IUlKS0vTNddco3feead/W2pqqhYuXKjvfe97kk4nNkkDkt1fKisr05YtWxQMBs+a2AAAOBfOH0UePnx6numHHnpIeXl5qqurU3NzszZv3qxLLrlEa9eu1aOPPirpdDWkJOXn5xv7ysvLG9AOAIDhck5skUhEkpSZmaknnnhCFRUVCgQCuvrqq/XYY48pNTVVa9eu9XygAACcC+fE1veUdcUVV6i4uHjAtmnTpmny5MlqamrS8ePH+9u2trYa++p7UutrBwDAcDkntilTpkiyf7zYF+/o6FB5ebkkqbGx0di2L97XDgCA4XIuHrnuuuskSbt37x60raurS42NjcrJydHYsWNVVFSk4uJi1dfXKxwOD6iMDIfDqq+v16RJk6IqHHGdK/L73/++U3ubodp71Zdr3Kv+veonmqpOL69DPPqJV/9e8qrCLl79RLNf22tc417179qPTd+fjLzoP9ZjTYi5IktLSzV79mw1NjYO+g7agw8+qNbWVlVWVio9PV0pKSmaP3++QqGQVq1aNaDtqlWrFAqFtGDBguEdAQAAfyHl+PHjzumyqalJt9xyiw4dOqQ5c+ZoypQpamho0NatWzVx4kS99NJLKioqknT6yWzOnDn6wx/+oNmzZ+vyyy/Xzp07VVdXp4qKCm3atElZWVln3ef//M//qLOz8+wHFOPf/Hli875/L8cU6/0mWv9e4okt+rhX/bv2Y+OnJ7bMzEzNmTPHqc+oJkEuLS3VL3/5S33+85/XG2+8oYcffliNjY36h3/4B9XV1fUnNen0F7E3bdqkqqoq7d69W2vXrtXu3bu1aNEi/eIXvzinpAYAwLmK6oktHnhi44ltuPuIZT/x6t9LPLFFH/eqf9d+bHhiAwDAR5JmBe3hzhUZ6/bxFK+xermKuZdPfy5ivRK76zWIZr+Jdk8m4hNbvJ6cYv2E59VxRWOkzkU0x8ITGwDAV0hsAABfIbEBAHyFxAYA8BUSGwDAV5K2KtImXtWSXu7Dtf/U1OT//cT13MXre2+xlmgVjkOJdfVjPL/Hlmjtz8dzRFUkAAAfILEBAHwlaT6KzMzMHPD/eH1ElYjTRSV7+2j6SqZj8Cs+ivR/+5HYx9mu/5nv/eciaeaKBADgXPBRJADAV0hsAABfIbEBAHyFxAYA8BUSGwDAV0hsAABfIbEBAHwl4RPbjh079OlPf1olJSW66KKL9LGPfUzPP/98vIcFB/v379e6dev0qU99StOnT9e4ceM0depUzZ8/X9u3bze+pq2tTStWrND06dNVWFioGTNm6J577lEoFBrh0WM4amtrVVBQoIKCAv32t78dtJ3rnLw2btyo22+/XaWlpSoqKtLMmTP1pS99Sc3NzQPaxeMaJ/QXtLdu3ao77rhDo0eP1ty5cxUIBPTCCy9o3759+u53v6uvfvWr8R4izkF1dbVqa2tVWlqqa6+9VmPHjlUwGNSmTZvU29urn/zkJ5o7d25/+3A4rI9//OPatWuXZs+erZkzZ6qhoUF1dXWqqKjQ5s2bNXr06DgeEc7Fm2++qRtvvFHp6ekKh8N68cUXddVVV/Vv5zonp97eXt1999167LHHVFpaqptuukmBQEAtLS3atm2bfvzjH+ujH/2opPhd44RNbN3d3brqqqu0f/9+vfjii5o5c6YkqbW1VTfddJP27t2r7du3q6SkJM4jxdm88MILuuCCC3TttdcOiL/66qv65Cc/qZycHL311lsaNWqUJOnee+/VD37wAy1ZskTV1dX97fsS5Le//W0tXbp0JA8Bjrq6uvSxj31MGRkZKisr03/8x38MSmxc5+S0fv16feMb39CXv/xlrVy5UmlpaQO2d3d3Kz399GyN8brGCZvY6urqNHfuXN1555166KGHBmx78skntXDhQn3jG9/Q8uXL4zRCeGHu3Lmqq6vTL3/5S1155ZXq7e3VtGnTdOLECb311lvKycnpbxsOh/WhD31IY8eO1RtvvBG/QeOsampqVFtbq1//+tdavXq1nnrqqQGJjeucnE6ePKm/+qu/UkFBgbZv396fwEzieY0T9m9sr7zyiiRp9uzZg7bddNNNkqRt27aN6JjgvYyMDEnq/60vGAyqpaVFs2bNGvCDIEk5OTmaNWuW9uzZM+hzfCSON954Qw888ICWL1+uSy+91NiG65yc6urqdPz4cVVWVqqnp0cvvPCCHnzwQf30pz9VY2PjgLbxvMYJm9iCwaAkqby8fNC2oqIiBQKBQScSyWXfvn361a9+pfHjx+uyyy6T9OfrXlZWZnxNX7yvHRLLqVOnVFVVpRkzZmjx4sXWdlzn5NT3dJWWlqZrrrlGX/jCF/Sd73xHS5cu1V//9V/rW9/6Vn/beF7jhE1sbW1tkqS8vDzj9tzc3P42SD5dXV36yle+olOnTqm6urr/ia3vmubn5xtf13c/cO0T07333qtgMKiHHnpo0N9e/hLXOTkdPnxYkvTQQw8pLy9PdXV1am5u1ubNm3XJJZdo7dq1evTRRyXF9xonbGKDf0UiES1cuFCvvvqqFixYoM997nPxHhI88Prrr2vNmjX62te+pmnTpsV7OIiBSCQi6fQaaU888YQqKioUCAR09dVX67HHHlNqaqrWrl0b51EmcGI7WzY/ceKE9WkOiSsSieiuu+7Sz3/+c33mM5/Rgw8+OGB73zVtbW01vv5sT/KIj+7ublVVVemyyy7T3Xfffdb2XOfk1Hc9rrjiChUXFw/YNm3aNE2ePFlNTU06fvx4XK9xwq6g3fe3tWAwqCuuuGLAtoMHDyoUCqmioiIOI0O0+p7Unn76ac2bN0/r169XaurA3636rrvt76d9cdPfXhE/oVCo/28l48aNM7a5+eabJUk/+9nP+otKuM7JZcqUKZLsHy/2xTs6OuL6s5ywie2aa67RD3/4Q9XV1emOO+4YsG3Lli39bZAc/jKpzZ07Vw8//LDxbzDl5eUqLi5WfX29wuHwoBLh+vp6TZo0SRMmTBjJ4eMsRo0apfnz5xu3vfrqqwoGg7r11ls1duxYlZSUcJ2T1HXXXSdJ2r1796BtXV1damxsVE5OjsaOHauioqK4XeOE/Sjy+uuv1+TJk/Xss8+qoaGhP97a2qof/vCHyszM5G8zSaLv48enn35at99+ux555BFrYUFKSormz5+vUCikVatWDdi2atUqhUIhLViwYCSGDQdZWVlas2aN8d9HPvIRSdLSpUu1Zs0azZw5k+ucpEpLSzV79mw1NjZqw4YNA7Y9+OCDam1tVWVlpdLT0+N6jRP2C9oSU2r5RU1NjVauXKlAIKB//Md/NCa1ysrK/tllwuGw5syZoz/84Q+aPXu2Lr/8cu3cubN/Gp5NmzYpKytrpA8DUaqqqhr0BW2J65ysmpqadMstt+jQoUOaM2eOpkyZooaGBm3dulUTJ07USy+9pKKiIknxu8YJndgk6Xe/+51qamr0+uuvq6urS9OmTdNdd901YG5BJLa+N7ahPPTQQ7rzzjv7/9/a2qr77rtPGzdu1MGDB1VUVKTbb79dy5cvV25ubqyHDA/ZEpvEdU5Wzc3Nuvfee7VlyxYdPXpURUVFuvXWW/VP//RPg/7GGo9rnPCJDQAAFwn7NzYAAKJBYgMA+AqJDQDgKyQ2AICvkNgAAL5CYgMA+AqJDQDgKyQ2AICvkNgAAL5CYgMA+AqJDQDgKyQ2AICvkNgAAL7y/wH7Db046Egh9wAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pointumaze_sac_imgs = np.load(\"/home/m_bobrin/d3il/expert_data/PointUMaze-v2/expert_ims.npy\", allow_pickle=True)\n",
    "pointumaze_sac_obs = np.load(\"/home/m_bobrin/d3il/expert_data/PointUMaze-v2/expert_obs.npy\", allow_pickle=True)\n",
    "ig, ax = plt.subplots()\n",
    "\n",
    "plt.imshow(pointumaze_sac_imgs[18][0])\n",
    "plt.grid(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "968b4739-e48e-45bb-b972-972d5ee6a499",
   "metadata": {},
   "outputs": [],
   "source": [
    "from envs.maze_envs import CustomPointUMazeSize3Env, CustomAntUMazeSize3Env\n",
    "\n",
    "env = CustomPointUMazeSize3Env()\n",
    "episode_limit = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cceb98df-3201-4c60-b08c-3f9f220b11c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "env.reset()\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "env.render(mode='rgb_array')\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9b7c8a85-7ad4-4888-8b4a-9d4e18723ef0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moviepy - Building video /home/m_bobrin/AILOT/notebooks/rl-video-episode-0.mp4.\n",
      "Moviepy - Writing video /home/m_bobrin/AILOT/notebooks/rl-video-episode-0.mp4\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                                                                              "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moviepy - Done !\n",
      "Moviepy - video ready /home/m_bobrin/AILOT/notebooks/rl-video-episode-0.mp4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "from gymnasium.utils import save_video\n",
    "\n",
    "def record(env):\n",
    "    frames=[]\n",
    "    i = 0\n",
    "    num_episodes = 1\n",
    "    all_reward = []\n",
    "    key = jax.random.PRNGKey(42)\n",
    "    \n",
    "    for i in range(1):\n",
    "        episode_reward = 0\n",
    "        obs = env.reset()\n",
    "        done = False\n",
    "        for i in range(10):\n",
    "            obs, reward, done ,_ = env.step(env.action_space.sample())\n",
    "            os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "            frames.append(env.render(mode='rgb_array'))\n",
    "            os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'\n",
    "            episode_reward += reward\n",
    "        all_reward.append(episode_reward)\n",
    "    save_video.save_video(frames, video_folder='.', fps=30)\n",
    "record(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "536d3405-43a1-4ce2-92ed-754edbd83f67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f27c0fbc-385f-4fff-9b81-7f4fd9d5a28f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
