{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a748794",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# from gridworlds.grid_env import GridEnvironment\n",
    "from src.Generalist.generalist_meta_env import Generalist_MetaEpisodeEnv\n",
    "from src.Generalist.draw_gridworld import draw_policy\n",
    "\n",
    "# import gymnasium as gym\n",
    "from stable_baselines3 import PPO, A2C\n",
    "from stable_baselines3.common.vec_env import SubprocVecEnv\n",
    "from stable_baselines3.common.utils import set_random_seed\n",
    "\n",
    "#stablebaselines feature extractor\n",
    "from src.Generalist.feature_extractor import Custom_Flatten\n",
    "\n",
    "#For evaluation\n",
    "from src.Generalist.evals_utils import average_evals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1282fb4",
   "metadata": {},
   "source": [
    "# Load Gridworlds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8e5d2f26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "976 train gridworlds loaded\n",
      "96 val gridworlds loaded\n",
      "200 test gridworlds loaded\n"
     ]
    }
   ],
   "source": [
    "#Load the gridworlds\n",
    "from classes import Object\n",
    "import pickle\n",
    "\n",
    "with open('src/world_builder/worlds/master_set_train.pkl','rb') as f:\n",
    "    train_gridworlds = pickle.load(f)\n",
    "print(f'{len(train_gridworlds)} train gridworlds loaded')  \n",
    "for grid in train_gridworlds:\n",
    "    grid.early_stopping = False   \n",
    "\n",
    "with open('src/world_builder/worlds/master_set_val.pkl','rb') as f:\n",
    "    val_gridworlds = pickle.load(f)\n",
    "print(f'{len(val_gridworlds)} val gridworlds loaded')  \n",
    "for grid in val_gridworlds:\n",
    "    grid.early_stopping = False   \n",
    "\n",
    "with open('src/world_builder/worlds/master_set_test.pkl','rb') as f:\n",
    "    test_gridworlds = pickle.load(f)\n",
    "print(f'{len(test_gridworlds)} test gridworlds loaded')  \n",
    "for grid in test_gridworlds:\n",
    "    grid.early_stopping = False       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ded7ac6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAFaCAYAAACJ9E8TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAAxOAAAMTgF/d4wjAAAsj0lEQVR4nO3deXgUZb728W+lsyeEHUGUSNhRZDcIKOBEkBlFVBDcRtQZ3FCPR4UBUTzndVBRRwcdLgcGERwdDi6AuG9BEwQUBAQVBQKiCCIQQghk637ePyrdqZgOZOlOd5L7c1198aSruurX1c3dtTxVZRljDCIiAkBEqAsQEQknCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFIEzzjgDy7KwLIu77rrrhOM+/vjjvnEjIyNrqcLw5V0Wdc2ECRN8tZ/ss/R4PCxcuJC0tDRatmxJTEwMbdq04YILLmDOnDkBqykvL4+XX36Ze+65h6FDh5KUlIRlWXTs2LFSr//ll1+YNGkS7du3JyYmhlNOOYWxY8fy5Zdf+h1/3759ZZaBZVk89NBDAXs/dZX+V//GSy+9xOOPP050dLTf4c8//3wtVyTBNGjQIDp27IjL5fI7PCcnh1GjRvHpp5+SlJTEwIEDadKkCXv27GHDhg0cOXKE2267LSC1bNu2jWuuuaZar/3+++8577zz2L9/PykpKYwePZqdO3fy6quvsmzZMpYsWcJll11W5jVxcXFcf/31AGzcuJFNmzbV+D3UC0ZMcnKyAUy/fv0MYJYsWeJ3vFWrVhnA9O/f3wDG5XLVcqXh59tvvzXffvttqMuosuuvv94AZsGCBRWO4/F4zNChQw1gbr75ZpObm1tmeEFBgfniiy8CVtP27dvNDTfcYJ555hmTmZlp3nzzTQOYDh06nPB1Ho/H9O7d2wDmuuuuM8XFxb5h//znPw1gEhMTzd69eyucxowZMwxgZsyYEai3U2dp89nhxhtvBCpeG5w/f36Z8QS6du1K165dQ11GUCxYsICVK1cyYsQInnvuORITE8sMj46Opl+/fgGbX4cOHXj++eeZNGkSgwYNIiEhoVKve+edd9iwYQNNmjRhzpw5ZdZ6J06cyO9+9zuOHj3K3//+94DVWp8pFB169OhBv379eP/999mzZ0+ZYUePHmXJkiWcdtppDB8+vMJpfPPNN8yYMYNBgwbRtm1boqOjad68OWlpaSxZsqTc+AUFBfTr1w/LsvjLX/5Sbrjb7WbIkCFYlsWtt95apfdz7Ngxnn76aQYPHkzTpk2JiYkhOTmZSy65hJdfftnv+I8++ih9+vShUaNGxMfHc+aZZzJ9+nSys7P9zqOifYre/bS7du0iPT2d4cOH07RpU+Li4ujTpw+LFi3yO72cnBymT59Ojx49SEhIICYmhlNPPZVBgwbx4IMPUlRUVKVlUBOzZ88G4L777qu1eVbH0qVLARg1alS54Aa4+uqrAXj99ddrta46K9SrquHAu/mckZFh5syZYwDz8MMPlxln/vz5BjD333+/2blzZ4WbzzfddJMBTNeuXc2IESPMuHHjzLnnnmsiIiIMYO6+++5yr9mxY4dp0qSJsSzLvP3222WGTZ061QCmd+/eJj8/v9Lvaffu3aZ79+4GMPHx8ebCCy8048ePN+edd55p3LixSU5OLjP+wYMHTa9evQxgkpKSzKhRo8wVV1xhWrRoYQDTvn17s3PnznLzAYy/r5F3mT7wwAPGsizTt29fM378eDNgwADfa5566qkyr8nLyzNnnXWWAUzLli3NJZdcYsaPH2+GDh1qWrdubQCTnZ1d5jXezb4hQ4ZUetkYc/LN53379vk+42PHjpkdO3aYRx55xNx8883mnnvuMUuWLDEFBQVVmmdVpaenV2rzuW/fvgYws2fP9jt806ZNBjCWZZmjR4/6HUebz6UUiqZsKB4+fNjExcWZjh07lhln0KBBxrIss2PHjhOG4sqVK82OHTvKPb9161Zz2mmnGcCsXbu23PClS5cawLRo0cL8+OOPxhhj3n77bWNZlklKSjLbt2+v9Ptxu92+/aPDhw83+/fvLzP8+PHj5q233irz3Lhx4wxgUlNTzYEDB3zP5+bmmpEjRxrADBw4sNy8ThaKUVFRZsWKFWWGLViwwACmcePG5tixY77nFy5caAAzcuRIU1hYWO49rVy5slwQBSsU33//fQOYVq1amdmzZ5uoqCjfe/U+UlJSzKZNm6o036qobCg2a9bMAGbZsmV+hx86dMhX85YtW/yOo1Aspc3n32jcuDGXX34527dv55NPPgHgu+++Y9WqVQwZMoSUlJQTvr6icbp06cIDDzwAwKuvvlpu+OjRo7n77rs5cOAA48ePZ+fOnVx33XUYY5g/fz4dOnSo9HtYsWIF69ato02bNrz22mu0bNmyzPDY2Fh+//vf+/7evXs3r7zyCpZlMXfuXJo3b+4blpiYyLx584iNjeWzzz7js88+q3QdAHfccQcXX3xxmecmTJhA165dycnJYd26db7nf/nlFwAuvPBCoqKiyrwmIiKCIUOGlOsV0KJFC7p06UK7du2qVNfJHDx4EIBDhw5x5513cumll7J582Zyc3NZvXo1qampZGVlcdFFF/nGDZXc3FyACvdBOjepjxw5Uis11WXqkuPHjTfeyEsvvcTzzz/PkCFDfAdeKnuA5ejRo76d3wcOHKCwsBCAvXv3AnbI+vPYY4+xevVqVq1aRe/evcnJyeGOO+5gzJgxVar/3XffBex9Sf72Mf3Wp59+isfjoU+fPpx99tnlhrdt25YRI0awfPly0tPTGThwYKVrueSSS/w+361bN7Zu3Vpm323//v0BmDVrFs2bN+fiiy+mWbNmJ5z+pEmTmDRpUqXrqSxTcpnR4uJizj33XF555RXfsAEDBvDBBx/QqVMn9u7dy5w5c3w/eFL3aU3Rj2HDhtG+fXteffVVsrOzWbRoEUlJSZUKpxUrVnDGGWdw5ZVX8sgjjzBv3jwWLlzIwoULef/994GKf62joqJYvHgxUVFR5OTk0LNnT5544okq1//DDz8AVPqosDeY2rdvX+E43jXV3x6AOpmK1uCSkpIAyM/P9z03dOhQpkyZwv79+7n++ut9a4E33ngjy5cvx+PxVGneNdGoUSNf++abb/Y7/NprrwXgww8/rLW6/PHWmpeX53f40aNHfW3vcpeKKRT9sCyLCRMmcOzYMa6//nr27dvH+PHjiYuLO+Hr9uzZw7hx4zh48CCTJ09m06ZN5OTk4Ha7Mcbw3nvvAaVrIf4sXrzYd4R19+7dvrXLuioiompfsUcffZQdO3Ywe/Zsxo4dS15eHgsWLGD06NEMGDCgwv/4gebcBVLRLhPv86H+jM444wzA/r748+OPPwL29zo5Obm2yqqzFIoVmDBhAhEREaxYsQKo3KbzihUrOH78OJdddhmPPfYYZ599NklJSb5g2LZt2wlfn5mZyfTp04mPj2f8+PFkZ2czbty4KndD8a6dbd26tVLjt23bFoCsrKwKx/EO844bTGeccQZ33HEH//d//8dPP/3E559/TufOnfniiy+YNWtW0OcP0LlzZ98a2IEDB/yO432+MrsogqlPnz4AZfbPOnmf79SpU8hrrQsUihVo164dl156Kc2bN2fAgAGkpqae9DWHDh0C8PtrbIzx2zfQy3uApbi4mGeffZYXX3yRc889l7Vr1zJlypQq1X7RRRcB8J///KdSa1bnn38+ERERFZ7qtXfvXt9+ymHDhlWplkDo37+/71S6jRs31so8IyMjGT16NFDx5vEHH3wAwDnnnFMrNVXEe/reG2+84ffz9n7vLr/88lqtq65SKJ7A66+/zoEDB1i9enWlxu/WrRtgH112blK53W4efPDBCo/cGmO49tpr2bNnD9dffz033HADkZGRLF68mGbNmvHUU0+xfPnyStc9atQoevfuzc8//8zYsWPLHR3Nz8/nnXfe8f3drl07xo4dizGGm2++ucz4eXl5TJw4kfz8fAYOHFilgyxVtXTpUt9BH6eioiJfKP/2B+fZZ5+la9eu/PGPfwx4PdOmTSMqKop58+bx5ptvlhn2+OOPk5mZicvl4vbbby8zbNeuXb5O7bt27Qp4Xb81cuRIevfuzeHDh7nttttwu92+YXPnzuWjjz4iMTHxpBc7kRIh7A4UNpz9FCujon6KRUVFvo60iYmJ5g9/+IO58sorTXJysomKijJTpkzx26fu4YcfNoDp3r27ycvLKzPsjTfeMJZlmaZNm/rtPF2RXbt2mS5duvg6bw8fPtxcddVV5vzzz/fbefvAgQOmZ8+evv6Do0ePNmPGjDEtW7asUeftimr210/wrrvu8vXVvPDCC80111xjRo0aZVq1amUA07ZtW18fTq9g9VP0euGFF3wd7/v162fGjBljunbt6vv8586dW+41O3bs8C2Xn376qUp1jR492qSmpprU1FTTrVs3A5iYmBjfc6mpqWbevHnlXrd161bfZ5WSkmLGjRtnzjnnHAOYyMhI8/rrr59wvuqnWEqhaAIXisbYnZ2nTZtmunTpYmJjY02rVq3M6NGjzbp163ydcZ3/gVeuXGlcLpeJj483X3/9td/53XPPPQYw55xzTrlOzSeSm5trHnvsMdO/f3/TqFEjExMTY5KTk82oUaPM4sWLy42fl5dnHnnkEdOrVy8THx9vYmNjTbdu3cy0adPMoUOH/M4jkKG4YcMG85e//MUMHjzYtG3b1kRHR5uWLVuavn37mpkzZ5bpVO4V7FA0xpjPP//cXHHFFaZVq1YmKirKtG7d2owdO9ZvJ3xjjFmyZIkBzIgRI6pUkzGly+1Ej4qCa+/eveb22283ycnJvmV3+eWXm/Xr1590vgrFUgpFaZCqEopV9ac//clYlmU2bNgQ8GkHi0KxlDpvS4P2r3/9i5UrV+JyuXxXQaqpDz74gKuvvppevXoFZHrBkpOT49vPWFsHsOoChaI0aKtWrWLVqlUBDcXaOLgSCMePH2fhwoWhLiPsWMacoCexiEgDoy45IiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgE7SKzRUVFZe4qJiJSUy6Xi6ioqKDOIyihWFRUxMyZMzl+/HgwJl9jLpeLYcOGkZ6eHnbBHc61geqrqXCuL5xr84qLi/PdejZYgrL57Ha7wzYQwb7ReVpaGpGR4Xc3hnCuDVRfTYVzfeFcm9fx48eDHtjapygi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxUCiKiDgoFEVEHCKDNWGXy0VkZNAmXyMxMTEARCU1xVNYGOJqyoqKjgYgISGBqKioEFdTXnQdqa9pUhSFhZ4QV1NedLS9zMJx+XmXnff/R0NlGWNMoCean59PZmYmaWlpgZ60iEhQBW1VLj09nYyMjGBNvkaikpoy/Z7/ou2Hhlx3qKspq5EL9qRZbN6eRbwV8N+rGrMsi5SUFLKysgjC72mNRXCM9h16wPa24MkNdTnlRTSCjnvYuWMzHuJDXU0Z3s921qxZ5OXlhbqcCk2dOpXY2NigTT9ooeh2u3G7wyxxSng3mXPdkFsc4mIq4PEYPFb4bf5FRNi7oY0xeDzhV5/lXWae3PAMxRLGePCY8Fp+3s+2sLCQgoKCEFcTOjrQIiLioFAUEXFQKIqIOCgURUQcFIoiIg4KRRERB4WiiIiDQlFExCE8T06uK7I2wZ29S/+++Ha45ZnQ1SOVsu9XWLAU3voEduyGgzkQFwMpp8N5feGaiyG154mn8f1OmP8avLcKftwLecehdQvo0A7SzoU/XgptT6md9yOBpVCsibXLy/79+ZsKxTA3bwnc+zgcOVr2+aIi2Pit/Xjm37BpGZzdxf80Hp0HM56BwqKyz//ws/34eA20aQkTLgvKW5AgUyjWxJo37H9POQN+2QX7f7DXHlNOspohIfHE83Df43Y7Pg4mjoURg6FVMziUA6s3wqLlsH03VHQG4/+bAw+W/O61aQm3jodBfaBpEhzOhfVfw79X1MrbkSBRKFbXgZ9gx5d2e9x0mHsX5OfZa48KxbCzegNMedJut24BKxdBl/Zlx0kbCPffAnOXQJyf6w1s+AYe+ofdHtQH3noOGjcqO86wVLj3RgizK9JJFehAS3WtLVlLtCxIvQR6DC15XqsJ4eiOv5au/b38RPlA9IqIgFvG+x/+wGx7Go0bwdJnygeiU8mlCaUOUihWl3fTuWNfaNwS+oyw/96+Hg7sCV1dUs66LfZmLdhreMNSqz6NA9n2QRWAm6+Els0CV5+EF4VidRzLhc0r7bY3DPteVDr88zdquyI5gWUflrYvq+Z1j9PXQnHJZeZG/670+bxjsG0X7D9Y7fIkzCgUq2P9O1BcstPIG4andoQ2Hez2GoViOPliS2m7/1nVm8amrfa/lmUflV71JQz9IyT2hc4j4ZTBcNpQmPIEHD5S45IlhBSK1eENvYQm0GVA6fPetcav0u21SQkL3+0sbXdMrt40dpbsEWmSBB+tgWET4JMvyo6z5xeYNR/OudLuuyh1k0KxqtzF9poiQK80cLlKh3nXGosL4ct3a7828SvbsebW5AQHR07E26/RGLhpOkS64G9T4OdPoGATbF4O439vj7PtB7jqXntcqXsUilW15VM4mm23+44oO+zsYRBVcic0HYUOG3nHS9ux1bxR3bF8+9/DR+yDLi/NgrsnQJtW9pHmszrDf56ES0v2N676Ej74rEZlS4goFKtqjeMslj4XlR0WmwDdB9vtdW/ba5UScglxpe38at56JMZxN9IBPeGyC/2P95c/lbaXf1S9eUloKRSr6vOSNcDks6BF2/LDvWuPuYfg68zaq0sq5OxPmFPNXb0JjhvvDR9U8Xj9zoKSWzuzcWv15iWhpTNaqmLnV/bpfAA/bIGLT/KbsnY5nD002FXJSXROLj3wsX23vclbVac0L22fdoILPURGQoum8PN+OHi46vOR0NOaYlWsWX7ycZy0XzEs9O9R2nZ2z6mKTo6j1u6T3Jm0uOTOvi7976qTtKZYFd5N58Yt4b6XKh7vvX9BxhLYlwW7tsAZ1ewcJwExaph9ZRuwO3L/94SqT6NP99J21o8Vj5dfULqGeEqLqs9HQk+hWFkH9tin8AGcfYHdHaciBcfsUAT77BaFYkid2xvO6gRbtkHGevjkcxhyThWn0cvuo3j4CKxYCY/da3fk/q2Vn4O7ZE3xnB7lh0v40wp+ZX3+RmnHs16/O/G4PYaBq+T3Rme3hIXZ95e2r77PPjXPH48HnltsX0TWKTISbii5PuLWLPj7ovKvzS8ovawYwNUX16hkCRGFYmU59w+eaC0RIL5R6Zku276AQzq9IdSGpcJDk+z2z/uh9xVwz2PwXqZ9SbCP18Bfn4PuF8Ot/1PaL9Fp6sTSTeL/fgz+/AB8tBq+/BoWvwXnXQtfbLaHT7is4ovUSnjT5nNlHMuFTR/b7dYp9kVlT6b3hfBNpr12+fkKuGhiUEuUk5txu30x2GlP2xdy+NsL9sMf54lKXi2bwdvPwSW32cH6r1ftx2+NugD+8UAAC5dapTXFyvjy3dILQJxs09mrz/DSto5Ch407r4Pv34H/vcPeT3hKC3vTuFEC9OpmD//iFejR2f/r+5wJ37wJD99lH9Vu1hiiouwL114yDF77Oyx71r6yt9RNWlOsjMFj4c2xVXtNl1R48yR9NyQkTm0FD9xmP6qjcSP7Ct333xLYuiQ8aE1RRMRBoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOQeuS43K5iIwMzx4/USU35W3kp4NuqHlrioiwiLDC7zfLKjnh17IsIiLCsL6S3/kidxvwVPPeA8FkEokCLCsi7D5f72cbHR1NTEw1L1FeD1jGBP5OEvn5+WRmZpKWVs37SYqIhEjQVuXS09PJyMgI1uRrJCEhgcmTJ7N5exYeT3jdXSgiwqJHxxSysrIIwu9VjVmWRUpK+NZn3BYdO6fw+rPZvpOQwklkNFw+qSnbv8/CcoXX8vN+trNmzSIvLy/U5VRo6tSpxMbGBm36QQtFt9uN23sNpTATFWVfLz7eMnis8DrrxLtJZYzB4wmv2gDfJnO41mc8dn3FhYRlKHp5PAYr3L57JZ9tYWEhBQXVvJlNPRBeOzVEREJMoSgi4qBQFBFxUCiKiDgoFEVEHBSKIiIOCkUREYfwPA9PJMCee+V2Mr78T7nnLcsiJjqBZklt6NiuH+f1uYruKYOrNO3sI/uY/NS5HMvPAeC8Pldxy9h/BKRuqX0KRWnQjDHkFxzl51+38fOv2/h0/X8Y1GssE8c8S6QrqlLTWLhiii8Qpe5TKEqDM/GKZ0g5rTcABsgvyGX33q9577N/8vOv21i18RUS45vyx0sePem01n/zDl9sWUFsdCL5hUeDXLnUBu1TlAanZbNkTm/dndNbd6dd6+50Tk4lbcCNTJ/4JkkJ9o2dP1zzPEePZZ9wOscLcnnhjfuwLIvRF9xbG6VLLVAoipRonNiSnp3tKzu5PcVs373uhOMvee9hDuX8zJB+19Lh9D61UaLUAoWiiEPTxm187eOFuRWOt233F3ywZj4JcU0YP+LB2ihNaolCUcTh4OE9vnbzxm39jlPsLmL+0rsxxsOVw6fTKKF5bZUntUChKFLiYM4e1n/7NgCNE1uR3KaH3/He/HQ2P+77hjNO7ckF50yoxQqlNujoszQ4vx76gR/jmwH20efj+UfY8eN63s6cQ37BUVyuKK67ZCYx0fHlXrvvwA6WffwklmUx4dJZYXlLBqkZhaI0OHNfu8Pv85GuaM7vexVpqTdVeOBk/tL/pqg4nyH9rqFTu/7BLFNCRKEoUqLYXciX37xL48RTOL11N6Kj4soMX7nu33yTlUF8bGPGj5gRoiol2BSK0uDc/+c3fKfyGWPILzzKnl++45P1L5H+xSJWfPI03/+whmk3LSMy0r7zY87RX3n5bfso89jh95OU2CJk9UtwaYeINGiWZREX04iO7fpx02VPMa5kDfC7XWt4M+MZ33gvrphK3vHDJLfpQVrqDaEqV2qB1hRFHEYOuoVlHz9BfuFR0j9fxOhh97D31+2s/up1APp2H8nmbenlXrd73ze+dvaRn9n03YcAnN66O80an1o7xUtAKBRFHCIjozm1VWeyfvqSA4d/JO94DgVFx33DX/9o1kmnsWX7J2zZ/gkAE8c8y5C+VwetXgk8bT6L/IZlWb52oSMQpWHQmqKIg9tdzN5ftwPgckWRlNCCpkmteemRQyd83TdZmfx13ihA11Os67SmKOLw/uq5vmsjdk8ZjMul9YaGRp+4NDjOM1rA3kT+5eAu1n3zJms3LwfAFRHJmLSpoSpRQkihKA1ORWe0eCXGNeXPY56hY7t+tVSRhBOFojR4MVHxJCY04/RTutGj0wUM7n0lifFNQ12WhIhCURqEW8b+I6gHP7qnDD7pwRipG3SgRUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDkHrkuNyuYiMDM8eP9HR9oVDLcsKu3tseC9GEI61QWl9xcXFeDyeEFdTnnG7ACi5NmzY8dYVEWFhhdnn6/1so6OjiYmJCXE1oWMZY0ygJ5qfn09mZiZpaWmBnrSISFAFbVUuPT2djIyMYE2+RhISEpg8eTJZWVkE4TehRizLIiUlJSxrA3sNsUuXLsycOZOCgoJQl1NOTEwM06ZN45st3xMZ6Qp1OeVERFh07Byen6/3uzdr1izy8vJCXU6Fpk6dSmxsbNCmH7RQdLvduN3uYE2+RqKiogD7/hzhtgno3WQOx9oAX00FBQVhGYpelsuN5bJOPmIts8L48/V+9woLC8P6sw228NqpISISYgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEITJYE3a5XERGBm3yNRIdHQ2AZVlERITX74JlWQAUFxfj8XhCXE15xcXFAMTExIS4Ev+8dUVGHMdluUJcTXlWyXpIOH/3oqOjw/bzrQ2WMcYEeqL5+flkZmaSlpYW6EmLiARV0Fbl0tPTycjICNbkayQhIYHJkyeTlZVFEH4TaqS4uJguXbowc+ZMCgoKQl1OOTExMUybNo3NmzeH5ZZAZMRxOnXpA9vbgic31OWUF9EIOu5h547NeIgPdTVlWJZFSkoKs2bNIi8vL9TlVGjq1KnExsYGbfpB+1a73W7cbnewJl8jUVFRABhjwm4T1VtPQUFBWIaiV2RkZFiGom+T2ZMbnqFYwhgPHhNe3z3v5nxhYWFYf/eCLbx2aoiIhJhCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDuHX0UwkyPb9CguWwlufwI7dcDAH4mIg5XQ4ry9cczGk9vT/2p/3w5qNsGaT/Vj/DRw7bg9bMBMmXFZrb0OCRKEoDcq8JXDv43DkaNnni4pg47f245l/w6ZlcHaXsuN89R30HF1blUqoKBSlwXjiebjvcbsdHwcTx8KIwdCqGRzKgdUbYdFy2L4b/J3o5HzO5YKzOkFCHHy2oVbKl1qiUJQGYfUGmPKk3W7dAlYugi7ty46TNhDuvwXmLoE4P6fWNm8CM++GAT3hnB6QEA8vLFUo1jcKRWkQ7vhr6Zrey0+UD0SviAi4Zbz/Yae3gakTg1OfhA8dfZZ6b90WWP+13R7UB4alhrYeCW8KRan3ln1Y2r5Ml/iUk1AoSr33xZbSdv+zQleH1A0KRan3vttZ2u6YHLo6pG5QKEq9l32ktN2kUejqkLpBoSj1Xt7x0nZsw70fk1SSQlHqvYS40nZ+w73KvlSSQlHqvcaOTeac8L1ti4QJhaLUe50dB1e27w5dHVI3KBSl3uvfo7Tt7J4j4o9CUeq9UcNK286O3CL+KBSl3ju3t31FG4CM9fDJ56GtR8KbQlEahNn3l7avvg+27fI/nscDzy2G73f6Hy71n66SIw3CsFR4aBI89Kx99ezeV8DNV8LwQfb1FLOP2NdTfPEN+wyYDa/7n86r78HRY6V/Z6733wb7EmUXnRfwtyJBplCUBmPG7dA0CaY9DXnH4G8v2A9/XC7/z987C3742f+w+a/ZD68h/RWKdZFCURqUO6+DMSNg/qvwTgZk/QQHD9v3aOnQDs7vB9eNgh6dQ12phIpCURqcU1vBA7fZj6ra9VHg65HwogMtIiIOCkUREQeFooiIg0JRRMRBoSgi4qBQFBFxCFqXHJfLRWRkePb4iY6OBsCyLCIiwut3wVtPTEx4XiLaW1dkxHFcVgU9nEMowsovaYTpfQdK6vLVGUasknWk6OjosP3+1QbLGGMCPdH8/HwyMzNJS9P9JEWkbgnaqlx6ejoZGRnBmnyNJCQkMHnyZLKysgjCb0KNWJZFSkoK3333HR6PJ9TllBMZcZxOXfrA9rbgCcPLWEc0go572LVjLR4TG+pqyomw8jmjQ2p4Lr+SZff0kw+TfaQo1NVUaOrUqcTGBu+zDVoout1u3G53sCZfI1FRUQAYY8IueLybz5GRkWFXG1C6yezJDb//1A4eE4vbxIe6jIqF8fIrLCygoCB8QzHYwmuHmohIiCkURUQcFIoiIg4KRRERB4WiiIiDQlFExEGhKCLioFAUEXEIz5OTpc7b9yssWApvfQI7dsPBHPs+KCmnw3l94ZqLIbVn+dcdzIZlH8GHq2HjVvhpHxQU2Tec6tUVLr8Q/ngpxIXfySoBU51l5/HAmythzSbY+C18/wPs/RUKS5bdWZ1g1DC48QpISgzJ26ozFIoScPOWwL2Pw5GjZZ8vKrL/w278Fp75N2xaBmd3KR3+8RoYOdH+j/xb+w/C+6vsx5ML4NW/l31tfVHdZXfsOFx6u/9p/noI0tfajycWwJKnYGDvoL2FOk+hKAH1xPNw3+N2Oz4OJo6FEYPteysfyrHvrbxoOWzfba/dOB3KsQPR5YLhA2Hk+fZd9RLi7PH/9aodnNt+gAtvgi9fg7an1PpbDJqaLDuw156H9ofBfaF7B/sGXS4X7PnFXutcuMxuX/Rn2LjUXvOU8hSKEjCrN8CUJ+126xawchF0aV92nLSBcP8tMHdJ+U3gSBf8eSw8cCuc3qbssP494Ko/2PddfnKBveb4wGx4/q/Bez+1qabLLiEeDnxmh+lv9T0TRl0AY4bD8D9Bbh787xx44ZHgvJe6TgdaJGDu+GvpGszLT5T/T+0VEQG3jC8/fHQazP3f8oHo9P/utAMA4PUP/K8x1UU1XXaW5T8QnS4cBN062O0PPqtZvfWZQlECYt0WWP+13R7UB4alBmc+cbFwZke7nZNr38i+rqutZQcQX7KGmX0kePOo6xSKEhDLPixtXxbkawsXFJa2Y+vBBaJra9lt+d4+og+Qclrw5lPXaZ+iBMQXW0rb/c8K3nxycmFrlt1ufxo0SgjevGpLMJfdsePw4z47eJ9YAN5LnP5pTGDnU58oFCUgvttZ2u6YHLz5zPlP6ZridaOCN5/aFOhld/gIND3BJvj1o2HSNTWfT32lUJSAcO6jahKke0Zt/wFm/tNut24B994YnPnUttpYdmD3TZxxOwwfFLx51AcKRQmIvOOl7WDs58svgLF3w9Fj9hHYhY/Wj01nCPyyS0qEzcvtdkGRfVbM4rdh+cfw1EJo2wrO7FTz+dRXCkUJiIS40rMw8gtO3j2kKjweuG6KfTYHwMz/ql9rO4FedhERcFbn0r/7nglXjoQFr8ON99udwFfMgfP61Ww+9ZWOPktANHZs9uUE+H5Md/4VXn3Pbt8yDqb8ObDTD7VgLjunGy6HK4bb87j6PigsPPlrGiKFogREZ8cBgu27Azfdh56Ff7xst8deBP94MHDTDhfBWnb+eLv8/LRPHbgrolCUgOjfo7Tt7GJSE7NfhP/5h90ePgj+/Zi9aVjfBGPZVaRVs9L2NzuCO6+6qh5+xSQURg0rbTs7I1fXomXwXyXn5g7uC0ufgejomk83HAV62Z3IvgOlbcsK7rzqKoWiBMS5ve1r9gFkrIdPPq/+tJZ9CDdOB2Og31nw1nOBPXATbgK57E7mjfTSdsd2wZtPXaZQlICZfX9p++r7YNsu/+N5PPDcYvh+Z/lhH6yC8ffYZ16c1QnendswLopa02U3bwl8ve3E83j5TXjtfbvdJAnSzq12ufWauuRIwAxLhYcm2QdHft4Pva+Am6+09we2amZ3Ul69EV58wz6LY8PrZV+/dhOMvsM+Y6VZY3h6qn316L2/VjzP9qeVXjWnLqvpsnvrU5g4w+6gffFQ6N3Nfp3bYx+8eeVdWOrYNH/8XkisJ/08A02hKAE143b78vfTnoa8Y/C3F+yHPy5X2b/fybDP1QX7oqpplThjJX0hDD2nBgWHkZosO6/PNtiPisTFwpOT4U9ja1hsPaZQlIC78zoYMwLmv2oHXdZP9iW+4mKgQzs4v5993nKPziedVINT3WX3z4fgqt/Dp+vsy5DtPwQHsiG/EBonQrcU+N25cNMVcFrrULyzukOhKEFxait44Db7UVkPTbIfDV11lt0pLWDc7+2H1IwOtIiIOCgURUQcFIoiIg4KRRERB4WiiIiDQlFExCFoXXJcLheRkeHZ4ye65MoClmUREWaXXbFKztIPx9oALO/vaEQQr5tfEyV1WVYEEZaWX5WU1BQVHUNMTPgtu9piGWNMoCean59PZmYmaWlBvteliEiABW1VLj09nYyMjGBNvkZiYmKYNm0as2bNojDMLj8cHR3N5MmTw7I2KK3v6ScfprCwINTllBMdHcN/3TNdy68aoqJjuPue6cycOZOCgvCqzenuu++mSZMmQZt+0ELR7Xbj9t5kNkzl5eWF3YcfE2PfuSgca4PS+rKPFFFQUBTiasrzbvZp+VWdd9kVFBSE5bKrLQ13x4GIiB8KRRERB4WiiIiDQlFExEGhKCLioFAUEXFQKIqIOITneXgiUmc0bdqUKVOmVGrcDz/8kA8/DPLNrWtIa4oiIg5aUxSRgFm9ejVr1qypcPjRo0drsZrqUSiKSMDk5eXxyy+/hLqMGtHms4iIg0JRRMRBm88iElAul4ukpCQAcnNzKS4uDnFFVaNQFJGAGTBgAEOHDvVddd/tdvPTTz+xdu1aNmzYQBCuaR1wCkURCZjExMQyf7tcLpKTk0lOTqZnz5689NJLYXnxXyeFoojU2O7du1m/fj07d+4kOzsbj8dDkyZN6N69O0OHDiUhIYEuXbowbtw4XnzxxVCXe0IKRRGpkezsbObMmVPu+YMHD5KRkcHXX3/NrbfeSqNGjTjzzDPp1KkT27ZtC0GllaOjzyISVIcOHeLdd9/1/d23b98QVnNyCkURCbotW7b47tmUnJwc4mpOTKEoIkFXUFDAsWPHgPIHY8KNQlFEaoVlWQB4PJ4QV3JiCkURCbr4+Hji4+MByMnJCXE1J6ZQFJGg69WrFxERdtxkZWWFuJoTUyiKSI307t3bt2nsz6mnnsrw4cMBe9N59erVtVVataifoojUyLhx4xg5ciSbN2/mhx9+IDs7G7fbTaNGjejatSv9+vUjKioKgPT09LC/tJhCUURqLCkpiUGDBjFo0CC/w91uNx999BEff/xxLVdWdQpFEamRl19+mY4dO9KmTRsaN25MXFwcLpeLgoICDh48yI4dO1i7di2HDh0KdamVolAUkRr56quv+Oqrr0JdRsDoQIuIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDkEJRZfLRVxcXDAmLSINXHR0dFCnb5kg3XOwqKjId6XdcBQbG0t+fn6oy/ArnGsD1VdT4VxfbGwshw8fDnUZFYqOjvZdgixYghaKIiJ1kfYpiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuKgUBQRcVAoiog4KBRFRBwUiiIiDgpFEREHhaKIiINCUUTEQaEoIuLw/wFAh6XgUl4moAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 400x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# VISUALISE GRIDWORLD HERE:\n",
    "\n",
    "# ========================================\n",
    "env = test_gridworlds[0] #change index to change gridworld\n",
    "# ========================================\n",
    "\n",
    "env.reset()\n",
    "env.render()\n",
    "plt.title(f'Max coins: {env.max_coins}', fontsize=20)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87c3384f",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6909093d",
   "metadata": {},
   "outputs": [],
   "source": [
    "## DREST CONFIG ##\n",
    "#IF USING DREST CONFIG set self.m_values = self.env.max_coins in lines 24 and 48 in src/Generalist/generalist_meta_env.py\n",
    "config={\n",
    "    \"env_list\": \"976\",     \n",
    "    \"test_env_list\": \"200\",  \n",
    "    \"lambda_factor\": 0.9,\n",
    "    \"meta_ep_size\": 32,\n",
    "    \"hidden_layer_depth\": 512,\n",
    "    \"num_hidden_layers\": 3,\n",
    "    \"ent_coef\": 0.02,\n",
    "    \"learning_rate\": 0.000001,\n",
    "    \"total_timesteps\": 500000,\n",
    "    \"clip_range\": 0.2,\n",
    "    \"n_steps_ppo\": 8192,\n",
    "    \"batch_size\": 64,\n",
    "    \"vf_coef\": 0.55,\n",
    "    \"timesteps_per_run\":1000000\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b152b4f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cpu device\n",
      "------------------------------\n",
      "| time/              |       |\n",
      "|    fps             | 2911  |\n",
      "|    iterations      | 1     |\n",
      "|    time_elapsed    | 8     |\n",
      "|    total_timesteps | 24576 |\n",
      "------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1553          |\n",
      "|    iterations           | 2             |\n",
      "|    time_elapsed         | 31            |\n",
      "|    total_timesteps      | 49152         |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00014181258 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.39         |\n",
      "|    explained_variance   | 0.15893596    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.156         |\n",
      "|    n_updates            | 10            |\n",
      "|    policy_gradient_loss | -0.000662     |\n",
      "|    value_loss           | 0.603         |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1320          |\n",
      "|    iterations           | 3             |\n",
      "|    time_elapsed         | 55            |\n",
      "|    total_timesteps      | 73728         |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 8.5672225e-05 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.39         |\n",
      "|    explained_variance   | 0.19197899    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.375         |\n",
      "|    n_updates            | 20            |\n",
      "|    policy_gradient_loss | -0.000426     |\n",
      "|    value_loss           | 0.602         |\n",
      "-------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1245         |\n",
      "|    iterations           | 4            |\n",
      "|    time_elapsed         | 78           |\n",
      "|    total_timesteps      | 98304        |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0001162888 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.39        |\n",
      "|    explained_variance   | 0.25031382   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.111        |\n",
      "|    n_updates            | 30           |\n",
      "|    policy_gradient_loss | -0.000466    |\n",
      "|    value_loss           | 0.747        |\n",
      "------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1163          |\n",
      "|    iterations           | 5             |\n",
      "|    time_elapsed         | 105           |\n",
      "|    total_timesteps      | 122880        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00012480286 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.39         |\n",
      "|    explained_variance   | 0.26307642    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.496         |\n",
      "|    n_updates            | 40            |\n",
      "|    policy_gradient_loss | -0.00048      |\n",
      "|    value_loss           | 0.73          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1150          |\n",
      "|    iterations           | 6             |\n",
      "|    time_elapsed         | 128           |\n",
      "|    total_timesteps      | 147456        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00010570235 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.39         |\n",
      "|    explained_variance   | 0.24897456    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.341         |\n",
      "|    n_updates            | 50            |\n",
      "|    policy_gradient_loss | -0.000571     |\n",
      "|    value_loss           | 0.793         |\n",
      "-------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1137         |\n",
      "|    iterations           | 7            |\n",
      "|    time_elapsed         | 151          |\n",
      "|    total_timesteps      | 172032       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0002712777 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.38        |\n",
      "|    explained_variance   | 0.29595226   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.364        |\n",
      "|    n_updates            | 60           |\n",
      "|    policy_gradient_loss | -0.000693    |\n",
      "|    value_loss           | 0.765        |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1120         |\n",
      "|    iterations           | 8            |\n",
      "|    time_elapsed         | 175          |\n",
      "|    total_timesteps      | 196608       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0002323658 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.38        |\n",
      "|    explained_variance   | 0.25505757   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.345        |\n",
      "|    n_updates            | 70           |\n",
      "|    policy_gradient_loss | -0.000686    |\n",
      "|    value_loss           | 0.974        |\n",
      "------------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 1111        |\n",
      "|    iterations           | 9           |\n",
      "|    time_elapsed         | 198         |\n",
      "|    total_timesteps      | 221184      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 8.77308e-05 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.38       |\n",
      "|    explained_variance   | 0.32742667  |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.348       |\n",
      "|    n_updates            | 80          |\n",
      "|    policy_gradient_loss | -0.00041    |\n",
      "|    value_loss           | 0.933       |\n",
      "-----------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1100          |\n",
      "|    iterations           | 10            |\n",
      "|    time_elapsed         | 223           |\n",
      "|    total_timesteps      | 245760        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00018165137 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.38         |\n",
      "|    explained_variance   | 0.3201332     |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.4           |\n",
      "|    n_updates            | 90            |\n",
      "|    policy_gradient_loss | -0.000714     |\n",
      "|    value_loss           | 1.07          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1093          |\n",
      "|    iterations           | 11            |\n",
      "|    time_elapsed         | 247           |\n",
      "|    total_timesteps      | 270336        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00022535835 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.38         |\n",
      "|    explained_variance   | 0.35599732    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.407         |\n",
      "|    n_updates            | 100           |\n",
      "|    policy_gradient_loss | -0.000673     |\n",
      "|    value_loss           | 0.982         |\n",
      "-------------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 1095        |\n",
      "|    iterations           | 12          |\n",
      "|    time_elapsed         | 269         |\n",
      "|    total_timesteps      | 294912      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.000276632 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.38       |\n",
      "|    explained_variance   | 0.40013808  |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.554       |\n",
      "|    n_updates            | 110         |\n",
      "|    policy_gradient_loss | -0.000754   |\n",
      "|    value_loss           | 0.946       |\n",
      "-----------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1082          |\n",
      "|    iterations           | 13            |\n",
      "|    time_elapsed         | 295           |\n",
      "|    total_timesteps      | 319488        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00022151829 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.38         |\n",
      "|    explained_variance   | 0.38441366    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.64          |\n",
      "|    n_updates            | 120           |\n",
      "|    policy_gradient_loss | -0.000835     |\n",
      "|    value_loss           | 1.09          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1068          |\n",
      "|    iterations           | 14            |\n",
      "|    time_elapsed         | 321           |\n",
      "|    total_timesteps      | 344064        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00019617243 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.38         |\n",
      "|    explained_variance   | 0.40714806    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.328         |\n",
      "|    n_updates            | 130           |\n",
      "|    policy_gradient_loss | -0.000714     |\n",
      "|    value_loss           | 1.08          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1052          |\n",
      "|    iterations           | 15            |\n",
      "|    time_elapsed         | 350           |\n",
      "|    total_timesteps      | 368640        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00024307931 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.38         |\n",
      "|    explained_variance   | 0.41911715    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.591         |\n",
      "|    n_updates            | 140           |\n",
      "|    policy_gradient_loss | -0.000722     |\n",
      "|    value_loss           | 1.18          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1035          |\n",
      "|    iterations           | 16            |\n",
      "|    time_elapsed         | 379           |\n",
      "|    total_timesteps      | 393216        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00033603268 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.37         |\n",
      "|    explained_variance   | 0.42383486    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.382         |\n",
      "|    n_updates            | 150           |\n",
      "|    policy_gradient_loss | -0.000996     |\n",
      "|    value_loss           | 1.02          |\n",
      "-------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1029          |\n",
      "|    iterations           | 17            |\n",
      "|    time_elapsed         | 405           |\n",
      "|    total_timesteps      | 417792        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00049343595 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.37         |\n",
      "|    explained_variance   | 0.4191054     |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 0.482         |\n",
      "|    n_updates            | 160           |\n",
      "|    policy_gradient_loss | -0.00123      |\n",
      "|    value_loss           | 0.992         |\n",
      "-------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1027         |\n",
      "|    iterations           | 18           |\n",
      "|    time_elapsed         | 430          |\n",
      "|    total_timesteps      | 442368       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0002889405 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.37        |\n",
      "|    explained_variance   | 0.3750825    |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.64         |\n",
      "|    n_updates            | 170          |\n",
      "|    policy_gradient_loss | -0.000858    |\n",
      "|    value_loss           | 1.2          |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1018         |\n",
      "|    iterations           | 19           |\n",
      "|    time_elapsed         | 458          |\n",
      "|    total_timesteps      | 466944       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0006332556 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.36        |\n",
      "|    explained_variance   | 0.43184948   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.404        |\n",
      "|    n_updates            | 180          |\n",
      "|    policy_gradient_loss | -0.00144     |\n",
      "|    value_loss           | 1.18         |\n",
      "------------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 1006        |\n",
      "|    iterations           | 20          |\n",
      "|    time_elapsed         | 488         |\n",
      "|    total_timesteps      | 491520      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.000521815 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.36       |\n",
      "|    explained_variance   | 0.41128016  |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.746       |\n",
      "|    n_updates            | 190         |\n",
      "|    policy_gradient_loss | -0.00119    |\n",
      "|    value_loss           | 1.16        |\n",
      "-----------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 996         |\n",
      "|    iterations           | 21          |\n",
      "|    time_elapsed         | 518         |\n",
      "|    total_timesteps      | 516096      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.001345871 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.35       |\n",
      "|    explained_variance   | 0.38773662  |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.425       |\n",
      "|    n_updates            | 200         |\n",
      "|    policy_gradient_loss | -0.00173    |\n",
      "|    value_loss           | 1.2         |\n",
      "-----------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 993          |\n",
      "|    iterations           | 22           |\n",
      "|    time_elapsed         | 544          |\n",
      "|    total_timesteps      | 540672       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0005700137 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.34        |\n",
      "|    explained_variance   | 0.40288287   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.59         |\n",
      "|    n_updates            | 210          |\n",
      "|    policy_gradient_loss | -0.00124     |\n",
      "|    value_loss           | 1.4          |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 999          |\n",
      "|    iterations           | 23           |\n",
      "|    time_elapsed         | 565          |\n",
      "|    total_timesteps      | 565248       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0012526414 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.33        |\n",
      "|    explained_variance   | 0.42143023   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.48         |\n",
      "|    n_updates            | 220          |\n",
      "|    policy_gradient_loss | -0.00165     |\n",
      "|    value_loss           | 1.27         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1001         |\n",
      "|    iterations           | 24           |\n",
      "|    time_elapsed         | 589          |\n",
      "|    total_timesteps      | 589824       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0009472478 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.32        |\n",
      "|    explained_variance   | 0.40632927   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.851        |\n",
      "|    n_updates            | 230          |\n",
      "|    policy_gradient_loss | -0.00142     |\n",
      "|    value_loss           | 1.47         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1003         |\n",
      "|    iterations           | 25           |\n",
      "|    time_elapsed         | 612          |\n",
      "|    total_timesteps      | 614400       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0008210715 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.32        |\n",
      "|    explained_variance   | 0.42621106   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.549        |\n",
      "|    n_updates            | 240          |\n",
      "|    policy_gradient_loss | -0.00153     |\n",
      "|    value_loss           | 1.43         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1004         |\n",
      "|    iterations           | 26           |\n",
      "|    time_elapsed         | 636          |\n",
      "|    total_timesteps      | 638976       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0009404524 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.31        |\n",
      "|    explained_variance   | 0.44333524   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.712        |\n",
      "|    n_updates            | 250          |\n",
      "|    policy_gradient_loss | -0.00137     |\n",
      "|    value_loss           | 1.48         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1006         |\n",
      "|    iterations           | 27           |\n",
      "|    time_elapsed         | 659          |\n",
      "|    total_timesteps      | 663552       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0012734842 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.3         |\n",
      "|    explained_variance   | 0.46461284   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.597        |\n",
      "|    n_updates            | 260          |\n",
      "|    policy_gradient_loss | -0.0019      |\n",
      "|    value_loss           | 1.37         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1010         |\n",
      "|    iterations           | 28           |\n",
      "|    time_elapsed         | 681          |\n",
      "|    total_timesteps      | 688128       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0007812732 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.28        |\n",
      "|    explained_variance   | 0.4519112    |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.712        |\n",
      "|    n_updates            | 270          |\n",
      "|    policy_gradient_loss | -0.00158     |\n",
      "|    value_loss           | 1.52         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1011         |\n",
      "|    iterations           | 29           |\n",
      "|    time_elapsed         | 704          |\n",
      "|    total_timesteps      | 712704       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0014570669 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.27        |\n",
      "|    explained_variance   | 0.47846895   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.744        |\n",
      "|    n_updates            | 280          |\n",
      "|    policy_gradient_loss | -0.00195     |\n",
      "|    value_loss           | 1.7          |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1012         |\n",
      "|    iterations           | 30           |\n",
      "|    time_elapsed         | 728          |\n",
      "|    total_timesteps      | 737280       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0016445514 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.25        |\n",
      "|    explained_variance   | 0.44151318   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 1.11         |\n",
      "|    n_updates            | 290          |\n",
      "|    policy_gradient_loss | -0.00208     |\n",
      "|    value_loss           | 1.94         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1013         |\n",
      "|    iterations           | 31           |\n",
      "|    time_elapsed         | 751          |\n",
      "|    total_timesteps      | 761856       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0009900521 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.24        |\n",
      "|    explained_variance   | 0.524621     |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 1.48         |\n",
      "|    n_updates            | 300          |\n",
      "|    policy_gradient_loss | -0.00176     |\n",
      "|    value_loss           | 1.7          |\n",
      "------------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 1015        |\n",
      "|    iterations           | 32          |\n",
      "|    time_elapsed         | 774         |\n",
      "|    total_timesteps      | 786432      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.001380661 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.23       |\n",
      "|    explained_variance   | 0.4876567   |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.708       |\n",
      "|    n_updates            | 310         |\n",
      "|    policy_gradient_loss | -0.00175    |\n",
      "|    value_loss           | 1.98        |\n",
      "-----------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1017         |\n",
      "|    iterations           | 33           |\n",
      "|    time_elapsed         | 797          |\n",
      "|    total_timesteps      | 811008       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0011239928 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.21        |\n",
      "|    explained_variance   | 0.48170626   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 1.63         |\n",
      "|    n_updates            | 320          |\n",
      "|    policy_gradient_loss | -0.00158     |\n",
      "|    value_loss           | 1.99         |\n",
      "------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1021          |\n",
      "|    iterations           | 34            |\n",
      "|    time_elapsed         | 817           |\n",
      "|    total_timesteps      | 835584        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00093260827 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.2          |\n",
      "|    explained_variance   | 0.4492888     |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 1.07          |\n",
      "|    n_updates            | 330           |\n",
      "|    policy_gradient_loss | -0.00148      |\n",
      "|    value_loss           | 1.82          |\n",
      "-------------------------------------------\n",
      "-----------------------------------------\n",
      "| time/                   |             |\n",
      "|    fps                  | 1021        |\n",
      "|    iterations           | 35          |\n",
      "|    time_elapsed         | 842         |\n",
      "|    total_timesteps      | 860160      |\n",
      "| train/                  |             |\n",
      "|    approx_kl            | 0.001274461 |\n",
      "|    clip_fraction        | 0           |\n",
      "|    clip_range           | 0.2         |\n",
      "|    entropy_loss         | -1.19       |\n",
      "|    explained_variance   | 0.52198625  |\n",
      "|    learning_rate        | 1e-06       |\n",
      "|    loss                 | 0.757       |\n",
      "|    n_updates            | 340         |\n",
      "|    policy_gradient_loss | -0.00153    |\n",
      "|    value_loss           | 1.83        |\n",
      "-----------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1023         |\n",
      "|    iterations           | 36           |\n",
      "|    time_elapsed         | 864          |\n",
      "|    total_timesteps      | 884736       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0011161772 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.18        |\n",
      "|    explained_variance   | 0.53321064   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.996        |\n",
      "|    n_updates            | 350          |\n",
      "|    policy_gradient_loss | -0.00138     |\n",
      "|    value_loss           | 2.31         |\n",
      "------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1024          |\n",
      "|    iterations           | 37            |\n",
      "|    time_elapsed         | 887           |\n",
      "|    total_timesteps      | 909312        |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00091700256 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.16         |\n",
      "|    explained_variance   | 0.53461885    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 1.1           |\n",
      "|    n_updates            | 360           |\n",
      "|    policy_gradient_loss | -0.00145      |\n",
      "|    value_loss           | 2.04          |\n",
      "-------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1026         |\n",
      "|    iterations           | 38           |\n",
      "|    time_elapsed         | 909          |\n",
      "|    total_timesteps      | 933888       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0011034744 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.15        |\n",
      "|    explained_variance   | 0.482486     |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 1.08         |\n",
      "|    n_updates            | 370          |\n",
      "|    policy_gradient_loss | -0.00153     |\n",
      "|    value_loss           | 1.97         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1029         |\n",
      "|    iterations           | 39           |\n",
      "|    time_elapsed         | 930          |\n",
      "|    total_timesteps      | 958464       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0010726157 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.15        |\n",
      "|    explained_variance   | 0.5020282    |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 0.919        |\n",
      "|    n_updates            | 380          |\n",
      "|    policy_gradient_loss | -0.00117     |\n",
      "|    value_loss           | 2.01         |\n",
      "------------------------------------------\n",
      "------------------------------------------\n",
      "| time/                   |              |\n",
      "|    fps                  | 1031         |\n",
      "|    iterations           | 40           |\n",
      "|    time_elapsed         | 953          |\n",
      "|    total_timesteps      | 983040       |\n",
      "| train/                  |              |\n",
      "|    approx_kl            | 0.0005020134 |\n",
      "|    clip_fraction        | 0            |\n",
      "|    clip_range           | 0.2          |\n",
      "|    entropy_loss         | -1.15        |\n",
      "|    explained_variance   | 0.52325255   |\n",
      "|    learning_rate        | 1e-06        |\n",
      "|    loss                 | 1.85         |\n",
      "|    n_updates            | 390          |\n",
      "|    policy_gradient_loss | -0.000904    |\n",
      "|    value_loss           | 2.43         |\n",
      "------------------------------------------\n",
      "-------------------------------------------\n",
      "| time/                   |               |\n",
      "|    fps                  | 1031          |\n",
      "|    iterations           | 41            |\n",
      "|    time_elapsed         | 977           |\n",
      "|    total_timesteps      | 1007616       |\n",
      "| train/                  |               |\n",
      "|    approx_kl            | 0.00092892087 |\n",
      "|    clip_fraction        | 0             |\n",
      "|    clip_range           | 0.2           |\n",
      "|    entropy_loss         | -1.16         |\n",
      "|    explained_variance   | 0.56079924    |\n",
      "|    learning_rate        | 1e-06         |\n",
      "|    loss                 | 1.44          |\n",
      "|    n_updates            | 400           |\n",
      "|    policy_gradient_loss | -0.0012       |\n",
      "|    value_loss           | 2.26          |\n",
      "-------------------------------------------\n",
      "Step count: 1000000\n",
      "Average Usefulness: 0.37537378181128206\n",
      "Average NEUTRALITY: 0.659433085823456\n",
      "Weighted Average: 14.341610500143513\n"
     ]
    }
   ],
   "source": [
    "## PICK ENVIRONMENT\n",
    "train_env_list = train_gridworlds           #MANUALLY CHANGE\n",
    "test_env_list = test_gridworlds              #MANUALLY CHANGE\n",
    "\n",
    "hld = config[\"hidden_layer_depth\"]\n",
    "num_layers = config[\"num_hidden_layers\"]\n",
    "\n",
    "def net_arch(hidden_layer_depth, num_hidden_layers):\n",
    "    net_arch_list = []\n",
    "    for n in range(num_hidden_layers):\n",
    "          net_arch_list.append(hidden_layer_depth)\n",
    "    return net_arch_list      \n",
    "\n",
    "net_arch_list = net_arch(hld, num_layers)\n",
    "\n",
    "policy_kwargs = dict(features_extractor_class=Custom_Flatten, \n",
    "                     features_extractor_kwargs=dict(features_dim=250),\n",
    "                     net_arch=dict(pi=net_arch_list, \n",
    "                                   vf=net_arch_list))\n",
    "\n",
    "#Number of vectorised environments\n",
    "num_cpu=3\n",
    "\n",
    "#Set-up for vectorised environments\n",
    "def make_env(rank, seed=0):\n",
    "        \"\"\"\n",
    "        Utility function for multiprocessed env.\n",
    "\n",
    "        :param env_id: (str) the environment ID\n",
    "        :param seed: (int) the inital seed for RNG\n",
    "        :param rank: (int) index of the subprocess\n",
    "        \"\"\"\n",
    "\n",
    "        def _init():\n",
    "            env = Generalist_MetaEpisodeEnv(\n",
    "                    train_env_list, \n",
    "                    meta_ep_size=config[\"meta_ep_size\"],\n",
    "                    lambda_factor=config[\"lambda_factor\"],\n",
    "                )\n",
    "            # use a seed for reproducibility\n",
    "            # Important: use a different seed for each environment\n",
    "            # otherwise they would generate the same experiences\n",
    "            env.reset(seed=seed + rank)\n",
    "            return env\n",
    "\n",
    "        set_random_seed(seed)\n",
    "        return _init\n",
    "\n",
    "\n",
    "def vec_learning_run(model, timesteps):\n",
    "\n",
    "    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)],start_method=\"fork\")\n",
    "\n",
    "    model.set_env(env)\n",
    "\n",
    "    model.learn(total_timesteps=timesteps) \n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "def vec_learning(train_env_list,timesteps_per_run, total_timesteps):\n",
    "\n",
    "    steps_count = 0\n",
    "    best_test_weighted_average = 0\n",
    "    best_model_step = 0\n",
    "\n",
    "    env = SubprocVecEnv([make_env(i) for i in range(num_cpu)],start_method=\"fork\")\n",
    "\n",
    "    # Create the PPO model with the custom architecture\n",
    "    model = PPO(\"MlpPolicy\",                                  #MAUALLY CHANGE with feature_extractor_class\n",
    "                env,                                      #Change for vectorised Envs\n",
    "                verbose=1,\n",
    "                ent_coef=config[\"ent_coef\"],\n",
    "                learning_rate= config[\"learning_rate\"],\n",
    "                clip_range=config[\"clip_range\"],\n",
    "                n_steps=config[\"n_steps_ppo\"],\n",
    "                batch_size=config[\"batch_size\"],\n",
    "                vf_coef=config[\"vf_coef\"],\n",
    "                policy_kwargs=policy_kwargs,           #MANUALLY CHANGE\n",
    "                )\n",
    "    \n",
    "    while steps_count < total_timesteps:\n",
    "\n",
    "        model = vec_learning_run(model, timesteps_per_run)\n",
    "        train_av_traj_ratio, train_av_usefulness, train_av_entropy = average_evals(train_env_list,model)\n",
    "        test_av_ratio, test_av_usefulness, test_av_entropy = average_evals(test_env_list,model)\n",
    "        test_weighted_average = 0.7 * test_av_usefulness + 0.3 * test_av_entropy\n",
    "        steps_count += timesteps_per_run\n",
    "        print(f'Step count: {steps_count}')\n",
    "        print(f'Average Usefulness: {train_av_usefulness}')\n",
    "        print(f'Average NEUTRALITY: {train_av_entropy}')\n",
    "        print(f'Weighted Average: {train_av_traj_ratio}')\n",
    "\n",
    "        if test_weighted_average > best_test_weighted_average:\n",
    "             best_model_step = steps_count - timesteps_per_run\n",
    "             best_test_weighted_average = test_weighted_average\n",
    "\n",
    "    return model, best_model_step  \n",
    "\n",
    "model, best_model_step = vec_learning(train_env_list, config['timesteps_per_run'], config['total_timesteps'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c5c42ca",
   "metadata": {},
   "source": [
    "# Train Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a924e442",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average evals for train data\n",
      "Average Trajectory Ratio:14.342\n",
      "Average USEFULNESS:0.37537378181128206\n",
      "Average NEUTRALITY:0.659433085823456\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_av_traj, train_av_usefulness, train_av_entropy = average_evals(train_env_list,model)\n",
    "\n",
    "print('Average evals for train data')\n",
    "print(f'Average Trajectory Ratio:{\"{:.3f}\".format(train_av_traj)}')\n",
    "print(f'Average USEFULNESS:{train_av_usefulness}')\n",
    "print(f'Average NEUTRALITY:{train_av_entropy}')\n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b8eb08",
   "metadata": {},
   "source": [
    "# Test Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab5577ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_av_traj, test_av_usefulness, test_av_entropy = average_evals(test_env_list,model)\n",
    "\n",
    "print('Average evals for test data')\n",
    "print(f'Average Trajectory Ratio:{\"{:.3f}\".format(test_av_traj)}')\n",
    "print(f'Average USEFULNESS:{test_av_usefulness}')\n",
    "print(f'Average NEUTRALITY:{test_av_entropy}')\n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe358c51",
   "metadata": {},
   "source": [
    "# Draw Policy Diagrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d973b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for env in test_gridworlds:                                        # MANUALLY CHANGE\n",
    "    draw_policy(env, model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ipp_env3 (3.12.9)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
