{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pendulum Environment, OpenAI Gym\n",
    "* Left force: -50N, Right force: 50N, Nothing: 0N, with some amount of noise added to the action\n",
    "* Generate trajectories by starting upright, and then applying random forces. \n",
    "* Failure if the pendulum exceeds +/- pi/2\n",
    "* Setting this problem up: how to encode Newtons? I'm starting things upright- how do we determine success? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import configargparse\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from environments import PendulumEnv\n",
    "from models.agents import NFQAgent\n",
    "from models.networks import NFQNetwork, ContrastiveNFQNetwork\n",
    "from util import get_logger, close_logger, load_models, make_reproducible, save_models\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "import seaborn as sns\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/amandyam/.conda/envs/research/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
      "  warnings.warn(msg, FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:ylabel='Density'>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhaklEQVR4nO3deXSc9X3v8fd3ZrRYiyVr8b7I+4IxmIgtGAhrIOSStNCEBLKSULL0tsk9bWly70nb3HtPktumTdM0gQANhZKNBAqBBDCbAbPZ2Bgjg2153yTZkizLWkfzvX/M2MhGtsb2PDPSo8/rnDmzPfP8vozRZ575Pb/5/czdERGR8InkugAREQmGAl5EJKQU8CIiIaWAFxEJKQW8iEhIxXJdQH9VVVVeU1OT6zJERIaNlStX7nX36oGeG1IBX1NTw4oVK3JdhojIsGFmW4/1nLpoRERCSgEvIhJSCngRkZBSwIuIhJQCXkQkpBTwIiIhpYAXEQkpBbyISEgp4EVEQmpI/ZJ1pLr/lW2nvI9Pnjs1A5WISJjoCF5EJKQU8CIiIaWAFxEJqUAD3szKzewBM3vbzNaZ2flBticiIu8K+iTrD4A/uPv1ZpYPFAXcnoiIpAQW8GZWBlwEfBbA3XuAnqDaExGRIwXZRTMdaAL+3cxWmdmdZlZ89EZmdouZrTCzFU1NTQGWIyIysgQZ8DHgLODH7r4YOAjcdvRG7n6Hu9e6e2119YCrTomIyEkIMuB3ADvc/ZXU/QdIBr6IiGRBYAHv7nuA7WY2N/XQZUBdUO2JiMiRgh5F82fAf6ZG0GwCPhdweyIikhJowLv7aqA2yDZERGRg+iWriEhIKeBFREJKAS8iElIKeBGRkFLAi4iElAJeRCSkFPAiIiGlgBcRCSkFvIhISCngRURCSgEvIhJSCngRkZBSwIuIhJQCXkQkpBTwIiIhpYAXEQkpBbyISEgp4EVEQkoBLyISUgp4EZGQUsCLiISUAl5EJKQU8CIiIaWAFxEJqViQOzezLcABoA+Iu3ttkO2JiMi7Ag34lEvcfW8W2hERkX7URSMiElJBB7wDT5jZSjO7ZaANzOwWM1thZiuampoCLkdEZOQIOuCXuPtZwNXAV8zsoqM3cPc73L3W3Wurq6sDLkdEZOQINODdfWfquhF4EDgnyPZERORdgQW8mRWbWemh28CVwNqg2hMRkSMFOYpmHPCgmR1q5353/0OA7YmISD+BBby7bwLOCGr/IiJyfBomKSISUgp4EZGQUsCLiISUAl5EJKQU8CIiIaWAFxEJKQW8iEhIKeBFREJKAS8iElIKeBGRkFLAi4iElAJeRCSkFPAiIiGlgBcRCSkFvIhISCngRURCSgEvIhJSCngRkZBSwIuIhJQCXkQkpBTwIiIhpYAXEQkpBbyISEgp4EVEQirwgDezqJmtMrPfBd2WiIi8KxtH8H8OrMtCOyIi0k+gAW9mk4FrgDuDbEdERN4r6CP4fwb+CkgE3I6IiBwlsIA3sw8Dje6+cpDtbjGzFWa2oqmpKahyRERGnCCP4C8ArjWzLcAvgEvN7L6jN3L3O9y91t1rq6urAyxHRGRkCSzg3f1v3H2yu9cANwBPu/tNQbUnIiJH0jh4EZGQimWjEXd/Fng2G22JiEiSjuBFREJKAS8iElIKeBGRkEor4M3st2Z2jZnpA0FEZJhIN7D/DfgksMHMvmNmcwOsSUREMiCtgHf3pe5+I3AWsAVYambLzexzZpYXZIEiInJy0u5yMbNK4LPAF4BVwA9IBv6TgVQmIiKnJK1x8Gb2IDAXuBf4b+6+O/XUL81sRVDFiYjIyUv3h04/dffH+j9gZgXu3u3utQHUJSIipyjdLpr/PcBjL2WyEBERyazjHsGb2XhgEjDKzBYDlnpqNFAUcG0iInIKBuui+SDJE6uTge/3e/wA8I2AahIRkQw4bsC7+z3APWZ2nbv/Jks1iYhIBgzWRXOTu98H1JjZ149+3t2/P8DLRERkCBisi6Y4dV0SdCEiIpJZg3XR3J66/rvslCMiIpmS7mRj3zOz0WaWZ2ZPmVmTmWn5PRGRISzdcfBXunsb8GGSc9HMAv4yqKJEROTUpRvwh7pyrgF+7e77A6pHREQyJN2pCn5nZm8DncCXzKwa6AquLBEROVXpThd8G/B+oNbde4GDwEeCLExERE5NukfwAPNIjofv/5r/yHA9IiKSIelOF3wvMBNYDfSlHnYU8CIiQ1a6R/C1wAJ39yCLERGRzEl3FM1aYHyQhYiISGalewRfBdSZ2atA96EH3f3aY73AzAqBZUBBqp0H3P1bp1CriIicgHQD/m9PYt/dwKXu3p5amPsFM/u9u798EvsSEZETlFbAu/tzZjYNmO3uS82sCIgO8hoH2lN381IX9eGLiGRJunPRfBF4ALg99dAk4KE0Xhc1s9VAI/Cku78ywDa3mNkKM1vR1NSUbt0iIjKIdE+yfgW4AGgDcPcNwNjBXuTufe5+JskVoc4xs4UDbHOHu9e6e211dXXahYuIyPGlG/Dd7t5z6E7qx05pd7e4eyvwDHDVCVUnIiInLd2TrM+Z2TdILr59BfBl4JHjvSA1X02vu7ea2SjgCuC7p1TtCLO/s5e397TRE09QNiqPeeNHkx9L9zNZREa6dAP+NuBm4E3gT4HHgDsHec0Ekuu5Rkl+U/iVu//uZAsdSXriCR5bu5tXNzcf8XhRfpTL5o3lvBmVmFmOqhOR4SLdUTQJM3sIeMjd0zoT6u5rgMWnUNuI1NET587nN7OnrYvzZ1Ry7vQKRo/KY1drJ8+ub+KRNbvZ3tLJH581iVhER/MicmyDLbptwLeAr5LqrzezPuCH7v73wZc3svT2Jbj3pa00tXfzmfOnMXf86MPPzaguoaaqmGffaWLpugbcnT+pnUJER/IicgyDHQJ+jeTombPdvcLdK4BzgQvM7GuBVzfC/G7NLrY1d/Cx2ilHhPshETMunTeWKxeM440d+1la15CDKkVkuBgs4D8FfMLdNx96wN03ATcBnw6ysJFmQ8MBXtvSwpLZVZw+qey42148p5raaWN4dn0TGxvbj7utiIxcgwV8nrvvPfrBVD98XjAljTy9fQkeXLWT6pICLp8/btDtzYwPL5pIdUkBD6zcTkdPPAtVishwM1jA95zkc3ICXqrfR2tnLx85cyJ50fROnObHInz87Cm0d8d5/C111YjIew02iuYMM2sb4HEDCgOoZ8TZ39nLc+ubmDOuhBnVJSf02onlozh/RiXL6/exalsLi6eOCahKERmOjnu46O5Rdx89wKXU3dVFkwF3Pb+Jzt4+PnjayU23f9n8cZQWxvjWw2+RSGguNxF5lwZS51BHT5z/eHkrCyaMZkLZqJPaR2FelCsXjGfNjv08tnZ3hisUkeFMAZ9DD6zcQWtHLxfOrjql/Zw5tZy540r5h8ffobcvkaHqRGS4U8DnSF/CueuFzZw1tZxplcWntK+IGX999Vy27OvggZU7MlShiAx3CvgcWbahia37Ovj8kukZ2d8lc8eyaHIZP362nriO4kUEBXzO/PyVbVSV5HPlgsysZW5mfPWSWWxr7uCRNbsysk8RGd4U8DnQ0NbFU283ct37Jmd0+t/L549j3vhS/u2Zeo2oEREFfC48sHIHfQnnhrOnZnS/kYjx5UtmsaGxnSfq9mR03yIy/Cjgs8zd+e3rOzinpoLpVad2cnUg15w+gelVxfzw6Y0k1z0XkZFKAZ9ldbvbqG86yLVnTgxk/9GI8eUPzOStXW08u16LmIuMZAr4LHvkjd3EIsaHTp8QWBsfXTyJCWWF/HTZpsDaEJGhTwGfRe7OI2/sYsnsKiqK8wNrJy8a4XMX1LC8fh9rd+4PrB0RGdoU8Fn0+rYWdrZ2cu0ZwXTP9HfDOVMpKYjx0+d1FC8yUings+jh1bsoiEW48iQnFjsRowvz+MQ5U/jdmt3sbO0MvD0RGXoU8FkS70vw6Ju7uWz+WEoK0lrr/JR97oLpGPCzFzcPuq2IhI8CPkte2rSPve09WemeOWRi+SiuWTSBn7+6nbau3qy1KyJDgwI+Sx55YxclBTE+MHdsVtv94oUzaO+O84tXt2W1XRHJPQV8FvQlnKfWNXLpvLEU5kWz2vbCSWW8f2Yld7+whZ64JiETGUkCC3gzm2Jmz5hZnZm9ZWZ/HlRbQ93r21rYd7CHKxYMvqB2EL544Qz2tHXx6JuahExkJAnyCD4O/A93XwCcB3zFzBYE2N6Q9WRdA3lR4wNzq3PS/sVzqpk9toQ7lm3W9AUiI0hgAe/uu9399dTtA8A6YFJQ7Q1V7s6TdQ2cN6OS0sLcLGMbiRhfvHAG63a3sbx+X05qEJHsy0ofvJnVAIuBVwZ47hYzW2FmK5qawjd3Sn1TO5v3HuTKHHXPHPKRxROpKingDk1fIDJiBB7wZlYC/Ab4C3dvO/p5d7/D3Wvdvba6OjddGEF6sq4RgMtzHPAFsSifff80nlvfxDt7DuS0FhHJjkAD3szySIb7f7r7b4Nsa6h6sm4PCyeNZkLZqFyXwo3nTmNUXlTTF4iMEEGOojHgLmCdu38/qHaGsqYD3aza3soV84OfmiAdY4rz+VjtZP5r9U4a2rpyXY6IBCzII/gLgE8Bl5rZ6tTlQwG2N+Q8ta4Bd3I2PHIgn18ynb6Ec8/yLbkuRUQCFtikKO7+AmBB7X84eLKugUnlo5g/oTTXpRw2rbKYD542nvte3spXLplFcZbmxRGR7NMvWQPS0RPnhY17uWLBOJK9VUPHFy+aQVtXnF+t2J7rUkQkQAr4gDy/YS/d8UTOh0cO5KypY6idNoa7XthMvE/TF4iElb6fB+TJugZGF8Y4e3pFVtq7/5UTm0xs3vhSVmxt4X8+tJZFk8v55LlTA6pMRHJFR/AB6Es4T7/dyCXzxpIXHZpv8bwJo6kqKeDZd5pIaPoCkVAamukzzK3c2kJzDicXS0fEjEvnVbOnrYu3dr3n92ciEgIK+AAsXZecXOziOUP7l7mLJpdTXVLA0283kEjoKF4kbBTwGXZocrHzZ1blbHKxdCWP4sfS0NbNY2t357ocEckwBXyGHZpc7Ir52V256WSdPrmM6tICfrB0A306ihcJFQV8hj1R1wDkfnKxdEXMuGzeWDY0tvO7NVoQRCRMFPAZ9mRdA6dPKhsSk4ula+GkMuaNL+Ufn1hPd7wv1+WISIYo4DOo8UAXq7e3DunRMwOJmPGND81nW3OH5qgRCREFfAY9va5xyE0ulq6L5lRzydxqfvjURva1d+e6HBHJAAV8Bj1Z18DkMaOYN37oTC52Ir55zXw6evv456Ubcl2KiGSAAj5DDk0udvn8oTe5WLpmjS3lxnOncv+r29jQoFWfRIY7BXyGLFs/dCcXOxF/cfkcivOjfOvht3BNYSAyrCngM+QPa3dTXpTHOVmaXCwoFcX5/PXV81hev49fr9iR63JE5BQo4DOgO97HU+sauWL+OGJDdHKxE/GJs6dy7vQKvv1onZb2ExnGhn8aDQHL6/dxoDvO1acPjbVXT1UkYnznukX0xBP8r4fWqqtGZJhSwGfAH97cQ0lBjAtmVeW6lIyZXlXM16+YwxN1DTz25p5clyMiJ0EBf4rifQmeqNvDpfPGUhCL5rqcjLp5yXQWTS7jGw++yY6WjlyXIyInSAF/il7d3ExLRy9XLwxH90x/sWiEf7lhMYmE89X7V9ET1/J+IsOJAv4U/X7tHgrzIlw8d2jP/X6yaqqK+d71i1i9vZXv/P7tXJcjIidAAX8KEgnn8bf28IE5YynKD+/ytlefPoHPXVDD3S9u5vdvat54keFCAX8KVm1vofFAN1eFsHvmaH9z9XzOnFLO13/1Bqu3t+a6HBFJQ2ABb2Z3m1mjma0Nqo1ce3j1LvJjES4dJot7nIr8WISffrqW6tICPv+z19jU1J7rkkRkEEEewf8MuCrA/edUb1+CR9bs5or54xg9xJfmy5Tq0gLu+fw5GPDpu1+l8YB+BCUylAUW8O6+DGgOav+5tmx9E80He/ijxZNyXUpWTa8q5u7Pns2+9h4+fZdCXmQoUx/8SXpw1U7GFOVx0Zxwjp45njOmlHPnZ2rZ1tzBx37yEtubNUZeZCjKecCb2S1mtsLMVjQ1NeW6nLQc6OrlyboGPrxoIvmxnL+FOXHBrCru+8K5tHT0cv1Plmt6YZEhKOdj+9z9DuAOgNra2mEx6cnv1+6hO57gj84KT/fM/a9sO6nXfeb8Gv79xc388Y+X888fP5PL5g/v6ZJFwmRkHn6eoodW7WRaZRGLp5TnupScG19WyK0Xz2RaZRE337OCf3j8HfoSw+JzWiT0ghwm+XPgJWCume0ws5uDaiubdu/v5KVN+/jomZOG7cpNmTamOJ8Hbn0/H6+dwr8+s5FP3fWK5q4RGQKCHEXzCXef4O557j7Z3e8Kqq1senDVTtwZcaNnBlOYF+W71y/ie9clpzW48p+Wcc/yLSR0NC+SM+qiOQGJhPPzV7dx7vQKaqqKc13OkPSxs6fwxNcuoramgm89/BbX/2Q5K7e25LoskRFJAX8Cnt+4l+3Nndx43rRclzKkTR5TxD2fO5t//JMz2NbcyXU/Xs6t966kXr9+FcmqnI+iGU7uf2UrFcX5fPA0jRQZjJlx3fsmc9XC8dz5/GbuWFbPE3V7uPr0CfzpRTNYNLk81yWKhJ4CPk07WztZuq6RL1w4PXQLewSpuCDGn18+mxvPm8qdz2/mP1/eyqNrdnPejApuOm8aVy4YT34sctLDNA/55LlTM1SxSHgo4NN0z/ItAHz6/Jqc1jFUpRPQUyuK+NoVc3htSzMv1e/jq/evoig/yllTx7BochmTykdpZJJIBing03CwO87PX93GVQvHM6l8VK7LGdYK86JcOLuaC2ZVsbGxnde2NLO8fi8vbNzLmKI8Fk4q4/RJCnuRTFDAp+FXK7ZzoCvOzUum57qU0IiYMWdcKXPGldLRE6duVxtrd+3nxY17eX7DXspG5TF7bAlzxpUys7qEUfnqFhM5UQr4QXTH+7j9uU2cXTOGs6aOyXU5oVSUH6O2poLamgo6euKs293G23sO8ObO/azY2kLEYMqYIqZXF1NTWczUiiIK8zIb+DoHIGGkgB/EAyt3sKeti//3J4tyXcqIUJQf433TKnjftAr6Es725g7WNx5gY2M7y9Y38aw3YSSnSJhWWUxNZRGTxxTh7urSETmKAv44euIJfvxsPWdOKWfJrKpclzPiRCNGTVUxNVXFXLkg+W1qe3MnW/cdZOu+Dl7f2sLLm/YBcPuyek6bOJrTJpYdvp5RXUxeVD/1kJFLAX8c9728lR0tnfyfPzpdR4dDQEEsyqyxJcwaWwJAX8LZs7+Lna2dFBdEWburjfte3kp3PAFALGJMqyxiRnUJM6tLmFldzMyxJcysKqGsaGSswiUjmwL+GPZ39vLDpzewZFYVF83W0ftQFI0Yk8aMYtKYUYf7wON9CTbtPcjanfvZ2NhOfVM79U0HefadRnr73p0Xp7I4n2mVRUytKGJqZTG7WjoZU5xPRXE+pYUxIvpAlxBQwB/Dvz69gdbOXm67ep6O3oeRWDRyeHROf/G+BNtbOtnUlAz9TU0H2dbcwWtbWnj4jV30nxMtFrFk2BflU1GSui5+9xJEt49O8koQFPADqNvVxt0vbuHjtVNYOKks1+VIBsSiEaZXFTO9qvg9i5L0xBPc/lw9+w720Hywh5aDPew72ENLRw+b9x2kJ9Xlc0hpYSwZ9qngLy/KY2J5IeNGJy9jivJ0UCBDggL+KH0J55sPvUn5qDxuu3persuRLMiPRagsKaCypOA9z7k7HT19NKfCv7mjh+b25PWmvQdZvb0VB37z+s539xeNUF1acLi7Z3RhHqWFMUoPX8coyo9RlB9NXWJsa+4gPxohP5a6RCPkRU0fFHJKFPBHuX1ZPau2tfJPHz+D8qL8XJcjOWZmFBfEKC6IMaWi6D3Px/sSHOiKc97MChraumlo6zp83dLRw4GuOJv2ttPWGedAVy8He/rSbxvIOyr0C2KR5IfGqDxGF+YxelTyA2Tz3oNMKCvM+O8DZHhTwPezZkcr339iPdcsmsBHz9SCHjK4WDTCmOJ83jetIq3t430JDvb00dnTx8GeOJ09fXT09PHomt309CXoiSfo6UvQG0/QHU/Q2/fudU88QWdvH3vaulnf2H5E19GdL2wGoKokn4nlo5hYNip5XV6Yuk7eriouIBLJ/LcCnUMYmhTwKXvbu/nSfa9TXVrA//2ohkVKMGLRCGWjIpSNOnKY5sbGE58rv6u3j7auXto64yyYOJrdrZ3s2t/JrtYu6pvaWbahiY6jvjHkRyOMHV1AWeobQNmo5KUgL0IsEiEWNWIRIxoxevuc3r7E4UtP/Mj78YTTl3Difc6u/Z2HV+/Kix7aT7KbqaggRmlB7HA31ejCPCpLgjlZLUdSwJP8Q/nSfSvZ297Nr289X2Okh6FTPYIcjgrzohTmRRlbCte/b/J7nnd32jrj7GztZFe/8G9s62J/Zy/7O3upb2pnf2cvB7riJDwZ2Al3Ep4chhqNvBv40YgRtX63I0bEjIhB1Iz8vAhO8ltKR0+C3r5eevucjp44Xb1Hnqg2oKI4n+rSAsaWFlJSGGPxlHImj9Ekc5k04gO+ty/BV+9/nde2tPAvn1ishShkWErnAy4WiSTH/Q9wLiFoPfEE7d3J8xCtnb00Heim8UA3jW1dbGhIftsAqC4tYPGUchZPHcP5Mys5fVIZ0QC6lEaKER3wnT19/PdfrGLpuka+/dGFXHvGxFyXJBJK+bEIFbHksNKjF7zsSziLp5azansrq7a2sGp7K0/UNQDJIannzajkgpmVXDCrilljS3SEfwJGbMA3tHVx630rWb29lb+79jQ+pXVWRXIiGjHW7NhP1OzwrKLt3fHDP0pbubWFJ/sF/szqEmZVlzBzbMnhcxk6STuwERnwS+sa+KvfrKGzp48f33gWVy2ckOuSRKSfkoIYiyaXH+4ybTnYQ31TOxub2tnQcIDV21sBqCopYNbYYipL8jlvRuV7Tl6PdCMq4Dc1tfPdP7zN4281MG98Kf/6ycXMGls6+AtFBjEST/Jm05jifGqLk0f3CXca2rqob0wG/sqtLby8qZmIwemTy1kyq5Jzpldy5pTyER/4oQ/4RMJ5bUsz9768lcfe3E1BLMpfXTWXLyyZQX5Mw7REhpuIGRPKRjGhbBRLZlcTTySYO66UF+v38eLGvfzkuU386Jl6zGBWdQlnTR3D4qnlzJ8wmjnjSkfU6mCBBryZXQX8AIgCd7r7d4Js75CdrZ2s2NLMa1uaeWpdI7v3d1FaGOMLF87glotmUDXAT9JFZHiKRSKcO6OSc2dU8vUr5nCgq5c3tu/n9W0tvL6thT+8tYdfrtgOgBlMryxm7vhSZo0tYUpqVNHUiiLGjy4M5EdguRRYwJtZFPgRcAWwA3jNzB5297pMttOXcH7yXD1b9iYXgdi09yB727sBKM6Pcv7MSv7yg3O5euGEEfXJLTJSlRbmsWR2FUtS03wnEs7W5g7e2dPGut0HeGfPAep2t/H4W3uOmEU0L2pUlxRQVVpAVUkB1SUFjEnNJ1RSkLoUJn+0VVQQS80dZORFI4cv+dEIebFDvw8wzJLfOIzkh0u2RwAFeQR/DrDR3TcBmNkvgI8AGQ34aMT46fObyItGmF5ZzCVzqzlt4mhqayqYN76UmH4tJzKiRSJ2eCbR/gMqevsS7GrtZFtzB9ubO9ne0kFjWzd725NzCa3duZ+Wjp4j1hHISD2poD90bSTH/7/w15dmtB0INuAnAdv73d8BnHv0RmZ2C3BL6m67mb1zsg2ueO9DVcDek91fFqnOzFKdmTeka73x3ZtDus5+jqhzPWC3nfS+jjnGO+cnWd39DuCOIPZtZivcvTaIfWeS6sws1Zl5w6VW1XmkIPsvdgJT+t2fnHpMRESyIMiAfw2YbWbTzSwfuAF4OMD2RESkn8C6aNw9bmZfBR4nOUzybnd/K6j2jiGQrp8AqM7MUp2ZN1xqVZ39mHtmzxCLiMjQoDGEIiIhpYAXEQmpUAW8mVWY2ZNmtiF1PWaAbc40s5fM7C0zW2NmH89ifVeZ2TtmttHsvaNezazAzH6Zev4VM6vJVm1H1TFYnV83s7rU+/eUmeVkruXB6uy33XVm5maWk+Fz6dRpZh9Lvadvmdn92a4xVcNg/+5TzewZM1uV+rf/UI7qvNvMGs1s7TGeNzP7l9R/xxozOyvbNabqGKzOG1P1vWlmy83sjIwX4e6huQDfA25L3b4N+O4A28wBZqduTwR2A+VZqC0K1AMzgHzgDWDBUdt8GfhJ6vYNwC9z8B6mU+clQFHq9peGap2p7UqBZcDLQO1QrBOYDawCxqTujx2idd4BfCl1ewGwJdt1ptq+CDgLWHuM5z8E/J7kyoDnAa8M0Trf3+/f/Oog6gzVETzJqRDuSd2+B/jo0Ru4+3p335C6vQtoBKqzUNvhqRvcvQc4NHVDf/3rfwC4zLK/fM2gdbr7M+7ekbr7MsnfOGRbOu8nwLeB7wJd2Syun3Tq/CLwI3dvAXD3xizXCOnV6cDo1O0yYFcW63u3CPdlQPNxNvkI8B+e9DJQbmZZX/RhsDrdffmhf3MC+jsKW8CPc/fdqdt7gHHH29jMziF5tFIfdGEMPHXDpGNt4+5xYD9QmYXaBqwhZaA6+7uZ5NFStg1aZ+qr+RR3fzSbhR0lnfdzDjDHzF40s5dTs7BmWzp1/i1wk5ntAB4D/iw7pZ2wE/1/eCgI5O8o51MVnCgzWwqMH+Cpb/a/4+5uZsccA5r6RL8X+Iy7J461nRybmd0E1AIX57qWo5lZBPg+8Nkcl5KOGMlumg+QPIpbZmanu3trLosawCeAn7n7P5rZ+cC9ZrZQfz+nxswuIRnwSzK972EX8O5++bGeM7MGM5vg7rtTAT7gV10zGw08Cnwz9RUuG9KZuuHQNjvMLEbya/C+7JT3nhoOGXCKCTO7nOSH6sXu3p2l2vobrM5SYCHwbKqXazzwsJld6+4DzEsXmHTezx0k+197gc1mtp5k4L+WnRKB9Oq8GbgKwN1fMrNCkpNm5aJL6XiGzTQpZrYIuBO42t0z/rceti6ah4HPpG5/BvivozdITZvwIMk+ugeyWFs6Uzf0r/964GlPnYHJokHrNLPFwO3AtTnqL4ZB6nT3/e5e5e417l5Dso8z2+E+aJ0pD5E8esfMqkh22WzKYo2QXp3bgMsAzGw+UAg0ZbXK9DwMfDo1muY8YH+/rtshw8ymAr8FPuXu6wNpJBdnl4O6kOyvfgrYACwFKlKP15JcUQrgJqAXWN3vcmaW6vsQyZlB60l+ewD4e5LBA8k/mF8DG4FXgRk5eh8Hq3Mp0NDv/Xt4KNZ51LbPkoNRNGm+n0ayO6kOeBO4YYjWuQB4keQIm9XAlTmq8+ckR7/1kvz2czNwK3Brv/fzR6n/jjdz+O8+WJ13Ai39/o5WZLoGTVUgIhJSYeuiERGRFAW8iEhIKeBFREJKAS8iElIKeBGRkFLAi4iElAJeRCSk/j/5va79t+YrugAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "env = PendulumEnv()\n",
    "rollouts = []\n",
    "for i in range(10):\n",
    "    rollout, episode_cost = env.generate_rollout()\n",
    "    rollouts.extend(rollout)\n",
    "rewards = [r[2] for r in rollouts]\n",
    "sns.distplot(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(init_experience=100, bg_only=False, agent=None):\n",
    "    env_bg = PendulumEnv(group=0)\n",
    "    env_fg = PendulumEnv(group=1)\n",
    "    bg_rollouts = []\n",
    "    fg_rollouts = []\n",
    "    if init_experience > 0:\n",
    "        for _ in range(init_experience):\n",
    "            rollout_bg, episode_cost = env_bg.generate_rollout(\n",
    "                agent, render=False, group=0\n",
    "            )\n",
    "            bg_rollouts.extend(rollout_bg)\n",
    "            if not bg_only:\n",
    "                rollout_fg, episode_cost = env_fg.generate_rollout(\n",
    "                    agent, render=False, group=1\n",
    "                )\n",
    "                fg_rollouts.extend(rollout_fg)\n",
    "    bg_rollouts.extend(fg_rollouts)\n",
    "    all_rollouts = bg_rollouts.copy()\n",
    "    return all_rollouts, env_bg, env_fg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/1001 [00:01<22:02,  1.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation bg: [73.81142197848965, 72.28765194473992, 73.73941817745897, 73.48222940956178, 73.49007680694399] Evaluation fg: [74.01250678450806, 73.65066927176343, 74.84724716461659, 73.11356339568685, 75.0214335207005]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 301/1001 [01:17<06:24,  1.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation bg: [25.91933239636995, 25.943631812838813, 25.918313560945293, 25.93831096791118, 25.91540663172479] Evaluation fg: [25.96212970651441, 25.909092804253845, 25.90101909324823, 25.93344660781633, 25.913366980623046]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 601/1001 [02:33<03:46,  1.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation bg: [25.930066512891948, 25.919891857887908, 25.920398622003734, 25.98697165556666, 25.919573906670177] Evaluation fg: [25.91221921137715, 25.935719469985738, 25.930626828078918, 25.91169053116927, 25.9646164792548]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 901/1001 [03:50<00:55,  1.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation bg: [25.971403504800346, 25.982871052864095, 25.90971079444938, 25.962789946549336, 25.91520164596239] Evaluation fg: [25.941113174043924, 25.9494665088202, 25.936848622450885, 25.93515375752261, 25.946367689179176]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1001/1001 [04:15<00:00,  3.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation bg: 25.94022201048524 Evaluation fg: 25.939652176939322\n"
     ]
    }
   ],
   "source": [
    "train_rollouts, train_env_bg, train_env_fg = generate_data(init_experience=200, bg_only=True)\n",
    "test_rollouts, eval_env_bg, eval_env_fg = generate_data(init_experience=200, bg_only=True)\n",
    "\n",
    "is_contrastive=True\n",
    "epoch = 1000\n",
    "hint_to_goal = False\n",
    "if hint_to_goal:\n",
    "    goal_state_action_b_bg, goal_target_q_values_bg, group_bg = train_env_bg.get_goal_pattern_set(group=0)\n",
    "    goal_state_action_b_fg, goal_target_q_values_fg, group_fg = train_env_fg.get_goal_pattern_set(group=1)\n",
    "    \n",
    "    goal_state_action_b_bg = torch.FloatTensor(goal_state_action_b_bg)\n",
    "    goal_target_q_values_bg = torch.FloatTensor(goal_target_q_values_bg)\n",
    "    goal_state_action_b_fg = torch.FloatTensor(goal_state_action_b_fg)\n",
    "    goal_target_q_values_fg = torch.FloatTensor(goal_target_q_values_fg)\n",
    "    \n",
    "nfq_net = ContrastiveNFQNetwork(state_dim=train_env_bg.state_dim, is_contrastive=is_contrastive, deep=False)\n",
    "optimizer = optim.Adam(nfq_net.parameters(), lr=1e-1)\n",
    "\n",
    "nfq_agent = NFQAgent(nfq_net, optimizer)\n",
    "\n",
    "bg_success_queue = [0] * 3\n",
    "fg_success_queue = [0] * 3\n",
    "eval_fg = 0\n",
    "evaluations = 5\n",
    "for k, ep in enumerate(tqdm.tqdm(range(epoch + 1))):\n",
    "    state_action_b, target_q_values, groups = nfq_agent.generate_pattern_set(train_rollouts)\n",
    "    if hint_to_goal:\n",
    "        goal_state_action_b = torch.cat([goal_state_action_b_bg, goal_state_action_b_fg], dim=0)\n",
    "        goal_target_q_values = torch.cat([goal_target_q_values_bg, goal_target_q_values_fg], dim=0)\n",
    "        state_action_b = torch.cat([state_action_b, goal_state_action_b], dim=0)\n",
    "        target_q_values = torch.cat([target_q_values, goal_target_q_values], dim=0)\n",
    "        goal_groups = torch.cat([group_bg, group_fg], dim=0)\n",
    "        groups = torch.cat([groups, goal_groups], dim=0)\n",
    "\n",
    "    if not nfq_net.freeze_shared:\n",
    "        loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "\n",
    "    eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg = 0, 0, 0\n",
    "    if nfq_net.freeze_shared:\n",
    "        eval_fg += 1\n",
    "        if eval_fg > 50:\n",
    "            loss = nfq_agent.train((state_action_b, target_q_values, groups))\n",
    "\n",
    "    (eval_episode_length_bg, eval_success_bg, eval_episode_cost_bg) = nfq_agent.evaluate_pendulum(eval_env_bg, render=False)\n",
    "    bg_success_queue = bg_success_queue[1:]\n",
    "    bg_success_queue.append(1 if eval_success_bg else 0)\n",
    "    \n",
    "    (eval_episode_length_fg, eval_success_fg, eval_episode_cost_fg) = nfq_agent.evaluate_pendulum(eval_env_fg, render=False)\n",
    "    fg_success_queue = fg_success_queue[1:]\n",
    "    fg_success_queue.append(1 if eval_success_fg else 0)\n",
    "\n",
    "    if sum(bg_success_queue) == 3 and not nfq_net.freeze_shared == True:\n",
    "        nfq_net.freeze_shared = True\n",
    "        print(\"FREEZING SHARED\")\n",
    "        if is_contrastive:\n",
    "            for param in nfq_net.layers_shared.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_last_shared.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                param.requires_grad = True\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                param.requires_grad = True\n",
    "        else:\n",
    "            for param in nfq_net.layers_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "            for param in nfq_net.layers_last_fg.parameters():\n",
    "                param.requires_grad = False\n",
    "\n",
    "        optimizer = optim.Adam(\n",
    "            itertools.chain(\n",
    "                nfq_net.layers_fg.parameters(),\n",
    "                nfq_net.layers_last_fg.parameters(),\n",
    "            ),\n",
    "            lr=1e-1,\n",
    "        )\n",
    "        nfq_agent._optimizer = optimizer\n",
    "    if sum(fg_success_queue) == 3:\n",
    "        print(\"Done Training\")\n",
    "        break\n",
    "    if ep % 300 == 0:\n",
    "        perf_bg = []\n",
    "        perf_fg = []\n",
    "        for it in range(evaluations):\n",
    "            (eval_episode_length_bg,eval_success_bg,eval_episode_cost_bg) = nfq_agent.evaluate_pendulum(eval_env_bg, render=False)\n",
    "            (eval_episode_length_fg,eval_success_fg,eval_episode_cost_fg) = nfq_agent.evaluate_pendulum(eval_env_fg, render=False)\n",
    "            perf_bg.append(eval_episode_cost_bg)\n",
    "            perf_fg.append(eval_episode_cost_fg)\n",
    "            train_env_bg.close()\n",
    "            train_env_fg.close()\n",
    "            eval_env_bg.close()\n",
    "            eval_env_fg.close()\n",
    "        print(\"Evaluation bg: \" + str(perf_bg) + \" Evaluation fg: \" + str(perf_fg))\n",
    "perf_bg = []\n",
    "perf_fg = []\n",
    "for it in range(evaluations*10):\n",
    "    (eval_episode_length_bg,eval_success_bg,eval_episode_cost_bg) = nfq_agent.evaluate_car(eval_env_bg, render=False)\n",
    "    (eval_episode_length_fg,eval_success_fg,eval_episode_cost_fg) = nfq_agent.evaluate_car(eval_env_fg, render=False)\n",
    "    perf_bg.append(eval_episode_cost_bg)\n",
    "    perf_fg.append(eval_episode_cost_fg)\n",
    "    eval_env_bg.close()\n",
    "    eval_env_fg.close()\n",
    "print(\"Evaluation bg: \" + str(sum(perf_bg)/len(perf_bg)) + \" Evaluation fg: \" + str(sum(perf_fg)/len(perf_fg)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research [~/.conda/envs/research/]",
   "language": "python",
   "name": "conda_research"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
