{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "N_DIMS = 1\n",
    "NUM_SAMPLES = 50000\n",
    "BS = 500\n",
    "NUM_EPOCHS = 400\n",
    "SEED = 10\n",
    "LR = 1e-2\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "# Break by changing num datapoints, scales, means, or to 2D"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "source": [
    "# Define model\n",
    "model = RatioCritic1D(dim_input=N_DIMS, dim_output=3, dropout=DROPOUT)\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_1d(mu1=-1., mu2=2., mu3=0, scale_p=0.08, scale_q=0.15, scale_m=1.0)\n",
    "\n",
    "# -5, 5, m_var=3.0\n",
    "# -10, 10, m_var=3.0"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES) # Test dataset is only of size batch "
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Sampling p\n",
      "Sampling q\n",
      "Sampling m\n",
      "Sampling p\n",
      "Sampling q\n",
      "Sampling m\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS, shuffle=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3] = plt.subplots(1, 3,figsize=(15,4))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line, = ax3.plot([0,1],[0,1])\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-6,10])\n",
    "ax2.set_ylim([-1500,5000])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Test Loss\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax3.set_ylim([0,10])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "test_loss_store = []"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABDAAAAEYCAYAAACqUwbqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAybElEQVR4nO3de7gcVZno/+9rEglyFyO3oIkMwoRbhBBEkImBEwL4ExlRyBk0Os6JeGBEHUXQOQM6OKI4gs4wzuQnCAhCkIGR4yhXiQgKIcFwFwk3SQSJAUFQbsl7/ui1QyfsvbOTvburdvf38zz9dNWq6lpvV3fWXnl71arITCRJkiRJkursVVUHIEmSJEmStCYmMCRJkiRJUu2ZwJAkSZIkSbVnAkOSJEmSJNWeCQxJkiRJklR7JjAkSZIkSVLttSyBERFnR8TjEXFnU9lrI+LqiLivPG/WqvolqVtFxEMRcUdELIyI+aWs1/Y3Gr4REYsi4vaI2L3pODPL/vdFxMyq3o8kyb61JEFrR2CcA0xfrewE4NrM3B64tqxLkobeOzJzYmZOKut9tb8HAduXxyzgm9DoFAMnAXsBk4GT7BhLUqXOwb61pC7XsgRGZl4PPLFa8aHAuWX5XODdrapfkrSKvtrfQ4HzsuEmYNOI2Ao4ELg6M5/IzCeBq3llx1mS1Cb2rSUJRra5vi0y89Gy/BiwRV87RsQsGr8GssEGG+yx4447tiE8Sd1qwYIFv8vMMVXHMUQSuCoiEviPzJxN3+3vNsAjTa9dXMr6Kl+FbbWkduuw9nqw7FtLqqVWtdXtTmCslJlZOtd9bZ8NzAaYNGlSzp8/v22xSeo+EfFw1TEMoX0zc0lEvB64OiJ+2bxxTe3v2rCtltRuHdZeDxn71pLqpFVtdbvvQvLbMjSZ8vx4m+uXpI6XmUvK8+PAZTTmsOir/V0CbNv08rGlrK9ySVJ92LeW1FXancC4HOiZyX4m8P021y9JHS0iNoiIjXqWgWnAnfTd/l4OfKDcjeStwFNlOPKVwLSI2KxM3jmtlEmS6sO+taSu0rJLSCLiQmAK8LqIWExjNvtTgYsj4sPAw8D7WlW/JHWpLYDLIgIabfx3M/OKiLiF3tvfHwIHA4uAPwIfAsjMJyLiH4Fbyn5fyMzVJ4+TJLWJfWtJamECIzNn9LFp/1bVKUndLjMfAHbrpXwZvbS/mZnAMX0c62zg7KGOUepWL774IosXL+a5556rOpRhYfTo0YwdO5ZRo0ZVHUot2LeWpAon8ZQkSeomixcvZqONNmLcuHGUUVLqQ2aybNkyFi9ezPjx46sOR5JUE+2eA0OSJKkrPffcc2y++eYmLwYgIth8880drSJJWoUJDEmSpDYxeTFwnitJ0upMYEiSJEmSpNozgSFJktQFli1bxsSJE5k4cSJbbrkl22yzzcr1F154oeX1P/roo0ybNq3l9UiSOpeTeEqSJHWBzTffnIULFwJw8skns+GGG/KpT31q5faXXnqJkSNb1zW84oorOPDAA1t2fElS53MEhiRJUpf64Ac/yNFHH81ee+3F8ccfz8knn8xXv/rVldt33nlnHnroIQDOP/98Jk+ezMSJE/nIRz7C8uXLX3G8cePGcfzxx7PLLrswefJkFi1atHLbFVdcwUEHHURmcuyxx7LDDjtwwAEHcPDBB3PJJZe0/L1KkoY/ExiSJEk1ddVV8NnPNp5bZfHixfzsZz/ja1/7Wp/73HPPPcyZM4cbb7yRhQsXMmLECC644IJe991kk0244447OPbYY/n4xz8OwPLly7n33nuZMGECl112Gffeey9333035513Hj/72c9a8bYkSR3IS0gkSZJq6Kqr4Kij4Pnn4VvfgvPPh1ZMIfHe976XESNG9LvPtddey4IFC9hzzz0B+NOf/sTrX//6XvedMWPGyudPfOITANx8883stddeAFx//fXMmDGDESNGsPXWWzN16tSheiuSpA5nAkOSJKmG5s5tJC822ACefbax3ooExgYbbLByeeTIkaxYsWLl+nPPPQdAZjJz5ky+9KUvrfF4zbc/7Vn+0Y9+xPTp04cqZElSl/ISEkmSpBqaMgXWW6+RvFhvvcZ6q40bN45bb70VgFtvvZUHH3wQgP33359LLrmExx9/HIAnnniChx9+uNdjzJkzZ+Xz3nvvDTRGcBxwwAEA7LfffsyZM4fly5fz6KOPct1117X0PUmSOocjMCRJkmpo2rTGZSNz5zaSF+24A+l73vMezjvvPHbaaSf22msv3vzmNwMwYcIETjnlFKZNm8aKFSsYNWoUZ555Jm984xtfcYwnn3ySXXfdlfXWW48LL7yQpUuXMnr0aDbaaCMADjvsMH784x8zYcIE3vCGN6xMckiStCYmMCRJkmpq2rTWJC5OPvnkXsvXX399rupjxtAjjjiCI444Yo3H/vSnP82Xv/zllevnn38+05reRETwr//6ryvXP/jBDw4saElS1zOBIUmSpJY56qijqg5BktQhTGBIkiRpSDz00ENr/ZpzzjlnyOOQJHUmJ/GUJEmSJEm1ZwJDkiRJkiTVngkMSZIkSZJUeyYwJEmSJElS7ZnAkCRJ6hKPPfYYRx55JNtttx177LEHBx98ML/61a/63P+hhx5i/fXXZ+LEiey222687W1v4957713rek899VQuuOCCwYQuSZIJDEmSpG6QmRx22GFMmTKF+++/nwULFvClL32J3/72t/2+brvttmPhwoXcdtttzJw5k3/6p39a67qvvPJKpk2btq6hS5IEmMCQJEnqCtdddx2jRo3i6KOPXlm222678fa3v53M5NOf/jQ777wzu+yyC3PmzOn1GE8//TSbbbbZK8rnzp3LfvvtxyGHHMIOO+zA0UcfzYoVK1a+5oUXXmDMmDE8+OCD7L333uyyyy78/d//PRtuuGFr3qwkqSONrDoASZIk9eGqq2DuXJgyBQY5guHOO+9kjz326HXbpZdeunKUxe9+9zv23HNP9ttvPwDuv/9+Jk6cyB/+8Af++Mc/cvPNN/d6jHnz5nH33Xfzxje+kenTp3PppZdy+OGHc80117D//vsDcNxxx/HRj36UD3zgA5x55pmDej+SpO7jCAxJkqQ6uuoqOOooOPPMxvNVV7WsqhtuuIEZM2YwYsQItthiC/7iL/6CW265BXj5EpL777+fM844g1mzZvV6jMmTJ/OmN72JESNGMGPGDG644QYArrjiCg466CAAbrzxRmbMmAHA+9///pa9H0lSZzKBIUmSVEdz58Lzz8MGGzSe584d1OF22mknFixYMKhjvOtd7+L666/vdVtE9Lo+b948Jk+e3Od+kiQNlAkMSZKkOpoyBdZbD559tvE8ZcqgDjd16lSef/55Zs+evbLs9ttv56c//Slvf/vbmTNnDsuXL2fp0qVcf/31qyQdetxwww1st912vR5/3rx5PPjgg6xYsYI5c+aw7777ctddd7HjjjsyYsQIAPbZZx8uuugiAO9KIklaayYwJEmS6mjaNDj/fDjmmMbzIOfAiAguu+wyrrnmGrbbbjt22mknTjzxRLbccksOO+wwdt11V3bbbTemTp3KV77yFbbcckvg5TkwdtttNz772c/yrW99q9fj77nnnhx77LH8+Z//OePHj+ewww7jRz/6EdOnT1+5z9e//nXOPPNMdtllF5YsWTKo9yNJ6j5O4ilJHSgiRgDzgSWZ+c6IGA9cBGwOLADen5kvRMR6wHnAHsAy4IjMfKgc40Tgw8By4GOZeWX734nU5aZNG3TiotnWW2/NxRdf3Ou20047jdNOO22VsnHjxvGnP/1pQMfeeOON+cEPfrBK2ZVXXsl55523cn38+PH8/Oc/X7l+xhlnDDBySZIcgSFJneo44J6m9S8Dp2fmnwFP0khMUJ6fLOWnl/2IiAnAkcBOwHTg30pSRJIG7Oqrr2arrbaqOgxJUocwgSFJHSYixgKHAN8q6wFMBS4pu5wLvLssH1rWKdv3L/sfClyUmc9n5oPAIuCVF8RLEjBlypRXjL4YiGeeeaYF0UiSOpUJDEnqPGcAxwMryvrmwO8z86WyvhjYpixvAzwCULY/VfZfWd7La1aKiFkRMT8i5i9dunSI34bUeTKz6hCGDc+VJGl1JjAkqYNExDuBxzNzcPdKHKDMnJ2ZkzJz0pgxY9pRpTRsjR49mmXLlvkf8wHITJYtW8bo0aOrDkWSVCNO4ilJnWUf4F0RcTAwGtgY+DqwaUSMLKMsxgI90/8vAbYFFkfESGATGpN59pT3aH6NpHUwduxYFi9ejKOVBmb06NGMHTu26jAkSTViAkOSOkhmngicCBARU4BPZeZfRcT3gMNp3IlkJvD98pLLy/rPy/YfZ2ZGxOXAdyPia8DWwPbAvDa+FanjjBo1ivHjx1cdhiRJw5YJDEnqDp8BLoqIU4BfAGeV8rOA70TEIuAJGnceITPvioiLgbuBl4BjMnN5+8OWJEmSGkxgSFKHysy5wNyy/AC93EUkM58D3tvH678IfLF1EUqSJEkD5ySekiRJkiSp9ipJYETEJyLiroi4MyIujAinmJYkSZLWgX1rSd2i7QmMiNgG+BgwKTN3BkZQrrmWJEmSNHD2rSV1k6ouIRkJrF9u2fca4DcVxSFJkiQNd/atJXWFticwMnMJ8FXg18CjwFOZeVW745AkSZKGO/vWkrpJFZeQbAYcCowHtgY2iIijetlvVkTMj4j5S5cubXeYkiRJUu3Zt5bUTaq4hOQA4MHMXJqZLwKXAm9bfafMnJ2ZkzJz0pgxY9oepCRJkjQM2LeW1DWqSGD8GnhrRLwmIgLYH7ingjgkSZKk4c6+taSuUcUcGDcDlwC3AneUGGa3Ow5JkiRpuLNvLambjKyi0sw8CTipirolSZKkTmLfWlK3qOo2qpIkSZIkSQNmAkOSJEmSJNWeCQxJkiRJklR7JjAkSZIkSVLtmcCQJEmSJEm1ZwJDkiRJkiTVngkMSZIkSZJUeyYwJEmSJElS7ZnAkCRJkiRJtWcCQ5IkSZIk1Z4JDEmSJEmSVHsmMCRJkiRJUu2ZwJAkSZIkSbVnAkOSJEmSJNWeCQxJkiRJklR7JjAkSZIkSVLtmcCQpA4SEaMjYl5E3BYRd0XE50v5+Ii4OSIWRcSciHh1KV+vrC8q28c1HevEUn5vRBxY0VuSJEmSABMYktRpngemZuZuwERgekS8FfgycHpm/hnwJPDhsv+HgSdL+ellPyJiAnAksBMwHfi3iBjRzjciSZIkNTOBIUkdJBueKaujyiOBqcAlpfxc4N1l+dCyTtm+f0REKb8oM5/PzAeBRcDk1r8DSZIkqXcmMCSpw0TEiIhYCDwOXA3cD/w+M18quywGtinL2wCPAJTtTwGbN5f38hpJkiSp7UxgSFKHyczlmTkRGEtj1MSOraorImZFxPyImL906dJWVSNJkiSZwJCkTpWZvweuA/YGNo2IkWXTWGBJWV4CbAtQtm8CLGsu7+U1zXXMzsxJmTlpzJgxrXgbkiRJEmACQ5I6SkSMiYhNy/L6wP8A7qGRyDi87DYT+H5ZvrysU7b/ODOzlB9Z7lIyHtgemNeWNyFJkiT1YuSad5EkDSNbAeeWO4a8Crg4M38QEXcDF0XEKcAvgLPK/mcB34mIRcATNO48QmbeFREXA3cDLwHHZObyNr8XSZIkaSUTGJLUQTLzduAtvZQ/QC93EcnM54D39nGsLwJfHOoYJUmSpHXhJSSSJEmSJKn2TGBIkiRJkqTaM4EhSZIkSZJqzwSGJEmSJEmqPRMYkiRJkiSp9kxgSJIkSZKk2jOBIUmSJEmSas8EhiRJkiRJqj0TGJIkSZIkqfZMYEiSJEmSpNozgSFJkiRJkmrPBIYkSZIkSao9ExiSJEmSJKn2KklgRMSmEXFJRPwyIu6JiL2riEOSJEka7uxbS+oWIyuq9+vAFZl5eES8GnhNRXFIkiRJw519a0ldoe0JjIjYBNgP+CBAZr4AvNDuOCRJkqThzr61pG6yxktIImK7iFivLE+JiI9FxKaDqHM8sBT4dkT8IiK+FREb9FLvrIiYHxHzly5dOojqJEmSpI5l31pS1xjIHBj/CSyPiD8DZgPbAt8dRJ0jgd2Bb2bmW4BngRNW3ykzZ2fmpMycNGbMmEFUJ0mSJHUs+9aSusZAEhgrMvMl4DDgXzLz08BWg6hzMbA4M28u65fQaHQlqWtExCYRcXrPr2ER8c9lGLAkSWvDvrWkrjGQBMaLETEDmAn8oJSNWtcKM/Mx4JGI2KEU7Q/cva7Hk6Rh6mzgaeB95fE08O1KI5IktVxEfCUiNo6IURFxbUQsjYij1vV49q0ldZOBTOL5IeBo4IuZ+WBEjAe+M8h6/xa4oMyS/ECpQ5K6yXaZ+Z6m9c9HxMKqgpEktc20zDw+Ig4DHgL+ErgeOH8Qx7RvLakrrDGBkZl3Ax8DiIjNgI0y88uDqTQzFwKTBnMMSRrm/hQR+2bmDQARsQ/wp4pjkiS1Xk//+xDge5n5VEQM6oD2rSV1izUmMCJiLvCusu8C4PGIuDEzP9ni2CSpk30UOLfMexHAE5Rb4EmSOtoPIuKXNJLWH42IMcBzFcckScPCQC4h2SQzn46IvwHOy8yTIuL2VgcmSZ2s/Fq2W0RsXNafrjYiSVI7ZOYJEfEV4KnMXB4RzwKHVh2XJA0HA0lgjIyIrWhMMve5FscjSR0tIo7KzPMj4pOrlQOQmV+rJDBJUltExHuBK0ry4u9p3DHkFOCxaiOTpPobyF1IvgBcCdyfmbdExJuA+1obliR1rA3K80a9PDasKihJUtv8n8z8Q0TsCxwAnAV8s+KYJGlYGMgknt8Dvte0/gDwnr5fIUnqS2b+R1m8JjNvbN5WJvKUJHW25eX5EGB2Zv53RJxSZUCSNFyscQRGRIyNiMsi4vHy+M+IGNuO4CSpg/3LAMvWSkRsGxHXRcTdEXFXRBxXyl8bEVdHxH3lebNSHhHxjYhYFBG3R8TuTceaWfa/LyJmDjY2SRIASyLiP4AjgB9GxHoMbFS0JHW9gcyB8W3gu8B7y/pRpex/tCooSepUEbE38DZgzGrzYGwMjBiCKl4C/i4zb42IjYAFEXE1jTucXJuZp0bECcAJwGeAg4Dty2MvGsOY94qI1wIn0bgtX5bjXJ6ZTw5BjJLUzd4HTAe+mpm/L3PNfbrimCRpWBhItndMZn47M18qj3OAMS2OS5I61atpzHUxklXnv3gaOHywB8/MRzPz1rL8B+AeYBsaM9yfW3Y7F3h3WT6Uxh2mMjNvAjYtnekDgasz84mStLiaRodbkjQImflH4H7gwIg4Fnh9Zl5VcViSNCwMZATGsog4CriwrM8AlrUuJEnqXJn5E+AnEXFOZj7cyroiYhzwFuBmYIvMfLRsegzYoixvAzzS9LLFpayv8tXrmAXMAnjDG94whNFLUmcql/b9L+DSUnR+RMzOzEFfRihJnW4gCYy/pnFd9uk0hhH/jMZQZEnSuvtjRJwG7ASM7inMzKlDcfCI2BD4T+Djmfl0z21aSx0ZETkU9WTmbGA2wKRJk4bkmJLU4T4M7JWZzwJExJeBnzME8yBJUqdb4yUkmflwZr4rM8dk5usz893Aca0PTZI62gXAL4HxwOeBh4BbhuLAETGKRvLigszs+YXvt+XSEMrz46V8CbBt08vHlrK+yiVJgxO8fCcSynL0sa8kqcm6znj8viGNQpK6z+aZeRbwYmb+JDP/Ghj06ItoDLU4C7gnM7/WtOlyoOdOIjOB7zeVf6DcjeStwFPlUpMrgWkRsVm5Y8m0UiZJGpxvAzdHxMkRcTJwE412W5K0BgO5hKQ3ZoklaXBeLM+PRsQhwG+A1w7BcfcB3g/cERELS9lngVOBiyPiw8DDvJyI/iFwMLAI+CPwIYDMfCIi/pGXR4V8ITOfGIL4JKmrZebXImIusG8p+hDw2+oikqTho88ERrmFXq+bMIEhSYN1SkRsAvwdjeueNwY+PtiDZuYN9N1G79/L/gkc08exzgbOHmxMkqRVlbtF3dqzHhG/BpwJWZLWoL8RGAtoTNrZW0f4hdaEI0ndITN/UBafAt4BEBH7VBeRJKlC/jgoSQPQZwIjM8e3MxBJ6gYRMYLG5RvbAFdk5p0R8U4al3msT+O2p5Kk7uJdnCRpANZ1DgxJ0ro5i8bdPeYB34iI3wCTgBMy87+qDEyS1DoR8S/0nqgIYNP2RiNJw5MJDElqr0nArpm5IiJGA48B22XmsorjkiS11vx13CZJKkxgSFJ7vZCZKwAy87mIeMDkhSR1vsw8t+oYJGm4G1ACo1yzvUXz/pn561YFJUkdbMeIuL0sB7BdWQ8aNwXZtbrQJEmSpPpaYwIjIv4WOInG/alXlOIE7GRL0tr786oDkCRJkoajgYzAOA7YwSHOkjR4mflw1TFIkqoTEftk5o1rKpMkvdKrBrDPI8BTrQ5EkiRJ6gL/MsAySdJqBjIC4wFgbkT8N/B8T2Fmfq1lUUmSJEkdJCL2Bt4GjImITzZt2hgYUU1UkjS8DCSB8evyeHV5SJIkSVo7rwY2pNH/3qip/Gng8EoikqRhZo0JjMz8fDsCkaRuEhF30JgQudlTwHzgFOcdkqTOkpk/AX4SEef0zIcUEa8CNszMp6uNTpKGhz4TGBFxRmZ+PCL+L6/sZJOZ72ppZJLU2X4ELAe+W9aPBF4DPAacA/x/1YQlSWqxL0XE0TT+BtwCbBwRX8/M0yqOS5Jqr78RGN8pz19tRyCS1GUOyMzdm9bviIhbM3P3iDiqsqgkSa02ITOfjoi/opHMPgFYAJjAkKQ16DOBkZkLyvNP2heOJHWNERExOTPnAUTEnrw8idtL1YUlSWqxURExCng38K+Z+WJEvGK0syTpldY4B0ZEbA98CZgAjO4pz8w3tTAuSep0fwOcHREbAkFjErcPR8QGNNpcSVJn+g/gIeA24PqIeCONvwGSpDUYyF1Ivg2cBJwOvAP4EPCqVgYlSZ0uM28BdomITcr6U02bL64mKklSq2XmN4BvNBU9HBHvqCoeSRpOBpKIWD8zrwUiMx/OzJOBQ1obliR1tojYJCK+BlwLXBsR/9yTzJAkda6I2CIizoqIH5X1CcDMisOSpGFhIAmM58stnu6LiGMj4jAa97CWJK27s4E/AO8rj6dpjHiTJHW2c4Arga3L+q+Aj1cVjCQNJwNJYBxH49Z+HwP2AI7CLLEkDdZ2mXlSZj5QHp8HnFtIkjpURPRcuv26zLwYWAGQmS/RuKWqJGkN+k1gRMQI4IjMfCYzF2fmhzLzPZl5U5vik6RO9aeI2LdnJSL2Af5UYTySpNaaV56fjYjNgQSIiLcCT/X5KknSSn1O4hkRIzPzpeYOtiRpyBwNnNc078WTOLpNkjpZlOdPApcD20XEjcAY4PDKopKkYaS/u5DMA3YHfhERlwPfA57t2ZiZl7Y4NknqWJl5G7BbRGxc1p+OiI8Dt1camCSpVcZExCfL8mXAD2kkNZ4HDsD2X5LWaCC3UR0NLAOm0hjqFuV5UAmMcnnKfGBJZr5zMMeSpOEqM59uWv0kcEZFoUiSWmsEjYnwY7Xy1wzFwe1bS+oG/SUwXl+yxHfycuKiRw5B3ccB9wAbD8GxJKkTrN6plSR1jkcz8wstPL59a0kdr79JPHuyxBsCGzUt9zzWWUSMBQ4BvjWY40hShxmK5LAkqZ5alqS2by2pW/Q3AqOVWeIzgONpJEZ6FRGzgFkAb3jDG1oUhiS1V0T8gd4TFQGs3+ZwJEnts38Lj30G9q0ldYH+RmC0JEscEe8EHs/MBf3tl5mzM3NSZk4aM2ZMK0KRpLbLzI0yc+NeHhtl5kDmJepXRJwdEY9HxJ1NZa+NiKsj4r7yvFkpj4j4RkQsiojbI2L3ptfMLPvfFxHeHUWSBikzn2jFce1bS+om/SUwWpUl3gd4V0Q8BFwETI2I81tUlyR1m3OA6auVnQBcm5nbA9eWdYCDgO3LYxbwTWgkPICTgL2AycBJPUkPSVLt2LeW1DX6TGC0KkucmSdm5tjMHAccCfw4M49qRV2S1G0y83pg9fb7UODcsnwu8O6m8vOy4SZg04jYCjgQuDozn8jMJ4GreWVSRJJUA/atJXWT/kZgSJI6wxaZ+WhZfgzYoixvAzzStN/iUtZX+StExKyImB8R85cuXTq0UUuSJElNKk1gZOZc71MtSe2TmckQ3u3Ea6olqT7sW0vqdI7AkKTO99tyaQjl+fFSvgTYtmm/saWsr3JJkiSpMiYwJKnzXQ703ElkJvD9pvIPlLuRvBV4qlxqciUwLSI2K5N3TitlkiRJUmUGfcs+SVJ9RMSFwBTgdRGxmMbdRE4FLo6IDwMPA+8ru/8QOBhYBPwR+BA0JnGOiH8Ebin7faFVEztLkiRJA2UCQ5I6SGbO6GPTK26NXebDOKaP45wNnD2EoUmSJEmD4iUkkiRJkiSp9kxgSJIkSZKk2jOBIUmSJEmSas8EhiRJkiRJqj0TGJIkSZIkqfZMYEiSJEmSpNozgSFJkiRJkmrPBIYkSZIkSao9ExiSJEmSJKn2TGBIkiRJkqTaM4EhSZIkSZJqzwSGJEmSJEmqPRMYkiRJkiSp9kxgSJIkSZKk2jOBIUmSJEmSas8EhiRJkiRJqj0TGJIkSZIkqfZMYEiSJEmSpNozgSFJkiRJkmrPBIYkSZIkSao9ExiSJEmSJKn2TGBIkiRJkqTaM4EhSZIkSZJqzwSGJEmSJEmqPRMYkiRJkiSp9kxgSJIkSZKk2jOBIUnqU0RMj4h7I2JRRJxQdTySJEnqXiYwJEm9iogRwJnAQcAEYEZETKg2Kg3G8ghyCB+/edU2Vb8lSZLURUxgSJL6MhlYlJkPZOYLwEXAoRXHpHW0PGLI/+hvlb8xiSFJktrGBIYkqS/bAI80rS8uZStFxKyImB8R85cuXdrW4LR2WvUHf8v8TYuOLEmStCoTGJKkdZaZszNzUmZOGjNmTNXhqB8rWnTcx2LrFh1ZkiRpVSYwJEl9WQJs27Q+tpRpGBqROeRJjEdja7Ze4VdCkiS1x8iqA5Ak1dYtwPYRMZ5G4uJI4H9WG5IGY0TmkB7PsReSJKmd2j4CIyK2jYjrIuLuiLgrIo5rdwySpDXLzJeAY4ErgXuAizPzrmqjkiQ1s28tqZtUMQLjJeDvMvPWiNgIWBARV2fm3RXEIknqR2b+EPhh1XFIkvpk31pS12j7CIzMfDQzby3Lf6Dxq573YJMkSZLWkn1rSd2k0kk8I2Ic8Bbg5irjkCRJkoY7+9aSOl1lCYyI2BD4T+Djmfl0L9tnRcT8iJi/dOnS9gcoSZIkDRP2rSV1g0oSGBExikYDe0FmXtrbPpk5OzMnZeakMWPGtDdASZIkaZiwby2pW1RxF5IAzgLuycyvtbt+SZIkqVPYt5bUTaoYgbEP8H5gakQsLI+DK4hDkiRJGu7sW0vqGm2/jWpm3gBEu+uVJEmSOo19a0ndpNK7kEiSJEmSJA2ECQxJkiRJklR7JjAkSZIkSVLtmcCQJEmSJEm1ZwJDkiRJkiTVngkMSZIkSZJUeyYwJEmSJElS7ZnAkCRJkiRJtWcCQ5IkSZIk1Z4JDEmSJEmSVHsmMCRJkiRJUu2ZwJAkSZIkSbVnAkOSJEmSJNWeCQxJkiRJklR7JjAkSZIkSVLtmcCQJEmSJEm1NywSGCsyqw5BkiRJ6gh2rSUNV8MigfHYU89VHYIkSZLUEZ55/sWqQ5CkdTIsEhiSJEmSJKm7mcCQJEmSJEm1ZwJDkjpERLw3Iu6KiBURMWm1bSdGxKKIuDciDmwqn17KFkXECU3l4yPi5lI+JyJe3c73IkmSJK3OBIYkdY47gb8Erm8ujIgJwJHATsB04N8iYkREjADOBA4CJgAzyr4AXwZOz8w/A54EPtyetyBJkiT1zgSGJHWIzLwnM+/tZdOhwEWZ+XxmPggsAiaXx6LMfCAzXwAuAg6NiACmApeU158LvLvlb0CSJEnqhwkMSep82wCPNK0vLmV9lW8O/D4zX1qt/BUiYlZEzI+I+UuXLh3ywCVJkqQeI6sOQJI0cBFxDbBlL5s+l5nfb3c8mTkbmA0wadKkbHf9kiRJ6h4mMCRpGMnMA9bhZUuAbZvWx5Yy+ihfBmwaESPLKIzm/SVJkqRKeAmJJHW+y4EjI2K9iBgPbA/MA24Bti93HHk1jYk+L8/MBK4DDi+vnwm0fXSHJEmS1MwEhiR1iIg4LCIWA3sD/x0RVwJk5l3AxcDdwBXAMZm5vIyuOBa4ErgHuLjsC/AZ4JMRsYjGnBhntffdSJIkSavyEhJJ6hCZeRlwWR/bvgh8sZfyHwI/7KX8ARp3KZEkSZJqwREYkiRJkiSp9kxgSJIkSZKk2jOBIUmSJEmSas8EhiRJkiRJqj0TGJIkSZIkqfZMYEiSJEmSpNozgSFJkiRJkmrPBIYkSZIkSao9ExiSJEmSJKn2KklgRMT0iLg3IhZFxAlVxCBJkiR1AvvWkrpF2xMYETECOBM4CJgAzIiICe2OQ5IkSRru7FtL6iZVjMCYDCzKzAcy8wXgIuDQCuKQJEmShjv71pK6xsgK6twGeKRpfTGw1+o7RcQsYFZZfT4i7mxDbAPxOuB3VQdR1CkWqFc8xtK3OsVTp1h2qDqA4W7BggXPRMS9VcdR1Om7BfWKx1j6Vqd4jKVvttersm89dIylb3WKx1j6Vqd4WtJWV5HAGJDMnA3MBoiI+Zk5qeKQAGPpT53iMZa+1SmeusVSdQwd4N46fZ51iQXqFY+x9K1O8RhL32yv14196zUzlr7VKR5j6Vud4mlVW13FJSRLgG2b1seWMkmSJElrx761pK5RRQLjFmD7iBgfEa8GjgQuryAOSZIkabizby2pa7T9EpLMfCkijgWuBEYAZ2fmXWt42ezWRzZgxtK3OsVjLH2rUzzG0lnqdA7rFAvUKx5j6Vud4jGWvtUtnkrZtx5SxtK3OsVjLH2rUzwtiSUysxXHlSRJkiRJGjJVXEIiSZIkSZK0VkxgSJIkSZKk2qt1AiMipkfEvRGxKCJOaGE9D0XEHRGxsOd2LxHx2oi4OiLuK8+blfKIiG+UmG6PiN2bjjOz7H9fRMxci/rPjojHm+/HPZT1R8Qe5f0tKq+NtYzl5IhYUs7Pwog4uGnbieW490bEgU3lvX52ZYKpm0v5nDLZVF+xbBsR10XE3RFxV0QcV9W56SeWqs7N6IiYFxG3lXg+398xImK9sr6obB+3rnGuRSznRMSDTedmYqs/p6b9R0TELyLiB1Wdl24WEX8bEb8s34ev1CCev4uIjIjXVRzHaeW83B4Rl0XEphXEUIvvb19tapVWbzcqjmXTiLikfF/uiYi9K4zlE+UzujMiLoyI0W2se8D9Iw1Mu9qAsG+9pljsW9eobx016levIR771gCZWcsHjUmI7gfeBLwauA2Y0KK6HgJet1rZV4ATyvIJwJfL8sHAj4AA3grcXMpfCzxQnjcry5sNsP79gN2BO1tRPzCv7BvltQetZSwnA5/qZd8J5XNZDxhfPq8R/X12wMXAkWX534GP9hPLVsDuZXkj4Felzrafm35iqercBLBhWR4F3FzeR6/HAP438O9l+UhgzrrGuRaxnAMc3sv+Lf0Ol/0/CXwX+EF/57aV56VbH8A7gGuA9cr66yuOZ1saE9s9zGrtfAWxTANGluUvU9quNtZfm+8vfbSpFX8+q7QbFcdyLvA3ZfnVwKYVxbEN8CCwflm/GPhgG+sfcP/Ix4DOp31r+9b2rXuPpTb96jXEcw72rWs9AmMysCgzH8jMF4CLgEPbWP+hNDoQlOd3N5Wflw03AZtGxFbAgcDVmflEZj4JXA1MH0hFmXk98EQr6i/bNs7Mm7Lx7Tmv6VgDjaUvhwIXZebzmfkgsIjG59brZ1cye1OBS3p5X73F8mhm3lqW/wDcQ6Mz1fZz008sVZ2bzMxnyuqo8sh+jtF8zi4B9i91rlWcaxlLf+emZd/hiBgLHAJ8q6z3d25bdl662EeBUzPzeYDMfLzieE4Hjqf/72RbZOZVmflSWb0JGNvmEGrz/V2HNrWlVm83qhQRm9D4D89ZAJn5Qmb+vsKQRgLrR8RI4DXAb9pV8Vr2j7RmVbcB9q3XzL51BeemTv3qNcTT37npmr51nRMY2wCPNK0vpnWdmwSuiogFETGrlG2RmY+W5ceALdYQ11DHO1T1b1OWBxvXsWVI0tnx8nDNtY1lc+D3TR34AcdShh+9hUYGstJzs1osUNG5KUO5FgKP02iQ7u/nGCvrLdufKnUOyfd59Vgys+fcfLGcm9MjYr3VYxlgnWv7OZ1B4z+sK8p6f+e2peelS70ZeHsZNviTiNizqkAi4lBgSWbeVlUM/fhrGr94tFMtv7+9tKlVOINV240qjQeWAt8uw3W/FREbVBFIZi4Bvgr8GngUeCozr6oiliZ99QG0Zvat7VuvZN/6FTHUpl/dWzz2rV9W5wRGO+2bmbsDBwHHRMR+zRtLZqqyX++qrh/4JrAdMJFGB+af21l5RGwI/Cfw8cx8unlbu89NL7FUdm4yc3lmTqTxK+5kYMd21b2mWCJiZ+DEEtOeNIaufabVcUTEO4HHM3NBq+vqZhFxTTSuh1/9cSiNX2tfS2NY4qeBi9d0XWULY/ks8A+tqnsd4unZ53PAS8AF7Yytjvpr39sYQ93ajZE0hpt/MzPfAjxLY0h325X/OBxKI6myNbBBRBxVRSy9qUH/SH2zb90/+9Z9x1LJualTv7q3eOxbv6zOCYwlNK5d7jG2lA258gtDz1Dny2h8aX9bhtdQnnuGQfcV11DHO1T1L2HVYcprHVdm/rb8I1oB/P80zs+6xLKMxpCmkQONJSJG0WjULsjMS0txJeemt1iqPDc9sjG0+Dpg736OsbLesn2TUueQfp+bYplehgZmNi4l+Dbrfm7W5nPaB3hXRDxEYwjaVODrVHxeOk1mHpCZO/fy+D6NzPml5bOfRyNb37LJM/uKhcZ1nuOB28r3YSxwa0Rs2apY+ounnBsi4oPAO4G/Kp3EdqrV97eP9r0Kr2g3IuL8CuNZDCxu+rXtEhoJjSocADyYmUsz80XgUuBtFcXSo68+gNbMvrV9a/vWa1CnfvVq8di37pHrODlPqx80foHo6YD2TOixUwvq2QDYqGn5ZzSurzuNVSez+UpZPoRVJ0mZly9PkvIgjQlSNivLr12LOMax6uQ+Q1Y/r5yk5eC1jGWrpuVP0Lh+CWAnVp2M5QEaE7H0+dkB32PVCV/+dz9xBI1rss5Yrbzt56afWKo6N2Mok7oB6wM/pfGfol6PARzDqhPqXLyuca5FLFs1nbszaMyL0NLPabW4pvDyRENtPy/d+gCOBr5Qlt9MY1hg1CCuh6h+Es/pwN3AmIrqr833lz7a1Kofze1GxXH8FNihLJ8MnFZRHHsBd9GY+yJoXNf8t22OYRwD6B/5GNC5tG9t39q+de+x1KZfvYZ47Ftn1jeBUd7owTRmpL0f+FyL6nhTOVG30fgj/blSvjlwLXAfjRn1ez7sAM4sMd0BTGo61l/TmJBkEfChtYjhQhpDpF6k8cvLh4eyfmAScGd5zb/Sz38m+ojlO6Wu24HLWbVh+Vw57r00zV7b12dXzve8EuP3KHcq6COWfWkMYbsdWFgeB1dxbvqJpapzsyvwi1LvncA/9HcMYHRZX1S2v2ld41yLWH5czs2dwPm8PJtyS7/DTa+ZwsuNbNvPS7c+aPzxOb98XrcCU6uOqcT1ENUnMBbRSOj0tCH/XkEMtfj+0kebWoPvycp2o+I4JgLzy/n5LwZ454UWxfJ54Jfl3/R36OdvUwvqHnD/yMeAz6l9a/vW9q1fGUtt+tVriMe+dWYjUEmSJEmSpDqr8xwYkiRJkiRJgAkMSZIkSZI0DJjAkCRJkiRJtWcCQ5IkSZIk1Z4JDEmSJEmSVHsmMFQLEfFMeR4XEf9ziI/92dXWfzaUx5ekbhIRn4uIuyLi9ohYGBF7tbCuuRExqVXHl6ROZd9ancoEhupmHLBWjWxEjFzDLqs0spn5trWMSZIERMTewDuB3TNzV+AA4JFqo5Ik9WMc9q3VQUxgqG5OBd5eftX7RESMiIjTIuKW8mvfRwAiYkpE/DQiLgfuLmX/FRELyi+Ds0rZqcD65XgXlLKejHSUY98ZEXdExBFNx54bEZdExC8j4oKIiArOhSTVzVbA7zLzeYDM/F1m/iYi/qG003dGxOyeNrO0padHxPyIuCci9oyISyPivog4pewzrqmtvae0va9ZveKImBYRP4+IWyPiexGxYSk/NSLuLn8jvtrGcyFJw4F9a3WUyMyqY5CIiGcyc8OImAJ8KjPfWcpnAa/PzFMiYj3gRuC9wBuB/wZ2zswHy76vzcwnImJ94BbgLzJzWc+xe6nrPcDRwHTgdeU1ewE7AN8HdgJ+U+r8dGbe0PozIUn1VZIGNwCvAa4B5mTmT3ra37LPd4CLM/P/RsRc4ObM/ExEHAd8BtgDeAK4H9gN2Ah4ENg3M2+MiLOBuzPzq+X1nwIeAi4FDsrMZyPiM8B6wJnAz4AdMzMjYtPM/H1bToYk1Zh9a3UqR2Co7qYBH4iIhcDNwObA9mXbvJ4GtvhYRNwG3ARs27RfX/YFLszM5Zn5W+AnwJ5Nx16cmSuAhTSG30lSV8vMZ2gkIGYBS4E5EfFB4B0RcXNE3AFMpdFJ7XF5eb4DuCszHy0jOB6g0VYDPJKZN5bl82m0z83eCkwAbix/D2bS6Gw/BTwHnBURfwn8cajeqyR1KPvWGtbWdH2TVLUA/jYzr1ylsJFNfna19QOAvTPzj+VXu9GDqPf5puXl+G9FkgDIzOXAXGBuSVh8BNgVmJSZj0TEyaza/va0pytYtW1dwctt6+rDQVdfD+DqzJyxejwRMRnYHzgcOJZGAkWS1Dv71hrWHIGhuvkDjeHEPa4EPhoRowAi4s0RsUEvr9sEeLI0sDvS+LWux4s9r1/NT4EjyrWAY4D9gHlD8i4kqQNFxA4R0fwL3ETg3rL8u3KJyeHrcOg3RGOCUGhMNrf6sOKbgH0i4s9KHBuUvwcbAptk5g+BT9C4JEWS9DL71uooZr5UN7cDy8twtXOAr9MYYnZrmexnKfDuXl53BXB0RNxDozN9U9O22cDtEXFrZv5VU/llwN7AbTR+7Ts+Mx8rjbQk6ZU2BP4lIjYFXgIW0bic5PfAncBjNK55Xlv3Asf0zH8BfLN5Y2YuLZeqXFiu2Qb4exod8+9HxGgavyp+ch3qlqROZt9aHcVJPCVJUmUiYhzwg8zcuepYJElSvXkJiSRJkiRJqj1HYEiSJEmSpNpzBIYkSZIkSao9ExiSJEmSJKn2TGBIkiRJkqTaM4EhSZIkSZJqzwSGJEmSJEmqvf8HRFYPHJ8nqk0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1080x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "source": [
    "## CONFIRM q_list_test in validation/visualization in code\n",
    "\n",
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "# loss_crit = torch.nn.CrossEntropyLoss()\n",
    "loss_crit = torch.nn.functional.cross_entropy\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        model.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "            \n",
    "        logP = model(p_batch)\n",
    "        logP = logP[:,0] - logP[:,1]\n",
    "        logQ = model(q_batch)\n",
    "        logQ = logQ[:,0] - logQ[:,1]\n",
    "        \n",
    "        p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0.)\n",
    "        q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1.)\n",
    "        m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "    \n",
    "        loss = torch.nn.functional.binary_cross_entropy_with_logits(logP, p_label.float()) + torch.nn.functional.binary_cross_entropy_with_logits(logQ, q_label.float())\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 50 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "                    log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, p_batch, calc_true_kl=True)\n",
    "                    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "                    \n",
    "                    logP = model(p_batch)\n",
    "                    logP = logP[:,0] - logP[:,1]\n",
    "                    logQ = model(q_batch)\n",
    "                    logQ = logQ[:,0] - logQ[:,1]\n",
    "                    logM = model(m_batch)\n",
    "                    \n",
    "\n",
    "                    log_ratio_p_q_from_cob = logP\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob)\n",
    "#                     log_ratio_p_q_from_cob = logM[:, 0] - logM[:, 1]\n",
    "                \n",
    "\n",
    "                    p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "                    \n",
    "                    test_loss = torch.nn.functional.binary_cross_entropy_with_logits(logP, p_label.float()) + torch.nn.functional.binary_cross_entropy_with_logits(logQ, q_label.float())\n",
    "\n",
    "                    # Visualize\n",
    "                    \n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store) )\n",
    "                    \n",
    "                    scat1.set_offsets(np.vstack([p_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "                    scat2.set_offsets(np.vstack([p_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)\n",
    "\n",
    "                    ax2.set_xlim( -25., 25. )\n",
    "                    ax2.set_ylim( -1000, 1000)\n",
    "            \n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    ax3.set_xlim( 0, len(test_loss_store) )\n",
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model.train()"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAotklEQVR4nO3deXxU5dn/8c/VsCmbVlGQuGABBWIIEDYByyIItAUXVBAVfUTEuuLTx+LyCFp9an91q0u1qIBWFFyKohXKolilsgVihAAFBEtoigjKLgK5fn/MyTiEJAyQyZyQ7/v1mlfOuc8y1xwgX84y923ujoiISNj8KNkFiIiIFEcBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSkkNKDMba2ZfmdmSmLYfm9kMM1sZ/Dw+aDcze9LMVplZjpm1jtlmSLD+SjMbkozPIiIiZSvZZ1Djgd5F2kYCs9y9CTArmAfoAzQJXsOAZyESaMAooD3QDhhVGGoiIlJxJTWg3P3vwOYizf2Bl4Lpl4ALY9pf9oi5wHFm1gC4AJjh7pvd/RtgBgeGnoiIVDBVkl1AMU529/xg+j/AycF0Q2BdzHp5QVtJ7Qcws2FEzr6oWbNmm7PPPrsMyxYRkcORlZX1tbvXK9oexoCKcnc3szLri8ndxwBjADIzM33hwoVltWsRETlMZvZlce3JvgdVnA3BpTuCn18F7euBU2PWSw3aSmoXEZEKLIwBNQUofBJvCPBOTPvVwdN8HYAtwaXAvwG9zOz44OGIXkGbiIhUYEm9xGdmrwFdgRPNLI/I03gPA6+b2XXAl8BlwervA32BVcBO4FoAd99sZr8BFgTrPeDuRR+8EBGRCsYq63AbugclibBnzx7y8vL47rvvkl2KSOjUqFGD1NRUqlatul+7mWW5e2bR9UP9kIRIRZOXl0ft2rU544wzMLNklyMSGu7Opk2byMvLo1GjRnFtE8Z7UCIV1nfffccJJ5ygcBIpwsw44YQTDunqggJKpIwpnESKd6j/NhRQIiISSgookaPIpk2byMjIICMjg/r169OwYcPo/Pfff18m79G1a1fK8wGjuXPncv311x/Wttdccw1vvvkmAJs3b6ZVq1aMGzeOtWvXkpaWVmY1rlu3jm7dutG8eXNatGjBH/7wh+iyzZs307NnT5o0aULPnj355ptvgMg9mVtvvZXGjRuTnp7OokWLDtjvzp07+dnPfsbZZ59NixYtGDlyZHTZiBEjon+2TZs25bjjjiuzzxMWCiiRo8gJJ5xAdnY22dnZDB8+nBEjRkTnq1Wrxt69e5Nd4iGbOnUqvXsfWfeaW7Zs4YILLmDYsGFce+21ZVTZD6pUqcKjjz5Kbm4uc+fO5ZlnniE3NxeAhx9+mB49erBy5Up69OjBww8/DEQ+18qVK1m5ciVjxozhxhtvLHbfv/rVr1i+fDmLFy9mzpw5TJ06FYDHH388+md7yy23cPHFF5f550o2BZTIUe6aa65h+PDhtG/fnjvvvJPRo0fzyCOPRJenpaWxdu1aAF555RXatWtHRkYGN9xwA/v27YvrPTZv3syFF15Ieno6HTp0ICcnB4CNGzfSs2dPWrRowdChQzn99NP5+uuvD9i+Vq1ajBgxghYtWtCjRw82btwYXTZr1izOP/98du3axcCBA2nWrBkXXXQR7du3j+tMbvv27fTp04crrriixBA4Ug0aNKB168gIQLVr16ZZs2asXx/p0Oadd95hyJBI3wNDhgzh7bffjrZfffXVmBkdOnTg22+/JT8/f7/9HnvssXTr1g2AatWq0bp1a/Ly8g54/9dee41BgwYl5LMlkwJKJMmmT4e77478TJS8vDz+8Y9/8Nhjj5W4zrJly5g0aRJz5swhOzublJQUJkyYENf+R40aRatWrcjJyeH//u//uPrqqwG4//776d69O0uXLmXAgAH861//Knb7HTt2kJmZydKlS/npT3/K/fffD8DXX39N1apVqVu3Ls8++yzHHnssy5Yt4/777ycrKyuu2u644w46d+7MiBEj4lq/0IQJE6KX0GJfAwYMKHW7tWvXsnjxYtq3bw/Ahg0baNCgAQD169dnw4YNAKxfv55TT/2hl7bU1NRoqBXn22+/5d1336VHjx77tX/55ZesWbOG7t27H9Lnqwj0PSiRJJo+Ha68EnbvhhdegFdegV69yv59Lr30UlJSUkpdZ9asWWRlZdG2bVsAdu3axUknnRTX/j/55BPeeustALp3786mTZvYunUrn3zyCZMnTwagd+/eHH988UO1/ehHP+Lyyy8H4Morr4xerpo+fTq9ggPy97//nVtvvRWA9PR00tPT46qte/fuvPPOO/zqV7+K+/MADB48mMGDB8e9PkTO1i655BKeeOIJ6tSpc8ByMzuspzz37t3LoEGDuPXWWznzzDP3WzZx4kQGDBhw0D/fikgBJZJEs2dHwqlmTdixIzKfiICqWbNmdLpKlSoUFBRE5wu/l+LuDBkyhN/+9rdlX8AhKvwlPnXqVO64444j2tfAgQPp1KkTffv25cMPP6R27dpxbTdhwgR+//vfH9DeuHHj6IMXsfbs2cMll1zC4MGD97sfdPLJJ5Ofn0+DBg3Iz8+PhmTDhg1Zt+6HkYLy8vJo2LDYkYIYNmwYTZo04fbbbz9g2cSJE3nmmWfi+kwVjS7xiSRR165QvXoknKpXj8wn2hlnnBF9YmzRokWsWbMGgB49evDmm2/y1VeRAQQ2b97Ml18WOwrCAbp06RK9HDh79mxOPPFE6tSpQ6dOnXj99deByNlQ4RNsRRUUFER/6b/66qt07twZdycnJ4eMjAwAzjvvPF599VUAlixZEr3PBXD11Vczf/78EusbMWIEPXr04OKLL477acbBgwdHH0KIfRUXTu7OddddR7NmzQ4I1H79+vHSS5ExWF966SX69+8fbX/55Zdxd+bOnUvdunWjlwJj3XvvvWzZsoUnnnjigGXLly/nm2++oWPHjnF9popGASWSRL16RS7r3XRT4i7vFXXJJZewefNmWrRowdNPP03Tpk0BaN68OQ8++CC9evUiPT2dnj17HnDTvtDPfvYzUlNTSU1N5dJLL2X06NFkZWWRnp7OyJEjo7+QR40axfTp00lLS+ONN96gfv36xZ7B1KxZk/nz55OWlsYHH3zAfffdR1ZWFq1atYqeTd14441s376dZs2acd9999GmTZvo9jk5OZxyyimlfu7f/e53pKamctVVV1FQUMCKFSuinyE1NZU33njjsI4nwJw5c/jzn//MBx98EL1X9f777wMwcuRIZsyYQZMmTZg5c2b0UfG+ffty5pln0rhxY66//nr++Mc/RvdXGMp5eXk89NBD5Obm0rp1azIyMnjhhRei602cOJGBAwcetV8OV2exImVo2bJlNGvWLNllhMbu3btJSUmhSpUqfPrpp9x4441kZ2cfsF6tWrXYvn37fm0PPvggjRs3ZuDAgcXuu2vXrjzyyCM0bdqU66677ogCRspPcf9G1FmsiJS7f/3rX1x22WUUFBRQrVo1nn/++bi3vffee+Nar06dOgqno5QCSkQSpkmTJixevPig6xU9e4rH7NmzD6MiqUh0D0pEREJJASUiIqGkgBIRkVBSQImISCgpoESOMg899BAtWrQgPT2djIwM5s2bB8DQoUOjPWwfqsMZniIlJYWMjAzS0tL4xS9+wbffflvq+tnZ2dHvDgFMmTIl2vN3PMaPH8/NN998SDUeqTZt2rB79+5D3m727Nn8/Oc/j87fe++99O7dm927dyd0OJPvvvuOdu3a0bJlS1q0aMGoUaOiy55++mkaN26MmRXboW+hX//616SlpZGWlsakSZOi7ddddx0tW7YkPT2dAQMGHNaDL0WFMqDM7Cwzy455bTWz281stJmtj2nvG7PNXWa2ysxWmNkFyaxfJFk+/fRT3nvvPRYtWkROTg4zZ86Mdkj6wgsv0Lx583Kr5ZhjjiE7O5slS5bw4x//+KDd8RQNqH79+u03/lHYrFmzhoYNG1K9evUj2s+DDz7InDlzmDx58hHv62CqV6/OBx98wGeffUZ2djbTpk1j7ty5AHTq1ImZM2dy+umnl7j9X//6VxYtWkR2djbz5s3jkUceYevWrUBk+I/PPvuMnJwcTjvtNJ5++ukjrjeUAeXuK9w9w90zgDbATmBysPjxwmXu/j6AmTUHBgItgN7AH83s6Os5UeQg8vPzOfHEE6O/6E488cRoDwux/zOvVasW99xzDy1btqRDhw7RHrZXr15Nhw4dOOecc7j33nupVavWAe+xb98+/ud//oe2bduSnp7On/70p4PW1bFjx2hP3fPnz6djx460atWKc889lxUrVvD9999z3333MWnSJDIyMpg0adJ+Z0Rr166le/fupKen06NHjxJ7RS/OY489Fv0ff2x3Qb/5zW8466yz6Ny5M4MGDdpvCJJChUOVZGZm0rRpU957773osmnTpkXHqRo3bhxNmzalXbt2XH/99XGfyT366KNMnTqVd999l2OOOSbuz3S4zCz6Z7pnzx727NkT7YWiVatWnHHGGaVun5uby3nnnUeVKlWoWbMm6enpTJs2DSDaOa67s2vXrjLp3SKUAVVED2C1u5fWKVh/YKK773b3NcAqoF25VCdypMpwvI1evXqxbt06mjZtyi9/+Us++uijYtfbsWMHHTp04LPPPuO8886LfoH2tttu47bbbuPzzz8nNTW12G1ffPFF6taty4IFC1iwYAHPP/98tD+/4uzbt49Zs2bRr18/AM4++2w+/vhjFi9ezAMPPMDdd99NtWrVeOCBB7j88svJzs6O9mxe6JZbbmHIkCHk5OQwePDgaK/mB5OVlcW4ceOYN28ec+fO5fnnn2fx4sUsWLCAt956i88++4ypU6eWeklt7dq1zJ8/n7/+9a8MHz482rluYUDl5+czatQo5syZwyeffBL3ZdQ5c+bw3HPPMXXq1GL/I1CSbdu2FTsMSEZGRlzvvW/fPjIyMjjppJPo2bNndFiQeLRs2ZJp06axc+dOvv76az788MP9Ory99tprqV+/PsuXL+eWW26Je78lqQgBNRB4LWb+ZjPLMbOxZlbYd39DYF3MOnlB237MbJiZLTSzhbEDookkTeF4G888E/l5hCFVq1YtsrKyGDNmDPXq1ePyyy9n/PjxB6xXrVq16D2QNm3aRAcs/PTTT7n00ksBuOKKK0ooeTovv/wyGRkZtG/fnk2bNrFy5coD1tu1a1d06PkNGzbQs2dPIDK67aWXXkpaWhojRoxg6dKlB/1cn376abSeq666ik8++eSg20BkGJCLLrqImjVrUqtWLS6++GI+/vhj5syZQ//+/alRowa1a9fmF7/4RYn7uOyyy/jRj35EkyZNOPPMM1m+fDnff/89eXl5nHnmmcybN4+uXbtSr149qlWrdkC4lqRx48a4OzNmzIhr/UK1a9cuthPb7OzsuC7hpqSkkJ2dTV5eHvPnz2fJkiVxv3evXr3o27cv5557LoMGDaJjx477DfMxbtw4/v3vf9OsWbP97k8drlAHlJlVA/oBhf2YPAv8BMgA8oFHD2V/7j7G3TPdPbNevXplWarI4Ykdb2P37sj8EUpJSaFr167cf//9PP3009FxmmJVrVo1egkmJSXlkIaCd3eeeuqp6C/FNWvWRMdsilV4D+rLL7/E3aP3oP73f/+Xbt26sWTJEt59993oGUlYFb1UZWZ8/PHHdO7c+Yj2e/LJJ/P+++9z++238+GHH8a93aGcQa1bty667Lnnnttv2XHHHUe3bt2il+jidc8995Cdnc2MGTNw92hnw4VSUlIYOHBgsX/vDlWoAwroAyxy9w0A7r7B3fe5ewHwPD9cxlsPnBqzXWrQJhJuZTzexooVK/Y7m8nOzi71pndRHTp0iP5imThxYrHrXHDBBTz77LPs2bMHgH/+85/s2LGjxH0ee+yxPPnkkzz66KPs3buXLVu2RMc9ij27q127Ntu2bSt2H+eee260ngkTJtClS5e4Pk+XLl14++232blzJzt27GDy5Ml06dKFTp06RcNx+/bt+91bKuqNN96goKCA1atX88UXX3DWWWcxbdo0+vTpA0D79u356KOP2LRpE3v27NmvX8DJkydz1113lbjvpk2b8pe//IUrr7yy2E50i3MoZ1CnnnpqdNnw4cPZuHFj9GnKXbt2MWPGDM4+++y43hcilwc3bdoERHqQz8nJoVevXrg7q1atAiL/gZkyZcoh7bckYQ+oQcRc3jOz2MFSLgIKz02nAAPNrLqZNQKaACUPDiMSFmU83sb27dsZMmQIzZs3Jz09ndzcXEaPHh339k888QSPPfYY6enprFq1irp16x6wztChQ2nevDmtW7cmLS2NG2644aBnYK1atSI9PZ3XXnuNO++8k7vuuotWrVrtt123bt3Izc2NPiQR66mnnmLcuHGkp6fz5z//mT/84Q/Fvs/48eP3G0LjpJNO4pprrqFdu3a0b9+eoUOH0qpVK9q2bUu/fv1IT0+nT58+nHPOOcV+VoDTTjuNdu3a0adPH5577jlq1KjB7Nmz+elPfwpAgwYNGD16NB07dqRTp0779dS9evXqYkfWjdW2bVvGjRtHv379WL16NXDgcCZlJT8/n27dupGenk7btm3p2bNn9FLvk08+SWpqKnl5eaSnpzN06FAAFi5cGJ3es2cPXbp0oXnz5gwbNoxXXnmFKlWqRAe7POecczjnnHPIz8/nvvvuO/KC3T2UL6AmsAmoG9P2Z+BzIIdIKDWIWXYPsBpYAfQ52P7btGnjImUtNzc32SUckR07dnhBQYG7u7/22mver1+/JFeUONu2bXP3yGdu06aNZ2VlHbDOkCFD/I033tivbd26dd67d+8S9ztu3Di/6aab3N198ODB/tVXX5Vh1RVfcf9GgIVezO/p0PZm7u47gBOKtF1VyvoPAQ8lui6Ro1lWVhY333wz7s5xxx3H2LFjk11SwgwbNozc3Fy+++47hgwZQuvWrePaLjU1lalTp8a17iuvvHIkJVZ6GrBQpAxpwEKR0h3KgIVhvwclUuFU1v/0iRzMof7bUECJlKEaNWqwadMmhZRIEe7Opk2bqFGjRtzbhPYelEhFVPgUlL4ILnKgGjVqlNhDSXEUUCJlqGrVqjRq1CjZZYgcFXSJT0REQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQml0AaUma01s8/NLNvMFgZtPzazGWa2Mvh5fNBuZvakma0ysxwza53c6kVE5EiFNqAC3dw9w90zg/mRwCx3bwLMCuYB+gBNgtcw4Nlyr1RERMpU2AOqqP7AS8H0S8CFMe0ve8Rc4Dgza5CE+kREpIyEOaAcmG5mWWY2LGg72d3zg+n/ACcH0w2BdTHb5gVt+zGzYWa20MwWbty4MVF1i4hIGaiS7AJK0dnd15vZScAMM1seu9Dd3cz8UHbo7mOAMQCZmZmHtK2IiJSv0J5Bufv64OdXwGSgHbCh8NJd8POrYPX1wKkxm6cGbSIiUkGFMqDMrKaZ1S6cBnoBS4ApwJBgtSHAO8H0FODq4Gm+DsCWmEuBIiJSAYX1Et/JwGQzg0iNr7r7NDNbALxuZtcBXwKXBeu/D/QFVgE7gWvLv2QRESlLoQwod/8CaFlM+yagRzHtDtxUDqWJiEg5CeUlPhEREQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioRS6gDKzU83sQzPLNbOlZnZb0D7azNabWXbw6huzzV1mtsrMVpjZBcmrXkREykqVZBdQjL3Af7v7IjOrDWSZ2Yxg2ePu/kjsymbWHBgItABOAWaaWVN331euVYuISJkK3RmUu+e7+6JgehuwDGhYyib9gYnuvtvd1wCrgHaJr1RERBIpdAEVy8zOAFoB84Kmm80sx8zGmtnxQVtDYF3MZnmUEGhmNszMFprZwo0bNyaqbJEjNn063H135KdIZRXagDKzWsBbwO3uvhV4FvgJkAHkA48e6j7dfYy7Z7p7Zr169cqyXJEyM3069OkDv/1t5KdCSiqrUAaUmVUlEk4T3P0vAO6+wd33uXsB8Dw/XMZbD5was3lq0CZSIV18MRQURKYLCuACPfYjlVToAsrMDHgRWObuj8W0N4hZ7SJgSTA9BRhoZtXNrBHQBJhfXvWKlLUdOw5sMyv/OkSSLYxP8XUCrgI+N7PsoO1uYJCZZQAOrAVuAHD3pWb2OpBL5AnAm/QEn4hIxRe6gHL3T4Di/r/4finbPAQ8lLCiRMpRr17F33cyA/fyr0ckWUJ3iU+ksvvb35JdgUg4KKBEKpA6dZJdgUj5OWhAmVldM3u88PtDZvaomdUtj+JEKquSLuVt21a+dYgkUzxnUGOBrcBlwWsrMC6RRYmIiMTzkMRP3P2SmPn7Y56uE5EEcT/w8fJTTklOLSLJEM8Z1C4z61w4Y2adgF2JK0lECrn/EEqnnALr9RV0qUTiOYO6EXgpuO9kwGbgmkQWJSI/UChJZXXQgHL3bKClmdUJ5rcmuigREZESA8rMrnT3V8zsjiLtAMR2QyQiIlLWSjuDqhn8rF3MMn2fXUREEqrEgHL3PwWTM919Tuyy4EEJERGRhInnKb6n4mwTEREpM6Xdg+oInAvUK3Ifqg6QkujCRESkcivtHlQ1oFawTux9qK3AgEQWJSIiUto9qI+Aj8xsvLt/WY41iYiIxPVF3Z1m9nugBVCjsNHduyesKhERqfTieUhiArAcaATcT2Q02wUJrElERCSugDrB3V8E9rj7R+7+X4DOnkREJKHiucS3J/iZb2Y/A/4N/DhxJYmIiMQXUA8GHcX+N5HvP9UBbk9kUSIiIge9xOfu77n7Fndf4u7d3L0NkR7NQ8XMepvZCjNbZWYjk12PiIgcmRIDysxSzGyQmf3KzNKCtp+b2T+Ap8utwjiYWQrwDNAHaA4MMrPmya1KRESORGmX+F4ETgXmA0+a2b+BTGCku79dDrUdinbAKnf/AsDMJgL9gdykViUiIoettIDKBNLdvcDMagD/ITL8+6byKe2QNATWxcznAe2LrmRmw4BhAKeddlr5VCYiIoeltHtQ37t7AYC7fwd8EdJwipu7j3H3THfPrFevXrLLETk4s/1fIpVIaWdQZ5tZTjBtwE+CeQPc3dMTXl381hO5HFkoNWgTqZiqVoW9ew9sNwPXcGxSOZQWUM3KrYojtwBoYmaNiATTQOCK5JYkcphKCieRSqa0zmIrTAex7r7XzG4G/kZkKJCx7r40yWWJHB6FkwgQ3xd1KwR3fx94P9l1iByxKlVKDild3pNKJJ6++ESkPO3ZEwmpohROUskcNWdQIkeVPXsOvo7IUe6gAWVmnwNF/+u2BVgIPFjRHz0XEZFwiucMaiqwD3g1mB8IHEvki7vjgV8kpDIREanU4gmo8929dcz852a2yN1bm9mViSpMREQqt3gekkgxs3aFM2bWlsij3AB6HlZERBIinjOoocBYM6tFpBeJrcB1ZlYT+G0iixMRkcrroAHl7guAc4JBC3H3LTGLX09UYSIiUrkd9BKfmdU1s8eAWcAsM3u0MKxEREQSJZ57UGOBbcBlwWsrMC6RRYmIiMRzD+on7n5JzPz9ZpadoHpERESA+M6gdplZ58IZM+sE7EpcSSIiIvGdQQ0HXo657/QNMCRxJYmIiMT3FN9nQEszqxPMbzWz24GcUjcUERE5AnH3Zu7uW919azB7R4LqERERAQ5/uA0r0ypERESKONyA0sA0IiKSUCXegzKzbRQfRAYck7CKREREKCWg3L12eRYiIiISS0O+i4hIKIUqoMzs92a23MxyzGyymR0XtJ9hZrvMLDt4PRezTRsz+9zMVpnZk2amBzhERI4CoQooYAaQ5u7pwD+Bu2KWrXb3jOA1PKb9WeB6oEnw6l1u1YqISMKEKqDcfbq7Fw6COBdILW19M2sA1HH3ue7uwMvAhYmtUkREykOoAqqI/wKmxsw3MrPFZvaRmXUJ2hoCeTHr5AVtxTKzYWa20MwWbty4sewrFhGRMhNPX3xlysxmAvWLWXSPu78TrHMPkeHkJwTL8oHT3H2TmbUB3jazFof63u4+BhgDkJmZqe9yiYiEWLkHlLufX9pyM7sG+DnQI7hsh7vvBnYH01lmthpoCqxn/8uAqUGbiIhUcKG6xGdmvYE7gX7uvjOmvZ6ZpQTTZxJ5GOILd88HtppZh+DpvauBd5JQuoiIlLFyP4M6iKeB6sCM4GnxucETe+cBD5jZHqAAGO7um4NtfgmMJ9K7xVT2v28lIiIVVKgCyt0bl9D+FvBWCcsWAmmJrEtERMpfqC7xiYiIFFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkqhCygzG21m680sO3j1jVl2l5mtMrMVZnZBTHvvoG2VmY1MTuUiIlKWqiS7gBI87u6PxDaYWXNgINACOAWYaWZNg8XPAD2BPGCBmU1x99zyLFhERMpWWAOqOP2Bie6+G1hjZquAdsGyVe7+BYCZTQzWVUCJiFRgobvEF7jZzHLMbKyZHR+0NQTWxayTF7SV1H4AMxtmZgvNbOHGjRsTUbeIiJSRpASUmc00syXFvPoDzwI/ATKAfODRsnpfdx/j7pnunlmvXr2y2q2IiCRAUi7xufv58axnZs8D7wWz64FTYxanBm2U0i4iIhVU6C7xmVmDmNmLgCXB9BRgoJlVN7NGQBNgPrAAaGJmjcysGpEHKaaUZ80iIlL2wviQxP8zswzAgbXADQDuvtTMXify8MNe4CZ33wdgZjcDfwNSgLHuvjQJdYuISBkyd092DUmRmZnpCxcuTHYZIiKVnplluXtm0fbQXeITEREBBZSIiISUAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhFKqAMrNJZpYdvNaaWXbQfoaZ7YpZ9lzMNm3M7HMzW2VmT5qZJe0DiIhImamS7AJiufvlhdNm9iiwJWbxanfPKGazZ4HrgXnA+0BvYGoCyxQRkXIQqjOoQsFZ0GXAawdZrwFQx93nursDLwMXJr5CERFJtFAGFNAF2ODuK2PaGpnZYjP7yMy6BG0NgbyYdfKCtmKZ2TAzW2hmCzdu3Fj2VYuISJkp90t8ZjYTqF/Monvc/Z1gehD7nz3lA6e5+yYzawO8bWYtDvW93X0MMAYgMzPTD3V7EREpP+UeUO5+fmnLzawKcDHQJmab3cDuYDrLzFYDTYH1QGrM5qlBm4iIVHBhvMR3PrDc3aOX7sysnpmlBNNnAk2AL9w9H9hqZh2C+1ZXA+8Ut1MREalYQvUUX2AgBz4ccR7wgJntAQqA4e6+OVj2S2A8cAyRp/f0BJ+IyFEgdAHl7tcU0/YW8FYJ6y8E0hJcloiIlLMwXuITERFRQImISDgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKSQkoM7vUzJaaWYGZZRZZdpeZrTKzFWZ2QUx776BtlZmNjGlvZGbzgvZJZlatPD+LiIgkRrLOoJYAFwN/j200s+bAQKAF0Bv4o5mlmFkK8AzQB2gODArWBfgd8Li7Nwa+Aa4rn48gIiKJlJSAcvdl7r6imEX9gYnuvtvd1wCrgHbBa5W7f+Hu3wMTgf5mZkB34M1g+5eACxP+AUREJOGqJLuAIhoCc2Pm84I2gHVF2tsDJwDfuvveYtY/gJkNA4YFs9vNrLiQTIYTga+TXURI6diUTsenZDo2JQvbsTm9uMaEBZSZzQTqF7PoHnd/J1HvWxp3HwOMScZ7l8bMFrp75sHXrHx0bEqn41MyHZuSVZRjk7CAcvfzD2Oz9cCpMfOpQRsltG8CjjOzKsFZVOz6IiJSgYXtMfMpwEAzq25mjYAmwHxgAdAkeGKvGpEHKaa4uwMfAgOC7YcASTk7ExGRspWsx8wvMrM8oCPwVzP7G4C7LwVeB3KBacBN7r4vODu6GfgbsAx4PVgX4NfAHWa2isg9qRfL99OUidBddgwRHZvS6fiUTMemZBXi2FjkJERERCRcwnaJT0REBFBAiYhISCmgksjMfm9my80sx8wmm9lxMcuK7fKpsjic7rAqk5K6/qqszGysmX1lZkti2n5sZjPMbGXw8/hk1pgMZnaqmX1oZrnBv6fbgvYKcWwUUMk1A0hz93Tgn8BdUHKXT0mrMjkOqTus8i8veQ7S9VdlNZ7I34dYI4FZ7t4EmBXMVzZ7gf929+ZAB+Cm4O9KhTg2CqgkcvfpMb1gzCXyPS4oucunSuMwusOqTIrt+ivJNSWVu/8d2FykuT+R7s+gknaD5u757r4omN5G5CnohlSQY6OACo//AqYG0w05sGunErtwqmR0bHQM4nWyu+cH0/8BTk5mMclmZmcArYB5VJBjE7a++I468XT5ZGb3EDkVn1CetSVbGLvDkqOTu7uZVdrv1JhZLeAt4HZ33xrpZzsizMdGAZVgB+vyycyuAX4O9PAfvpRWWpdPR40EdIdVWegYxGeDmTVw93wzawB8leyCksHMqhIJpwnu/peguUIcG13iSyIz6w3cCfRz950xi0rq8kl0bKCErr+SXFMYTSHS/RlU0m7QgiGJXgSWuftjMYsqxLFRTxJJFHTPVJ1Ip7cAc919eLDsHiL3pfYSOS2fWvxejk5mdhHwFFAP+BbIdvcLgmWV+tgAmFlf4AkgBRjr7g8lt6LkMrPXgK5EhpHYAIwC3ibSddppwJfAZe5e9EGKo5qZdQY+Bj4HCoLmu4nchwr9sVFAiYhIKOkSn4iIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgRMqRmd0T9CqdY2bZZtY+ge81u2hP8CIViXqSECknZtaRSK8hrd19t5mdCFRLclkioaUzKJHy0wD42t13A7j71+7+bzO7z8wWmNkSMxsTfPu/8AzocTNbaGbLzKytmf0lGMPnwWCdM4IxxSYE67xpZscWfWMz62Vmn5rZIjN7I+ibDTN7OBgrKMfMHinHYyFyUAookfIzHTjVzP5pZn80s58G7U+7e1t3TwOOIXKWVeh7d88EniPSHc1NQBpwjZmdEKxzFvBHd28GbAV+GfumwZnavcD57t4aWAjcEWx/EdAiGJPswQR8ZpHDpoASKSfuvh1oAwwDNgKTgs6Cu5nZPDP7HOhOZDDGQoV97H0OLA3G99kNfMEPHcauc/c5wfQrQOcib92ByMCGc8wsm0jfa6cDW4DvgBfN7GJgJyIhontQIuXI3fcBs4HZQSDdAKQDme6+zsxGAzViNtkd/CyImS6cL/z3W7S/sqLzBsxw90FF6zGzdkAPYABwM5GAFAkFnUGJlBMzO8vMmsQ0ZQCFowZ/HdwXGnAYuz4teAAD4ArgkyLL5wKdzKxxUEdNM2savF9dd38fGAG0PIz3FkkYnUGJlJ9awFNmdhyRnthXEbnc9y2whMjIpgsOY78rgJvMbCyQCzwbu9DdNwaXEl8zs+pB873ANuAdM6tB5CzrjsN4b5GEUW/mIhVYMIz3e8EDFiJHFV3iExGRUNIZlIiIhJLOoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQun/A+4zZ54ezlKqAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "  1%|          | 3/400 [00:09<21:31,  3.25s/it]\n"
     ]
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-52-c4a905b1ea31>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNUM_EPOCHS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mp_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm_batch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    361\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    362\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 363\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    364\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    365\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    401\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    402\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 403\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    404\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    405\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/cob/cob_pytorch/src/dataset.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mq_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mm_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "source": [
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "iteration:  366\n",
      "KLD:  tensor(200.2708)\n",
      "CoB:  tensor(-31.8573, device='cuda:1')\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "source": [
    "# Set up viz\n",
    "fig, ax2 = plt.subplots(1, 1,figsize=(6,4))\n",
    "\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True Log p/q, KL = '+str(np.around(true_kl_p_q.item(),2)),alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='Single Ratio Log p/q, KL = '+str(np.around(kl_from_cob.item(),2)),alpha=0.9,s=10.,c='r')\n",
    "\n",
    "scat1.set_offsets(np.vstack([p_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "scat2.set_offsets(np.vstack([p_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)                    \n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-25,25])\n",
    "ax2.set_ylim([-1000,1000])\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('../plots/bcfre_mu1_p.png')"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAApNUlEQVR4nO3deXxU5d3//9dHEFBAtIBIiQtKEEgMAcImaFlksxVcQEFUbEXEalW8vf3hUgGrd+2votalWlRQK4pbcasoi0KVgpBAjCwiIFjCnRsRFATZ+Xz/mJNxCEkYIJM5Ie/n4zGPzLnOOTOfORDenOucuS5zd0RERMLmqGQXICIiUhwFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCU1oMxsvJl9Y2aLYtp+ZmbTzGx58POEoN3M7FEzW2FmeWbWOmafIcH2y81sSDI+i4iIlK1kn0E9B/Qu0jYSmOHuqcCMYBmgD5AaPIYBT0Ik0IBRQHugHTCqMNRERKTiSmpAufu/gI1FmvsBzwfPnwcujGl/wSPmAsebWUOgFzDN3Te6+3fANPYPPRERqWCqJruAYjRw94Lg+f8BDYLnjYA1MdvlB20lte/HzIYROfuiZs2abZo1a1aGZYuIyKHIycn51t3rF20PY0BFububWZmNxeTu44BxAFlZWZ6dnV1WLy0iIofIzL4urj3Z16CKsy7ouiP4+U3QvhY4OWa7lKCtpHYREanAwhhQbwOFd+INAd6Kab8quJuvA7Ap6Ar8AOhpZicEN0f0DNpERKQCS2oXn5m9DHQB6plZPpG78R4AXjWza4CvgUuDzd8DzgdWAD8CvwZw941m9gdgfrDdve5e9MYLERGpYKyyTreha1CSCLt27SI/P5/t27cnuxSR0KlRowYpKSkcffTR+7SbWY67ZxXdPtQ3SYhUNPn5+dSuXZvTTjsNM0t2OSKh4e5s2LCB/Px8GjduHNc+YbwGJVJhbd++nbp16yqcRIowM+rWrXtQvQsKKJEypnASKd7B/m4ooEREJJQUUCJHkA0bNpCZmUlmZiYnnXQSjRo1ii7v3LmzTN6jS5culOcNRnPnzuXaa689pH2vvvpqXn/9dQA2btxIq1atmDBhAqtXryY9Pb3MalyzZg1du3alRYsWpKWl8Ze//CW6buPGjfTo0YPU1FR69OjBd999B0Suydx00000adKEjIwMFixYsN/r/vjjj/zyl7+kWbNmpKWlMXLkyOi6ESNGRP9smzZtyvHHH19mnycsFFAiR5C6deuSm5tLbm4uw4cPZ8SIEdHlatWqsXv37mSXeNCmTJlC796HN7zmpk2b6NWrF8OGDePXv/51GVX2k6pVqzJ27FiWLFnC3LlzeeKJJ1iyZAkADzzwAN27d2f58uV0796dBx54AIh8ruXLl7N8+XLGjRvH9ddfX+xr33bbbXzxxRcsXLiQ2bNnM2XKFAAefvjh6J/t7373Oy6++OIy/1zJpoASOcJdffXVDB8+nPbt23P77bczevRoHnzwwej69PR0Vq9eDcCLL75Iu3btyMzM5LrrrmPPnj1xvcfGjRu58MILycjIoEOHDuTl5QGwfv16evToQVpaGkOHDuXUU0/l22+/3W//WrVqMWLECNLS0ujevTvr16+PrpsxYwbnnXce27ZtY+DAgTRv3pyLLrqI9u3bx3Umt2XLFvr06cPll19eYggcroYNG9K6dWQGoNq1a9O8eXPWro0MaPPWW28xZEhk7IEhQ4bw5ptvRtuvuuoqzIwOHTrw/fffU1BQsM/rHnvssXTt2hWAatWq0bp1a/Lz8/d7/5dffplBgwYl5LMlkwJKJMmmToU774z8TJT8/Hz+/e9/89BDD5W4zdKlS3nllVeYPXs2ubm5VKlShYkTJ8b1+qNGjaJVq1bk5eXxP//zP1x11VUAjBkzhm7durF48WL69+/Pf/7zn2L337p1K1lZWSxevJhf/OIXjBkzBoBvv/2Wo48+mjp16vDkk09y7LHHsnTpUsaMGUNOTk5ctd1666107tyZESNGxLV9oYkTJ0a70GIf/fv3L3W/1atXs3DhQtq3bw/AunXraNiwIQAnnXQS69atA2Dt2rWcfPJPo7SlpKREQ60433//Pe+88w7du3ffp/3rr79m1apVdOvW7aA+X0Wg70GJJNHUqXDFFbBjBzzzDLz4IvTsWfbvM2DAAKpUqVLqNjNmzCAnJ4e2bdsCsG3bNk488cS4Xv+TTz7hjTfeAKBbt25s2LCBzZs388knnzB58mQAevfuzQknFD9V21FHHcVll10GwBVXXBHtrpo6dSo9gwPyr3/9i5tuugmAjIwMMjIy4qqtW7duvPXWW9x2221xfx6AwYMHM3jw4Li3h8jZ2iWXXMIjjzzCcccdt996Mzukuzx3797NoEGDuOmmmzj99NP3WTdp0iT69+9/wD/fikgBJZJEM2dGwqlmTdi6NbKciICqWbNm9HnVqlXZu3dvdLnweynuzpAhQ/jjH/9Y9gUcpMJ/xKdMmcKtt956WK81cOBAOnXqxPnnn89HH31E7dq149pv4sSJ/PnPf96vvUmTJtEbL2Lt2rWLSy65hMGDB+9zPahBgwYUFBTQsGFDCgoKoiHZqFEj1qz5aaag/Px8GjUqdqYghg0bRmpqKrfccst+6yZNmsQTTzwR12eqaNTFJ5JEXbpA9eqRcKpePbKcaKeddlr0jrEFCxawatUqALp3787rr7/ON99EJhDYuHEjX39d7CwI+znnnHOi3YEzZ86kXr16HHfccXTq1IlXX30ViJwNFd7BVtTevXuj/+i/9NJLdO7cGXcnLy+PzMxMAM4991xeeuklABYtWhS9zgVw1VVXMW/evBLrGzFiBN27d+fiiy+O+27GwYMHR29CiH0UF07uzjXXXEPz5s33C9S+ffvy/POROViff/55+vXrF21/4YUXcHfmzp1LnTp1ol2Bse6++242bdrEI488st+6L774gu+++46OHTvG9ZkqGgWUSBL17Bnp1rvhhsR17xV1ySWXsHHjRtLS0nj88cdp2rQpAC1atOC+++6jZ8+eZGRk0KNHj/0u2hf65S9/SUpKCikpKQwYMIDRo0eTk5NDRkYGI0eOjP6DPGrUKKZOnUp6ejqvvfYaJ510UrFnMDVr1mTevHmkp6fz4Ycfcs8995CTk0OrVq2iZ1PXX389W7ZsoXnz5txzzz20adMmun9eXh4///nPS/3cf/rTn0hJSeHKK69k7969LFu2LPoZUlJSeO211w7peALMnj2bv//973z44YfRa1XvvfceACNHjmTatGmkpqYyffr06K3i559/PqeffjpNmjTh2muv5a9//Wv09QpDOT8/n/vvv58lS5bQunVrMjMzeeaZZ6LbTZo0iYEDBx6xXw7XYLEiZWjp0qU0b9482WWExo4dO6hSpQpVq1Zlzpw5XH/99eTm5u63Xa1atdiyZcs+bffddx9NmjRh4MCBxb52ly5dePDBB2natCnXXHPNYQWMlJ/ifkc0WKyIlLv//Oc/XHrppezdu5dq1arx9NNPx73v3XffHdd2xx13nMLpCKWAEpGESU1NZeHChQfcrujZUzxmzpx5CBVJRaJrUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJXKEuf/++0lLSyMjI4PMzEw+/fRTAIYOHRodYftgHcr0FFWqVCEzM5P09HQuuOACvv/++1K3z83NjX53CODtt9+Ojvwdj+eee44bb7zxoGo8XG3atGHHjh0Hvd/MmTP51a9+FV2+++676d27Nzt27EjodCbbt2+nXbt2tGzZkrS0NEaNGhVd9/jjj9OkSRPMrNgBfQvdfvvtpKWl0bx5c2666SYKv6q0c+dOhg0bRtOmTWnWrFl06KvDEcqAMrMzzSw35rHZzG4xs9Fmtjam/fyYfe4wsxVmtszMeiWzfpFkmTNnDu+++y4LFiwgLy+P6dOnRwckfeaZZ2jRokW51XLMMceQm5vLokWL+NnPfnbA4XiKBlTfvn33mf8obFatWkWjRo2oXr36Yb3Offfdx+zZs5k8efJhv9aBVK9enQ8//JDPPvuM3Nxc3n//febOnQtAp06dmD59OqeeemqJ+//73/9m9uzZ5OXlsWjRIubPn8+sWbOAyH+MTjzxRL788kuWLFnCL37xi8OuN5QB5e7L3D3T3TOBNsCPwORg9cOF69z9PQAzawEMBNKA3sBfzezIGzlR5AAKCgqoV69e9B+6evXqRUdYiP2fea1atbjrrrto2bIlHTp0iI6wvXLlSjp06MBZZ53F3XffTa1atfZ7jz179vDf//3ftG3bloyMDP72t78dsK6OHTtGR+qeN28eHTt2pFWrVpx99tksW7aMnTt3cs899/DKK6+QmZnJK6+8ss8Z0erVq+nWrRsZGRl07969xFHRi/PQQw+Rnp5Oenr6PsMF/eEPf+DMM8+kc+fODBo0aJ8pSAoVTlWSlZVF06ZNeffdd6Pr3n///eg8VRMmTKBp06a0a9eOa6+9Nu4zubFjxzJlyhTeeecdjjnmmLg/06Eys+if6a5du9i1a1d0FIpWrVpx2mmnHXD/7du3s3PnTnbs2MGuXbto0KABAOPHj+eOO+4AIoP/1qtX77DrDWVAFdEdWOnupQ0K1g+Y5O473H0VsAJoVy7ViRyuMpxvo2fPnqxZs4amTZvy29/+Nvq/26K2bt1Khw4d+Oyzzzj33HOjX6C9+eabufnmm/n8889JSUkpdt9nn32WOnXqMH/+fObPn8/TTz8dHc+vOHv27GHGjBn07dsXgGbNmvHxxx+zcOFC7r33Xu68806qVavGvffey2WXXUZubm50ZPNCv/vd7xgyZAh5eXkMHjw4Oqr5geTk5DBhwgQ+/fRT5s6dy9NPP83ChQuZP38+b7zxBp999hlTpkwptUtt9erVzJs3j3/+858MHz48OrhuYUAVFBQwatQoZs+ezSeffBJ3N+rs2bN56qmnmDJlSrH/ESjJDz/8UOw0IJmZmXG99549e8jMzOTEE0+kR48e0WlB4tGxY0e6du1Kw4YNadiwIb169aJ58+bR7tvf//73tG7dmgEDBkT/03M4KkJADQRejlm+0czyzGy8mRWO3d8IWBOzTX7Qtg8zG2Zm2WaWHTshmkjSFM638cQTkZ+HGVK1atUiJyeHcePGUb9+fS677DKee+65/barVq1a9BpImzZtohMWzpkzhwEDBgBw+eWXl1DyVF544QUyMzNp3749GzZsYPny5fttt23btujU8+vWraNHjx5AZHbbAQMGkJ6ezogRI1i8ePEBP9ecOXOi9Vx55ZV88sknB9wHItOAXHTRRdSsWZNatWpx8cUX8/HHHzN79mz69etHjRo1qF27NhdccEGJr3HppZdy1FFHkZqayumnn84XX3zBzp07yc/P5/TTT+fTTz+lS5cu1K9fn2rVqu0XriVp0qQJ7s60adPi2r5Q7dq1ix3ENjc3N64u3CpVqpCbm0t+fj7z5s1j0aJFcb/3ihUrWLp0Kfn5+axdu5YPP/yQjz/+mN27d5Ofn8/ZZ5/NggUL6NixI7fddttBfa7ihDqgzKwa0BcoHMfkSeAMIBMoAMYezOu5+zh3z3L3rPr165dlqSKHJna+jR07IsuHqUqVKnTp0oUxY8bw+OOPF3ux+uijj4527VSpUuWgpoJ3dx577LHoP4qrVq2KztkUq/Aa1Ndff427R69B/f73v6dr164sWrSId955J3pGElZFB2I1Mz7++GM6d+58WK/boEED3nvvPW655RY++uijuPc7mDOoNWvWRNc99dRT+6w7/vjj6dq1K++//37c7z158mQ6dOhArVq1qFWrFn369GHOnDnUrVuXY489NjrNyIABA6Ij5h+OUAcU0AdY4O7rANx9nbvvcfe9wNP81I23Fjg5Zr+UoE0k3Mp4vo1ly5btczaTm5tb6kXvojp06BANtEmTJhW7Ta9evXjyySfZtWsXAF9++SVbt24t8TWPPfZYHn30UcaOHcvu3bvZtGlTdN6j2LO72rVr88MPPxT7GmeffXa0nokTJ3LOOefE9XnOOecc3nzzTX788Ue2bt3K5MmTOeecc+jUqVM0HLds2bLPtaWiXnvtNfbu3cvKlSv56quvOPPMM3n//ffp06cPAO3bt2fWrFls2LCBXbt27TMu4OTJk6PXZYrTtGlT/vGPf3DFFVcUO4hucQ7mDOrkk0+Orhs+fDjr16+Pdsdt27aNadOm0axZs7jeF+CUU05h1qxZ7N69m127djFr1iyaN2+OmXHBBRdEh5+aMWNGmdyQE/aAGkRM956ZxU6WchFQeG76NjDQzKqbWWMgFSh5chiRsCjj+Ta2bNnCkCFDaNGiBRkZGSxZsoTRo0fHvf8jjzzCQw89REZGBitWrKBOnTr7bTN06FBatGhB69atSU9P57rrrjvgGVirVq3IyMjg5Zdf5vbbb+eOO+6gVatW++zXtWtXlixZEr1JItZjjz3GhAkTyMjI4O9//zt/+ctfin2f5557bp8pNE488USuvvpq2rVrR/v27Rk6dCitWrWibdu29O3bl4yMDPr06cNZZ51V7GeFyD/K7dq1o0+fPjz11FPUqFGDmTNnRu9Sa9iwIaNHj6Zjx4506tRpn5G6V65cWezMurHatm3LhAkT6Nu3LytXrgT2n86krBQUFNC1a1cyMjJo27YtPXr0iHb1Pvroo6SkpJCfn09GRgZDhw4FIDs7O/q8f//+nHHGGZx11lm0bNmSli1bRrtH//SnPzF69Ojon9HYsQfVwVU8dw/lA6gJbADqxLT9HfgcyCMSSg1j1t0FrASWAX0O9Ppt2rRxkbK2ZMmSZJdwWLZu3ep79+51d/eXX37Z+/btm+SKEueHH35w98hnbtOmjefk5Oy3zZAhQ/y1117bp23NmjXeu3fvEl93woQJfsMNN7i7++DBg/2bb74pw6orvuJ+R4BsL+bf6dCOZu7uW4G6RdquLGX7+4H7E12XyJEsJyeHG2+8EXfn+OOPZ/z48ckuKWGGDRvGkiVL2L59O0OGDKF169Zx7ZeSksKUKVPi2vbFF188nBIrPU1YKFKGNGGhSOkOZsLCsF+DEqlwKut/+kQO5GB/NxRQImWoRo0abNiwQSElUoS7s2HDBmrUqBH3PqG9BiVSERXeBaUvgovsr0aNGiWOUFIcBZRIGTr66KNp3LhxsssQOSKoi09EREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJpdAGlJmtNrPPzSzXzLKDtp+Z2TQzWx78PCFoNzN71MxWmFmembVObvUiInK4QhtQga7ununuWcHySGCGu6cCM4JlgD5AavAYBjxZ7pWKiEiZCntAFdUPeD54/jxwYUz7Cx4xFzjezBomoT4RESkjYQ4oB6aaWY6ZDQvaGrh7QfD8/4AGwfNGwJqYffODtn2Y2TAzyzaz7PXr1yeqbhERKQNVk11AKTq7+1ozOxGYZmZfxK50dzczP5gXdPdxwDiArKysg9pXRETKV2jPoNx9bfDzG2Ay0A5YV9h1F/z8Jth8LXByzO4pQZuIiFRQoQwoM6tpZrULnwM9gUXA28CQYLMhwFvB87eBq4K7+ToAm2K6AkVEpAIKaxdfA2CymUGkxpfc/X0zmw+8ambXAF8DlwbbvwecD6wAfgR+Xf4li4hIWQplQLn7V0DLYto3AN2LaXfghnIoTUREykkou/hEREQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiIRS6ALKzE42s4/MbImZLTazm4P20Wa21sxyg8f5MfvcYWYrzGyZmfVKXvUiIlJWqia7gGLsBv7L3ReYWW0gx8ymBesedvcHYzc2sxbAQCAN+Dkw3cyauvuecq1aRETKVOjOoNy9wN0XBM9/AJYCjUrZpR8wyd13uPsqYAXQLvGViohIIoUuoGKZ2WlAK+DToOlGM8szs/FmdkLQ1ghYE7NbPiUEmpkNM7NsM8tev359osoWKTNTp8Kdd0Z+ilQ2oQ0oM6sFvAHc4u6bgSeBM4BMoAAYe7Cv6e7j3D3L3bPq169fluWKlLnUVOjVC/74R7j4YoWUVD6hDCgzO5pIOE10938AuPs6d9/j7nuBp/mpG28tcHLM7ilBm0iFlZoKK1b8tLx1K1x1VfLqEUmG0AWUmRnwLLDU3R+KaW8Ys9lFwKLg+dvAQDOrbmaNgVRgXnnVK1LWpk7dN5wKrVsHv/lN+dcjkixhvIuvE3Al8LmZ5QZtdwKDzCwTcGA1cB2Auy82s1eBJUTuALxBd/BJRXb++SWvmzgRxo8vv1pEkil0AeXunwBWzKr3StnnfuD+hBUlUo72lPLfq507y68OkWQLXRefiJRON0tIZaGAEqlgZs5MdgUi5eOAAWVmdczs4cLvD5nZWDOrUx7FiVRGVlwHd4wuXcqlDJGki+cMajywGbg0eGwGJiSyKJHKrF69kteddRb07Fl+tYgkUzwBdYa7j3L3r4LHGOD0RBcmUlmNGFF8+1FHwYMPFr9O5EgUz11828ysc3B3HWbWCdiW2LJEKq877oj8fOcdWL8e8vPhpJPgb3/T2ZNULubupW8Q+e7R80AdIrd/bwSudvfPEl5dAmVlZXl2dnayyxARqfTMLMfds4q2H/AMyt1zgZZmdlywvLnsyxMREdlXiQFlZle4+4tmdmuRdgBihyESEREpa6WdQdUMftYuZl3p/YIiIiKHqcSAcve/BU+nu/vs2HXBjRIiIiIJE89t5o/F2SYiIlJmSrsG1RE4G6hf5DrUcUCVRBcmIiKVW2nXoKoBtYJtYq9DbQb6J7IoERGR0q5BzQJmmdlz7v51OdYkIiIS10gSP5rZn4E0oEZho7t3S1hVIiJS6cVzk8RE4AugMTCGyGy28xNYk4iISFwBVdfdnwV2ufssd/8NoLMnERFJqHi6+HYFPwvM7JfA/wI/S1xJIiIi8QXUfcEEhf9F5PtPxwG3JLIoERGRA3bxufu77r7J3Re5e1d3b0NkRPNQMbPeZrbMzFaY2chk1yMiIoenxIAysypmNsjMbjOz9KDtV2b2b+DxcqswDmZWBXgC6AO0AAaZWYvkViUiIoejtC6+Z4GTgXnAo2b2v0AWMNLd3yyH2g5GO2CFu38FYGaTgH7AkqRWJSIih6y0gMoCMtx9r5nVAP6PyPTvG8qntIPSCFgTs5wPtC+6kZkNA4YBnHLKKeVTmYiIHJLSrkHtdPe9AO6+HfgqpOEUN3cf5+5Z7p5Vv379ZJcjUrI2bcBs38dvfpPsqkTKVWlnUM3MLC94bsAZwbIB7u4ZCa8ufmuJdEcWSgnaRCqeNm1gwYL92ydMgLVr4YMPyr8mkSQoLaCal1sVh28+kGpmjYkE00Dg8uSWJHKIigunQlOnRh49e5ZfPSJJUtpgsRVmgFh3321mNwIfEJkKZLy7L05yWSKJMXOmAkoqhXi+qFshuPt7wHvJrkMk4bp0SXYFIuUinrH4RKQ8uZe87oMPdPYklcYRcwYlckQpLaREKokDBpSZfQ4U/W3ZBGQD91X0W89FRCSc4jmDmgLsAV4KlgcCxxL54u5zwAUJqUxERCq1eALqPHdvHbP8uZktcPfWZnZFogoTEZHKLZ6bJKqYWbvCBTNrS+RWboDdCalKREQqvXjOoIYC482sFpFRJDYD15hZTeCPiSxOREQqrwMGlLvPB84KJi3E3TfFrH41UYWJiEjldsAuPjOrY2YPATOAGWY2tjCsREREEiWea1DjgR+AS4PHZmBCIosSERGJ5xrUGe5+SczyGDPLTVA9IiIiQHxnUNvMrHPhgpl1ArYlriQREZH4zqCGAy/EXHf6DhiSuJJERETiu4vvM6ClmR0XLG82s1uAvFJ3FBEROQxxj2bu7pvdfXOweGuC6hEREQEOfboNK9MqREREijjUgNJcACIiklAlXoMysx8oPogMOCZhFYmIiFBKQLl77fIsREREJJamfBcRkVAKVUCZ2Z/N7AszyzOzyWZ2fNB+mpltM7Pc4PFUzD5tzOxzM1thZo+amW7gEBE5AoQqoIBpQLq7ZwBfAnfErFvp7pnBY3hM+5PAtUBq8OhdbtWKiEjChCqg3H2quxdOgjgXSCltezNrCBzn7nPd3YEXgAsTW6WIiJSHUAVUEb8BpsQsNzazhWY2y8zOCdoaAfkx2+QHbcUys2Fmlm1m2evXry/7ikVEpMzEMxZfmTKz6cBJxay6y93fCra5i8h08hODdQXAKe6+wczaAG+aWdrBvre7jwPGAWRlZem7XCIiIVbuAeXu55W23syuBn4FdA+67XD3HcCO4HmOma0EmgJr2bcbMCVoExGRCi5UXXxm1hu4Hejr7j/GtNc3syrB89OJ3AzxlbsXAJvNrENw995VwFtJKF1ERMpYuZ9BHcDjQHVgWnC3+Nzgjr1zgXvNbBewFxju7huDfX4LPEdkdIsp7HvdSkREKqhQBZS7Nymh/Q3gjRLWZQPpiaxLRETKX6i6+ERERAopoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQml0AWUmY02s7Vmlhs8zo9Zd4eZrTCzZWbWK6a9d9C2wsxGJqdyEREpS1WTXUAJHnb3B2MbzKwFMBBIA34OTDezpsHqJ4AeQD4w38zedvcl5VmwiIiUrbAGVHH6AZPcfQewysxWAO2CdSvc/SsAM5sUbKuAEhGpwELXxRe40czyzGy8mZ0QtDUC1sRskx+0ldS+HzMbZmbZZpa9fv36RNQtIiJlJCkBZWbTzWxRMY9+wJPAGUAmUACMLav3dfdx7p7l7ln169cvq5cVEZEESEoXn7ufF892ZvY08G6wuBY4OWZ1StBGKe0iIlJBha6Lz8waxixeBCwKnr8NDDSz6mbWGEgF5gHzgVQza2xm1YjcSPF2edYsIiJlL4w3Sfz/ZpYJOLAauA7A3Reb2atEbn7YDdzg7nsAzOxG4AOgCjDe3RcnoW4RESlD5u7JriEpsrKyPDs7O9lliIhUemaW4+5ZRdtD18UnIiICCigREQkpBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCKVQBZWavmFlu8FhtZrlB+2lmti1m3VMx+7Qxs8/NbIWZPWpmlrQPICIiZaZqsguI5e6XFT43s7HAppjVK909s5jdngSuBT4F3gN6A1MSWKaIiJSDUJ1BFQrOgi4FXj7Adg2B49x9rrs78AJwYeIrFBGRRAtlQAHnAOvcfXlMW2MzW2hms8zsnKCtEZAfs01+0FYsMxtmZtlmlr1+/fqyr1pERMpMuXfxmdl04KRiVt3l7m8Fzwex79lTAXCKu28wszbAm2aWdrDv7e7jgHEAWVlZfrD7i4hI+Sn3gHL380pbb2ZVgYuBNjH77AB2BM9zzGwl0BRYC6TE7J4StImISAUXxi6+84Av3D3adWdm9c2sSvD8dCAV+MrdC4DNZtYhuG51FfBWcS8qIiIVS6ju4gsMZP+bI84F7jWzXcBeYLi7bwzW/RZ4DjiGyN17uoNPROQIELqAcveri2l7A3ijhO2zgfQElyUiIuUsjF18IiIiCigREQknBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCKSkBZWYDzGyxme01s6wi6+4wsxVmtszMesW09w7aVpjZyJj2xmb2adD+iplVK8/PIiIiiZGsM6hFwMXAv2IbzawFMBBIA3oDfzWzKmZWBXgC6AO0AAYF2wL8CXjY3ZsA3wHXlM9HEBGRREpKQLn7UndfVsyqfsAkd9/h7quAFUC74LHC3b9y953AJKCfmRnQDXg92P954MKEfwAREUm4qskuoIhGwNyY5fygDWBNkfb2QF3ge3ffXcz2+zGzYcCwYHGLmRUXkslQD/g22UWElI5N6XR8SqZjU7KwHZtTi2tMWECZ2XTgpGJW3eXubyXqfUvj7uOAccl479KYWba7Zx14y8pHx6Z0Oj4l07EpWUU5NgkLKHc/7xB2WwucHLOcErRRQvsG4HgzqxqcRcVuLyIiFVjYbjN/GxhoZtXNrDGQCswD5gOpwR171YjcSPG2uzvwEdA/2H8IkJSzMxERKVvJus38IjPLBzoC/zSzDwDcfTHwKrAEeB+4wd33BGdHNwIfAEuBV4NtAf4/4FYzW0HkmtSz5ftpykTouh1DRMemdDo+JdOxKVmFODYWOQkREREJl7B18YmIiAAKKBERCSkFVBKZ2Z/N7AszyzOzyWZ2fMy6Yod8qiwOZTisyqSkob8qKzMbb2bfmNmimLafmdk0M1se/DwhmTUmg5mdbGYfmdmS4Pfp5qC9QhwbBVRyTQPS3T0D+BK4A0oe8ilpVSbHQQ2HVf7lJc8Bhv6qrJ4j8vch1khghrunAjOC5cpmN/Bf7t4C6ADcEPxdqRDHRgGVRO4+NWYUjLlEvscFJQ/5VGkcwnBYlUmxQ38luaakcvd/ARuLNPcjMvwZVNJh0Ny9wN0XBM9/IHIXdCMqyLFRQIXHb4ApwfNG7D+0U4lDOFUyOjY6BvFq4O4FwfP/Axoks5hkM7PTgFbAp1SQYxO2sfiOOPEM+WRmdxE5FZ9YnrUlWxiHw5Ijk7u7mVXa79SYWS3gDeAWd98cGWc7IszHRgGVYAca8snMrgZ+BXT3n76UVtqQT0eMBAyHVVnoGMRnnZk1dPcCM2sIfJPsgpLBzI4mEk4T3f0fQXOFODbq4ksiM+sN3A70dfcfY1aVNOST6NhACUN/JbmmMHqbyPBnUEmHQQumJHoWWOruD8WsqhDHRiNJJFEwPFN1IoPeAsx19+HBuruIXJfaTeS0fErxr3JkMrOLgMeA+sD3QK679wrWVepjA2Bm5wOPAFWA8e5+f3IrSi4zexnoQmQaiXXAKOBNIkOnnQJ8DVzq7kVvpDiimVln4GPgc2Bv0HwnketQoT82CigREQkldfGJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiISSAkqkHJnZXcGo0nlmlmtm7RP4XjOLjgQvUpFoJAmRcmJmHYmMGtLa3XeYWT2gWpLLEgktnUGJlJ+GwLfuvgPA3b919/81s3vMbL6ZLTKzccG3/wvPgB42s2wzW2pmbc3sH8EcPvcF25wWzCk2MdjmdTM7tugbm1lPM5tjZgvM7LVgbDbM7IFgrqA8M3uwHI+FyAEpoETKz1TgZDP70sz+ama/CNofd/e27p4OHEPkLKvQTnfPAp4iMhzNDUA6cLWZ1Q22ORP4q7s3BzYDv4190+BM7W7gPHdvDWQDtwb7XwSkBXOS3ZeAzyxyyBRQIuXE3bcAbYBhwHrglWCw4K5m9qmZfQ50IzIZY6HCMfY+BxYH8/vsAL7ipwFj17j77OD5i0DnIm/dgcjEhrPNLJfI2GunApuA7cCzZnYx8CMiIaJrUCLlyN33ADOBmUEgXQdkAFnuvsbMRgM1YnbZEfzcG/O8cLnw97foeGVFlw2Y5u6DitZjZu2A7kB/4EYiASkSCjqDEiknZnammaXGNGUChbMGfxtcF+p/CC99SnADBsDlwCdF1s8FOplZk6COmmbWNHi/Ou7+HjACaHkI7y2SMDqDEik/tYDHzOx4IiOxryDS3fc9sIjIzKbzD+F1lwE3mNl4YAnwZOxKd18fdCW+bGbVg+a7gR+At8ysBpGzrFsP4b1FEkajmYtUYME03u8GN1iIHFHUxSciIqGkMygREQklnUGJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiITS/wPlXY6qB1iubAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sr",
   "language": "python",
   "name": "sr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}