{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "N_DIMS = 1\n",
    "NUM_SAMPLES = 330\n",
    "BS = 990\n",
    "NUM_EPOCHS = 5000\n",
    "SEED = 10\n",
    "LR = 1e-2\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "# Break by changing num datapoints, scales, means, or to 2D"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "# Define model\n",
    "model = RatioCritic1D_overlap(dim_input=N_DIMS, dim_output=3, dropout=DROPOUT)\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_1d_overlap(mu1=0., mu2=0., mu3=0, scale_p=1e-6, scale_q=1., scale_m=1.0)\n",
    "print(m)\n",
    "# -5, 5, m_var=3.0\n",
    "# -10, 10, m_var=3.0"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[Normal(loc: 0.0, scale: 9.999999974752427e-07), MixtureSameFamily(\n",
      "  Categorical(probs: torch.Size([2]), logits: torch.Size([2])),\n",
      "  Independent(Normal(loc: torch.Size([2]), scale: torch.Size([2])), 0)), MixtureSameFamily(\n",
      "  Categorical(probs: torch.Size([2]), logits: torch.Size([2])),\n",
      "  Independent(Normal(loc: torch.Size([2]), scale: torch.Size([2])), 0)), MixtureSameFamily(\n",
      "  Categorical(probs: torch.Size([2]), logits: torch.Size([2])),\n",
      "  Independent(Normal(loc: torch.Size([2]), scale: torch.Size([2])), 0)), Normal(loc: 0.0, scale: 1.0)]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset(p, q, None, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset(p, q, None, num_samples=NUM_SAMPLES) # Test dataset is only of size batch "
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Sampling p\n",
      "Sampling q\n",
      "None\n",
      "Linear mixing for m samples\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n",
      "Sampling p\n",
      "Sampling q\n",
      "None\n",
      "Linear mixing for m samples\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n",
      "torch.Size([330])\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=len(test_ds), shuffle=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3] = plt.subplots(1, 3,figsize=(15,4))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line, = ax3.plot([0,1],[0,1])\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-6,10])\n",
    "ax2.set_ylim([-1500,5000])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Test Loss\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax3.set_ylim([0,10])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "test_loss_store = []"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABDEAAAEYCAYAAABFkW3UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyIklEQVR4nO3de7gcVZno/+87SSDIXQzXoImIMEEgQgggyMTAhIAekBGFnEGDw5yIPxlBRxDU3wEdHEE8iDqMYx5BQBCCCMphkHCRiKASEuSaGAk3SQQJAUHuJHnPH107NDu7s3eyd3d19f5+nqefrlpVXfX2SlfVyrtXrYrMRJIkSZIkqd39TdkBSJIkSZIk9YVJDEmSJEmSVAkmMSRJkiRJUiWYxJAkSZIkSZVgEkOSJEmSJFWCSQxJkiRJklQJTUtiRMT5EfFkRNxXV/bmiLghIh4o3jdt1v4labCKiEci4t6IuCsi5hRlPZ5/o+bbEbEwIu6JiN3qtjO1WP+BiJha1veRJNm2lqQuzeyJcQEwuVvZycBNmbk9cFMxL0kaeO/LzLGZOa6Yb3T+PQjYvnhNA74LtYYxcCqwJzAeONXGsSSV6gJsW0tS85IYmXkL8HS34kOBC4vpC4EPNmv/kqQ3aHT+PRS4KGt+C2wSEVsBBwI3ZObTmfkMcAOrNp4lSS1i21qSaoa2eH9bZObjxfQTwBaNVoyIadT+Ksj666+/+4477tiC8CRp7cydO/epzBxRdhyFBK6PiAS+l5nTaXz+3QZ4rO6zi4qyRuVv4LlaUpW02bl6INi2ltSRVne+bnUSY6XMzKKB3Wj5dGA6wLhx43LOnDkti02S1lREPFp2DHX2zczFEbE5cENE/L5+YW/n3zXhuVpSlbTZuXpA2baW1ElWd75u9dNJ/lx0U6Z4f7LF+5ekjpeZi4v3J4GrqI1p0ej8uxjYtu7jI4uyRuWSpPZh21rSoNPqJMbVQNcI91OBn7V4/5LU0SJi/YjYsGsamATcR+Pz79XAx4qnlOwFPFt0TZ4JTIqITYsBPScVZZKk9mHbWtKg07TbSSLiUmAC8JaIWERtlPszgMsj4hjgUeAjzdq/JA1SWwBXRQTUzvE/yszrIuIOej7/XgscDCwEXgQ+DpCZT0fEvwF3FOt9JTO7DygnSWoR29aSVNO0JEZmTmmwaP9m7VOSBrvMfAjYtYfypfRw/s3MBD7VYFvnA+cPdIyS1s5rr73GokWLePnll8sOpe0NHz6ckSNHMmzYsLJDGTC2rSWpprSBPSVJktR3ixYtYsMNN2TUqFEUva3Ug8xk6dKlLFq0iNGjR5cdjiRpgLV6TAxJkiSthZdffpnNNtvMBEYvIoLNNtvMHiuS1KFMYkiSJFWECYy+sZ4kqXOZxJAkSZIkSZVgEkOSJEm9Wrp0KWPHjmXs2LFsueWWbLPNNivnX3311abv//HHH2fSpElN348kqb05sKckSZJ6tdlmm3HXXXcBcNppp7HBBhvwuc99buXyZcuWMXRo85qW1113HQceeGDTti9JqgZ7YkiSJGmtHH300Rx77LHsueeenHTSSZx22ml84xvfWLn8Xe96F4888ggAF198MePHj2fs2LF84hOfYPny5atsb9SoUZx00knsvPPOjB8/noULF65cdt1113HQQQeRmRx33HHssMMOHHDAARx88MFcccUVTf+ukqT2YBJDkiSpQ11/PXzhC7X3Zlm0aBG//vWvOfvssxuuM3/+fGbMmMFtt93GXXfdxZAhQ7jkkkt6XHfjjTfm3nvv5bjjjuOEE04AYPny5SxYsIAxY8Zw1VVXsWDBAubNm8dFF13Er3/962Z8LUlSm/J2EkmSpA50/fVw1FHwyivw/e/DxRdDM4aU+PCHP8yQIUNWu85NN93E3Llz2WOPPQB46aWX2HzzzXtcd8qUKSvfP/OZzwBw++23s+eeewJwyy23MGXKFIYMGcLWW2/NxIkTB+qrSJIqwCSGJElSB5o1q5bAWH99eOGF2nwzkhjrr7/+yumhQ4eyYsWKlfMvv/wyAJnJ1KlT+drXvtbr9uofj9o1/fOf/5zJkycPVMiSpArzdhJJkqQONGECrLtuLYGx7rq1+WYbNWoUd955JwB33nknDz/8MAD7778/V1xxBU8++SQATz/9NI8++miP25gxY8bK97333huo9eQ44IADANhvv/2YMWMGy5cv5/HHH+fmm29u6neSJLUXe2JIkiR1oEmTareQzJpVS2C04umkH/rQh7jooovYaaed2HPPPXnnO98JwJgxYzj99NOZNGkSK1asYNiwYZx77rm87W1vW2UbzzzzDLvssgvrrrsul156KUuWLGH48OFsuOGGABx22GH84he/YMyYMbz1rW9dmeiQJA0OJjEkSZI61KRJzUlenHbaaT2Wr7feelzfYBTRI444giOOOKLXbZ944omceeaZK+cvvvhiJtV9iYjgP/7jP1bOH3300X0LWpLUEUxiSJIkqW0dddRRZYcgSWojJjEkSZLUFh555JE1/swFF1ww4HFIktqXA3tKkiRJkqRKMIkhSZIkSZIqwSSGJEmSJEmqBJMYkiRJkiSpEkxiSJIkqU+eeOIJjjzySLbbbjt23313Dj74YP7whz80XP+RRx5hvfXWY+zYsey666685z3vYcGCBWu83zPOOINLLrmkP6FLkjqESQxJkiT1KjM57LDDmDBhAg8++CBz587la1/7Gn/+859X+7ntttuOu+66i7vvvpupU6fy7//+72u875kzZzJp0qS1DV2S1EFMYkiSJKlXN998M8OGDePYY49dWbbrrrvy3ve+l8zkxBNP5F3vehc777wzM2bM6HEbzz33HJtuuukq5bNmzWK//fbj/e9/PzvssAPHHnssK1asWPmZV199lREjRvDwww+z9957s/POO/OlL32JDTbYoDlfVpLUtoaWHYAkSZKa5PrrYdYsmDAB+tmT4b777mP33XfvcdmVV165srfFU089xR577MF+++0HwIMPPsjYsWP561//yosvvsjtt9/e4zZmz57NvHnzeNvb3sbkyZO58sorOfzww7nxxhvZf//9ATj++OP55Cc/ycc+9jHOPffcfn0fSVI12RNDkiSpE11/PRx1FJx7bu39+uubtqtbb72VKVOmMGTIELbYYgv+7u/+jjvuuAN4/XaSBx98kHPOOYdp06b1uI3x48fz9re/nSFDhjBlyhRuvfVWAK677joOOuggAG677TamTJkCwEc/+tGmfR9JUvsyiSFJktSJZs2CV16B9devvc+a1a/N7bTTTsydO7df2zjkkEO45ZZbelwWET3Oz549m/HjxzdcT5I0uJjEkCRJ6kQTJsC668ILL9TeJ0zo1+YmTpzIK6+8wvTp01eW3XPPPfzqV7/ive99LzNmzGD58uUsWbKEW2655Q2Jhy633nor2223XY/bnz17Ng8//DArVqxgxowZ7Lvvvtx///3suOOODBkyBIB99tmHyy67DMCnlUjSIGUSQ5IkqRNNmgQXXwyf+lTtvZ9jYkQEV111FTfeeCPbbbcdO+20E6eccgpbbrklhx12GLvssgu77rorEydO5Otf/zpbbrkl8PqYGLvuuitf+MIX+P73v9/j9vfYYw+OO+44/vZv/5bRo0dz2GGH8fOf/5zJkyevXOdb3/oW5557LjvvvDOLFy/u1/eRJFWTA3tKUgeKiCHAHGBxZn4gIkYDlwGbAXOBj2bmqxGxLnARsDuwFDgiMx8ptnEKcAywHPh0Zs5s/TeR1C+TJvU7eVFv66235vLLL+9x2VlnncVZZ531hrJRo0bx0ksv9WnbG220Eddcc80bymbOnMlFF120cn706NH85je/WTl/zjnn9DFySVKnsCeGJHWm44H5dfNnAt/MzHcAz1BLTlC8P1OUf7NYj4gYAxwJ7ARMBv6zSIxIUsvccMMNbLXVVmWHIUlqIyYxJKnDRMRI4P3A94v5ACYCVxSrXAh8sJg+tJinWL5/sf6hwGWZ+UpmPgwsBFa9wV2SBsCECRNW6YXRF88//3wTopEktTOTGJLUec4BTgJWFPObAX/JzGXF/CJgm2J6G+AxgGL5s8X6K8t7+MxKETEtIuZExJwlS5YM8NeQ1F1mlh1CJVhPktS5TGJIUgeJiA8AT2Zm/56D2EeZOT0zx2XmuBEjRrRil9KgNXz4cJYuXep/0HuRmSxdupThw4eXHYokqQkc2FOSOss+wCERcTAwHNgI+BawSUQMLXpbjAS6hvVfDGwLLIqIocDG1Ab47CrvUv8ZSSUYOXIkixYtwl5PvRs+fDgjR44sOwxJUhOYxJCkDpKZpwCnAETEBOBzmfmPEfFj4HBqTyiZCvys+MjVxfxviuW/yMyMiKuBH0XE2cDWwPbA7BZ+FUndDBs2jNGjR5cdhiRJpTKJIUmDw+eByyLidOB3wHlF+XnADyNiIfA0tSeSkJn3R8TlwDxgGfCpzFze+rAlSZKk15nEkKQOlZmzgFnF9EP08HSRzHwZ+HCDz38V+GrzIpQkSZLWjAN7SpIkSZKkSigliRERn4mI+yPivoi4NCIcPlqSJElaC7atJQ0mLU9iRMQ2wKeBcZn5LmAIxT3YkiRJkvrOtrWkwaas20mGAusVj/N7E/CnkuKQJEmSqs62taRBo+VJjMxcDHwD+CPwOPBsZl7f6jgkSZKkqrNtLWmwKeN2kk2BQ4HRwNbA+hFxVA/rTYuIORExZ8mSJa0OU5IkSWp7tq0lDTZl3E5yAPBwZi7JzNeAK4H3dF8pM6dn5rjMHDdixIiWBylJkiRVgG1rSYNKGUmMPwJ7RcSbIiKA/YH5JcQhSZIkVZ1ta0mDShljYtwOXAHcCdxbxDC91XFIkiRJVWfbWtJgM7SMnWbmqcCpZexbkiRJ6iS2rSUNJmU9YlWSJEmSJGmNmMSQJEmSJEmVYBJDkiRJkiRVgkkMSZIkSZJUCSYxJEmSJElSJZjEkCRJkiRJlWASQ5IkSZIkVYJJDEmSJEmSVAkmMSRJkiRJUiWYxJAkSZIkSZVgEkOSJEmSJFWCSQxJkiRJklQJJjEkSZIkSVIlmMSQJEmSJEmVYBJDkiRJkiRVgkkMSZIkSZJUCSYxJKmDRMTwiJgdEXdHxP0R8eWifHRE3B4RCyNiRkSsU5SvW8wvLJaPqtvWKUX5gog4sKSvJEmSJK1kEkOSOssrwMTM3BUYC0yOiL2AM4FvZuY7gGeAY4r1jwGeKcq/WaxHRIwBjgR2AiYD/xkRQ1r5RSRJkqTuTGJIUgfJmueL2WHFK4GJwBVF+YXAB4vpQ4t5iuX7R0QU5Zdl5iuZ+TCwEBjf/G8gSZIkNWYSQ5I6TEQMiYi7gCeBG4AHgb9k5rJilUXANsX0NsBjAMXyZ4HN6st7+IwkSZJUCpMYktRhMnN5Zo4FRlLrPbFjs/YVEdMiYk5EzFmyZEmzdiNJkiQBJjEkqWNl5l+Am4G9gU0iYmixaCSwuJheDGwLUCzfGFhaX97DZ+r3MT0zx2XmuBEjRjTja0iSJEkrmcSQpA4SESMiYpNiej3g74H51JIZhxerTQV+VkxfXcxTLP9FZmZRfmTx9JLRwPbA7JZ8CUmSJKmBob2vIkmqkK2AC4snifwNcHlmXhMR84DLIuJ04HfAecX65wE/jIiFwNPUnkhCZt4fEZcD84BlwKcyc3mLv4skSZL0BiYxJKmDZOY9wLt7KH+IHp4ukpkvAx9usK2vAl8d6BglSZKkteXtJJIkSZIkqRJMYkiSJEmSpEowiSFJkiRJkirBJIYkSZIkSaoEkxiSJEmSJKkSTGJIkiRJkqRKMIkhSZIkSZIqwSSGJEmSJEmqBJMYkiRJkiSpEkxiSJIkSZKkSjCJIUmSJEmSKsEkhiRJkiRJqgSTGJIkSZIkqRJKSWJExCYRcUVE/D4i5kfE3mXEIUmSJFWdbWtJg8nQkvb7LeC6zDw8ItYB3lRSHJIkSVLV2baWNGi0PIkRERsD+wFHA2Tmq8CrrY5DkiRJqjrb1pIGm15vJ4mI7SJi3WJ6QkR8OiI26cc+RwNLgB9ExO8i4vsRsX4P+50WEXMiYs6SJUv6sTtJkiSpY9m2ljSo9GVMjJ8AyyPiHcB0YFvgR/3Y51BgN+C7mflu4AXg5O4rZeb0zByXmeNGjBjRj91JkiRJHcu2taRBpS9JjBWZuQw4DPhOZp4IbNWPfS4CFmXm7cX8FdROvJI0aETExhHxza6/ikXE/ym6BEuStCZsW0saVPqSxHgtIqYAU4FrirJha7vDzHwCeCwidiiK9gfmre32JKmizgeeAz5SvJ4DflBqRJKkpouIr0fERhExLCJuioglEXHU2m7PtrWkwaYvA3t+HDgW+GpmPhwRo4Ef9nO//wJcUoye/FCxD0kaTLbLzA/VzX85Iu4qKxhJUstMysyTIuIw4BHgH4BbgIv7sU3b1pIGjV6TGJk5D/g0QERsCmyYmWf2Z6eZeRcwrj/bkKSKeyki9s3MWwEiYh/gpZJjkiQ1X1f7+/3AjzPz2Yjo1wZtW0saTHpNYkTELOCQYt25wJMRcVtmfrbJsUlSJ/skcGExDkYAT1M8Hk+S1NGuiYjfU0tcfzIiRgAvlxyTJFVGX24n2Tgzn4uIfwYuysxTI+KeZgcmSZ2s+KvZrhGxUTH/XLkRSZJaITNPjoivA89m5vKIeAE4tOy4JKkq+pLEGBoRW1EbeO6LTY5HkjpaRByVmRdHxGe7lQOQmWeXEpgkqSUi4sPAdUUC40vUniRyOvBEuZFJUjX05ekkXwFmAg9m5h0R8XbggeaGJUkda/3ifcMeXhuUFZQkqWX+/8z8a0TsCxwAnAd8t+SYJKky+jKw54+BH9fNPwR8qPEnJEmNZOb3iskbM/O2+mXF4J6SpM62vHh/PzA9M/87Ik4vMyBJqpJee2JExMiIuCoinixeP4mIka0ITpI62Hf6WLZGImLbiLg5IuZFxP0RcXxR/uaIuCEiHijeNy3KIyK+HRELI+KeiNitbltTi/UfiIip/Y1NkgTA4oj4HnAEcG1ErEvfekdLkujbmBg/AH4EfLiYP6oo+/tmBSVJnSoi9gbeA4zoNi7GRsCQAdjFMuBfM/POiNgQmBsRN1B78slNmXlGRJwMnAx8HjgI2L547UmtS/OeEfFm4FRqj+zLYjtXZ+YzAxCjJA1mHwEmA9/IzL8UY8+dWHJMklQZfcn6jsjMH2TmsuJ1ATCiyXFJUqdah9rYF0N543gYzwGH93fjmfl4Zt5ZTP8VmA9sQ23k+wuL1S4EPlhMH0rtyVOZmb8FNika1AcCN2Tm00Xi4gZqjW5JUj9k5ovAg8CBEXEcsHlmXl9yWJJUGX3pibE0Io4CLi3mpwBLmxeSJHWuzPwl8MuIuCAzH23mviJiFPBu4HZgi8x8vFj0BLBFMb0N8FjdxxYVZY3Ku+9jGjAN4K1vfesARi9Jnam4ze9/AVcWRRdHxPTM7PcthZI0GPQlifFP1O7T/ia1LsW/ptYtWZK09l6MiLOAnYDhXYWZOXEgNh4RGwA/AU7IzOe6HuFa7CMjIgdiP5k5HZgOMG7cuAHZpiR1uGOAPTPzBYCIOBP4DQMwLpIkDQa93k6SmY9m5iGZOSIzN8/MDwLHNz80SepolwC/B0YDXwYeAe4YiA1HxDBqCYxLMrPrL31/Lm4ToXh/sihfDGxb9/GRRVmjcklS/wSvP6GEYjoarCtJ6mZtR0L+yIBGIUmDz2aZeR7wWmb+MjP/Ceh3L4yodbk4D5ifmWfXLboa6HrCyFTgZ3XlHyueUrIX8Gxx28lMYFJEbFo8yWRSUSZJ6p8fALdHxGkRcRrwW2rnbUlSH/TldpKemC2WpP55rXh/PCLeD/wJePMAbHcf4KPAvRFxV1H2BeAM4PKIOAZ4lNeT0dcCBwMLgReBjwNk5tMR8W+83jvkK5n59ADEJ0mDWmaeHRGzgH2Loo8Dfy4vIkmqloZJjOLxej0uwiSGJPXX6RGxMfCv1O6D3gg4ob8bzcxbaXyO3r+H9RP4VINtnQ+c39+YJElvVDxF6s6u+Yj4I+DoyJLUB6vriTGX2kCePTWGX21OOJI0OGTmNcXks8D7ACJin/IikiSVyD8QSlIfNUxiZOboVgYiSYNBRAyhdivHNsB1mXlfRHyA2i0f61F7JKokaXDx6U6S1EdrOyaGJGntnEftqR+zgW9HxJ+AccDJmfnTMgOTJDVPRHyHnpMVAWzS2mgkqbpMYkhSa40DdsnMFRExHHgC2C4zl5YclySpueas5TJJUh2TGJLUWq9m5gqAzHw5Ih4ygSFJnS8zLyw7BknqBH1KYhT3cG9Rv35m/rFZQUlSB9sxIu4ppgPYrpgPag8L2aW80CRJkqT21msSIyL+BTiV2vOrVxTFCdjQlqQ197dlByBJkiRVVV96YhwP7GB3Z0nqv8x8tOwYJEnliYh9MvO23sokST37mz6s8xjwbLMDkSRJkgaB7/SxTJLUg770xHgImBUR/w280lWYmWc3LSpJkiSpg0TE3sB7gBER8dm6RRsBQ8qJSpKqpy9JjD8Wr3WKlyRJkqQ1sw6wAbX294Z15c8Bh5cSkSRVUK9JjMz8cisCkaTBJCLupTZIcr1ngTnA6Y5DJEmdJTN/CfwyIi7oGh8pIv4G2CAznys3OkmqjoZJjIg4JzNPiIj/y6oNbTLzkKZGJkmd7efAcuBHxfyRwJuAJ4ALgP9RTliSpCb7WkQcS+0acAewUUR8KzPPKjkuSaqE1fXE+GHx/o1WBCJJg8wBmblb3fy9EXFnZu4WEUeVFpUkqdnGZOZzEfGP1BLaJwNzAZMYktQHDZMYmTm3eP9l68KRpEFjSESMz8zZABGxB68P7LasvLAkSU02LCKGAR8E/iMzX4uIVXo9S5J61uuYGBGxPfA1YAwwvKs8M9/exLgkqdP9M3B+RGwABLWB3Y6JiPWpnXMlSZ3pe8AjwN3ALRHxNmrXAElSH/Tl6SQ/AE4Fvgm8D/g48DfNDEqSOl1m3gHsHBEbF/PP1i2+vJyoJEnNlpnfBr5dV/RoRLyvrHgkqWr6koxYLzNvAiIzH83M04D3NzcsSepsEbFxRJwN3ATcFBH/pyuhIUnqXBGxRUScFxE/L+bHAFNLDkuSKqMvSYxXisc/PRARx0XEYdSecS1JWnvnA38FPlK8nqPW802S1NkuAGYCWxfzfwBOKCsYSaqaviQxjqf22L9PA7sDR2G2WJL6a7vMPDUzHypeXwYca0iSOlREdN3G/ZbMvBxYAZCZy6g9blWS1AerTWJExBDgiMx8PjMXZebHM/NDmfnbFsUnSZ3qpYjYt2smIvYBXioxHklSc80u3l+IiM2ABIiIvYBnG35KkvQGDQf2jIihmbmsvpEtSRowxwIX1Y2D8Qz2cpOkThbF+2eBq4HtIuI2YARweGlRSVLFrO7pJLOB3YDfRcTVwI+BF7oWZuaVTY5NkjpWZt4N7BoRGxXzz0XECcA9pQYmSWqWERHx2WL6KuBaaomNV4AD8PwvSX3Sl0esDgeWAhOpdXuL4r1fSYziVpU5wOLM/EB/tiVJVZWZz9XNfhY4p6RQJEnNNYTa4PjRrfxNA7Fx29aSBovVJTE2L7LF9/F68qJLDsC+jwfmAxsNwLYkqRN0b9hKkjrH45n5lSZu37a1pEFhdQN7dmWLNwA2rJvueq21iBgJvB/4fn+2I0kdZiASxJKk9tS0RLVta0mDyep6YjQzW3wOcBK15EiPImIaMA3grW99a5PCkKTWioi/0nOyIoD1WhyOJKl19m/its/BtrWkQWJ1PTGaki2OiA8AT2bm3NWtl5nTM3NcZo4bMWJEM0KRpJbLzA0zc6MeXhtmZl/GKVqtiDg/Ip6MiPvqyt4cETdExAPF+6ZFeUTEtyNiYUTcExG71X1marH+AxHhU1MkqZ8y8+lmbNe2taTBZnVJjGZli/cBDomIR4DLgIkRcXGT9iVJg80FwORuZScDN2Xm9sBNxTzAQcD2xWsa8F2oJT2AU4E9gfHAqV2JD0lS27FtLWlQaZjEaFa2ODNPycyRmTkKOBL4RWYe1Yx9SdJgk5m3AN3P34cCFxbTFwIfrCu/KGt+C2wSEVsBBwI3ZObTmfkMcAOrJkYkSW3AtrWkwWZ1PTEkSZ1hi8x8vJh+AtiimN4GeKxuvUVFWaPyVUTEtIiYExFzlixZMrBRS5IkSd2UmsTIzFk+x1qSWiczkwF8Cor3WEtS+7BtLWkwsCeGJHW+Pxe3iVC8P1mULwa2rVtvZFHWqFySJEkqlUkMSep8VwNdTxiZCvysrvxjxVNK9gKeLW47mQlMiohNiwE9JxVlkiRJUqn6/Tg/SVL7iIhLgQnAWyJiEbWnjJwBXB4RxwCPAh8pVr8WOBhYCLwIfBxqAztHxL8BdxTrfaVZgz1LkiRJa8IkhiR1kMyc0mDRKo/NLsbH+FSD7ZwPnD+AoUmSJEn95u0kkiRJkiSpEkxiSJIkSZKkSjCJIUmSJEmSKsEkhiRJkiRJqgSTGJIkSZIkqRJMYkiSJEmSpEowiSFJkiRJkirBJIYkSZIkSaoEkxiSJEmSJKkSTGJIkiRJkqRKMIkhSZIkSZIqwSSGJEmSJEmqBJMYkiRJkiSpEkxiSJIkSZKkSjCJIUmSJEmSKsEkhiRJkiRJqgSTGJIkSZIkqRJMYkiSJEmSpEowiSFJkiRJkirBJIYkSZIkSaoEkxiSJEmSJKkSTGJIkiRJkqRKMIkhSZIkSZIqwSSGJEmSJEmqBJMYkiRJkiSpEkxiSJIkSZKkSjCJIUlqKCImR8SCiFgYESeXHY8kSZIGt6FlByBJak8RMQQ4F/h7YBFwR0RcnZnzyo1MA215xID+VWMFMCRzALcoSZJUY08MSVIj44GFmflQZr4KXAYcWnJMGmADncCAWuNiecQAb1WSJMkkhiSpsW2Ax+rmFxVlK0XEtIiYExFzlixZ0tLgNDCa1RCwgSFJkprBNoYkaa1l5vTMHJeZ40aMGFF2OFoLKyq2XUmSNLiZxJAkNbIY2LZufmRRpg4yJHPAEw6OiSFJkprFgT0lSY3cAWwfEaOpJS+OBP5nuSGpGQY64TBkQLcmSZL0upb3xIiIbSPi5oiYFxH3R8TxrY5BktS7zFwGHAfMBOYDl2fm/eVGJUmqZ9ta0mBTRk+MZcC/ZuadEbEhMDcibvCRfZLUfjLzWuDasuOQJDVk21rSoNLynhiZ+Xhm3llM/5XaX/e2Wf2nJEmSJHVn21rSYFPqwJ4RMQp4N3B7mXFIkiRJVWfbWtJgUFoSIyI2AH4CnJCZz/WwfFpEzImIOUuWLGl9gJIkSVJF2LaWNFiUksSIiGHUTrKXZOaVPa2TmdMzc1xmjhsxYkRrA5QkSZIqwra1pMGkjKeTBHAeMD8zz271/iVJkqROYdta0mBTRk+MfYCPAhMj4q7idXAJcUiSJElVZ9ta0qDS8kesZuatQLR6v5IkSVKnsW0tabAp9ekkkiRJkiRJfWUSQ5IkSZIkVYJJDEmSJEmSVAkmMSRJkiRJUiWYxJAkSZIkSZVgEkOSJEmSJFWCSQxJkiRJklQJJjEkSZIkSVIlmMSQJEmSJEmVYBJDkiRJkiRVgkkMSZIkSZJUCSYxJEmSJElSJZjEkCRJkiRJlWASQ5IkSZIkVYJJDEmSJEmSVAkmMSRJkiRJUiVUIomxfEWWHYIkSZLUEVakbWtJ1VWJJMajS18sOwRJkiSpIyx9/tWyQ5CktVaJJIYkSZIkSVIlkhgRZUcgSZIkdQjb1pIqrBJJDElS7yLiwxFxf0SsiIhx3ZadEhELI2JBRBxYVz65KFsYESfXlY+OiNuL8hkRsU4rv4skqXnMYUiqMpMYktQ57gP+AbilvjAixgBHAjsBk4H/jIghETEEOBc4CBgDTCnWBTgT+GZmvgN4BjimNV9BkiRJaqwSSQyzxZLUu8ycn5kLelh0KHBZZr6SmQ8DC4HxxWthZj6Uma8ClwGHRkQAE4Eris9fCHyw6V9AkiRJ6kUlkhiSpH7ZBnisbn5RUdaofDPgL5m5rFv5KiJiWkTMiYg5S5YsGfDAJUkDL/wToaQKG1p2AH3ieVaSAIiIG4Ete1j0xcz8WavjyczpwHSAcePGZav3L0laC7atJVVYJZIYZoslqSYzD1iLjy0Gtq2bH1mU0aB8KbBJRAwtemPUry9Jqjhb1pKqzNtJJKnzXQ0cGRHrRsRoYHtgNnAHsH3xJJJ1qA3+eXVmJnAzcHjx+alAy3t5SJIkSd2ZxJCkDhERh0XEImBv4L8jYiZAZt4PXA7MA64DPpWZy4teFscBM4H5wOXFugCfBz4bEQupjZFxXmu/jSRJkrSqatxOYp83SepVZl4FXNVg2VeBr/ZQfi1wbQ/lD1F7eokkqcPYtJZUZfbEkCRJkgYTsxiSKswkhiRJkiRJqoRKJDFMFkuSJEkDw7a1pCqrRBLDM60kSZI0UGxcS6quaiQxJEmSJA0McxiSKqwSSYzwTCtJkiQNCFvWkqqsEkkMSZIkSZKkSiQxzBZLkiRJA8O2taQqq0QSwzOtJEmSJEkqJYkREZMjYkFELIyIk8uIQZIkSeoEa9y29g+Ekiqs5UmMiBgCnAscBIwBpkTEmNV+phWBSZIkSRVj21rSYFNGT4zxwMLMfCgzXwUuAw4tIQ5JkiSp6mxbSxpUhpawz22Ax+rmFwF7dl8pIqYB04rZVyLivhbE1ldvAZ4qO4hCO8UCxtMb42msnWKBNY/nbc0KpCrmzp37fEQsKDuObtrtd9WlHeMypr5px5igPeNqx5h2KDuAJqh627rdfifGs3rtFE87xQLG05sBa1uXkcTok8ycDkwHiIg5mTmu5JBWaqd42ikWMJ7eGE9j7RQLtF88FbGg3eqsXf8d2zEuY+qbdowJ2jOudo2p7BjK0q5t63aKBYynN+0UTzvFAsbTm4GMp4zbSRYD29bNjyzKJEmSJK0Z29aSBpUykhh3ANtHxOiIWAc4Eri6hDgkSZKkqrNtLWlQafntJJm5LCKOA2YCQ4DzM/P+Xj42vfmRrZF2iqedYgHj6Y3xNNZOsUD7xVMF7Vhn7RgTtGdcxtQ37RgTtGdcxtQCHdC2bqdYwHh6007xtFMsYDy9GbB4IjMHaluSJEmSJElNU8btJJIkSZIkSWvMJIYkSZIkSaqEtk5iRMTkiFgQEQsj4uQW7XPbiLg5IuZFxP0RcXxR/uaIuCEiHijeNy3KIyK+XcR4T0Ts1oSYhkTE7yLimmJ+dETcXuxzRjGIExGxbjG/sFg+qgmxbBIRV0TE7yNifkTsXXLdfKb4d7ovIi6NiOGtrJ+IOD8inqx/1vra1EdETC3WfyAipg5wPGcV/173RMRVEbFJ3bJTingWRMSBdeUDcuz1FE/dsn+NiIyItxTzpdRPUf4vRR3dHxFfrytvav10qkb1Wbbuv7mSY2l4XJYQS9v9nqPBtbgdRLdrctl6ui63QUyrXJtLiqPP1+jBoozjvdHxvDbtpQGMybZ143hWOX5bWT9rctyurj7CtnXT66dRLNGKdnVmtuWL2sBEDwJvB9YB7gbGtGC/WwG7FdMbAn8AxgBfB04uyk8GziymDwZ+DgSwF3B7E2L6LPAj4Jpi/nLgyGL6v4BPFtP/H/BfxfSRwIwmxHIh8M/F9DrAJmXVDbAN8DCwXl29HN3K+gH2A3YD7qsrW6P6AN4MPFS8b1pMbzqA8UwChhbTZ9bFM6Y4rtYFRhfH25CBPPZ6iqco35baAGSPAm8puX7eB9wIrFvMb96q+unEV6P6LPvV02+u5Hh6PC5LiKMtf880uBaXHVcRzxuuyWW/6OG6XHI8PV6bS4qlz9fowfAq63hvdDw3+rdo1B4Y4JhsW/cci23rvsVj27px3bSkXT2gB+EAH0R7AzPr5k8BTikhjp8Bfw8sALYqyrYCFhTT3wOm1K2/cr0B2v9I4CZgInBN8SN8qu7AWVlPxQ9372J6aLFeDGAsG1M7sUW38rLqZhvgseIAHFrUz4Gtrh9gVLeDd43qA5gCfK+u/A3r9TeebssOAy4ppt9wTHXVz0Afez3FA1wB7Ao8wusn2lLqh9qF+YAe1mtJ/XTaq1F9lv3q6TfXLq/647KEfVfi90xxLW6DON5wTW6DeHq8LpccU0/X5kklxtOna/RgeLXL8Y5t6/pYbFv3HEefjttG9YFt65a1rXv4t2pJu7qdbyfpOoi6LCrKWqboEvVu4HZgi8x8vFj0BLBFMd3sOM8BTgJWFPObAX/JzGU97G9lLMXyZ4v1B8poYAnwg6IL3vcjYn1KqpvMXAx8A/gj8Di17zuX8uqny5rWRyt/6/9ELSNbWjwRcSiwODPv7raorPp5J/DeohvkLyNij5LjqbpG9Vma1fzm2kX9cdlqbf977nYtLts5vPGaXLZG1+XS9HRtzszry4ypm0bX6MGg9OPdtvUqbFv3jW3r1WiztnVL2tXtnMQoVURsAPwEOCEzn6tflrU0UbYghg8AT2bm3Gbvq4+GUusy9N3MfDfwArUuXSu1qm4AivvhDqV2AdgaWB+Y3Ip991Ur66M3EfFFYBlwSYkxvAn4AvC/y4qhB0Op/cVhL+BE4PKIiHJDam8RcWNxr2z316GUVJ+9xFTKb66XmLrWKf24bGeruxaXEEu7XZOhD9flVuvp2hwRR5UZUyPtdI0eDGxb98i29Rpqp+O2Ha7hbdi2bkk7cOhAb3AALaZ2b0+XkUVZ00XEMGon2Usy88qi+M8RsVVmPh4RWwFPtiDOfYBDIuJgYDiwEfAtYJOIGFpkPOv31xXLoogYSq2L2tIBigVqmbFFmdn117ArqJ1oy6gbgAOAhzNzCUBEXEmtzsqqny5rWh+LgQndymcNZEARcTTwAWD/4uS/unhYTXl/bUftwnh3cT4bCdwZEeNXE0+z62cRcGVRL7MjYgXwltXEw2rKB4XMPKDRsoj4JD3X55IyYoqInWnwm8vMJ8qIqS62o1n1uGy10q61vWlwLS7TKtfkiLg4M8v8D3qj63KZero2vwe4uNSoXtfoGj0Y2La2bd0b29Z9ZNu6oZa0q9u5J8YdwPZRGw13HWqDxVzd7J0WmaLzgPmZeXbdoquBqcX0VGr383WVfyxq9qLWbfJxBkBmnpKZIzNzFLXv/4vM/EfgZuDwBrF0xXh4sf6ANYyLBv9jEbFDUbQ/MI8S6qbwR2CviHhT8e/WFU8p9VNnTetjJjApIjYtMuCTirIBERGTqXWbPCQzX+wW55FRG1l6NLA9MJsmHnuZeW9mbp6Zo4rf9SJqg309QUn1A/yU2iBERMQ7qQ0q9BQl1E+H+Ck912cpevnNlWY1x2WrteXveTXX4tI0uCaX2sNgNdflMvV0bZ5fckz1Gl2jBwPb1rate2Pbug9sW6/WT2lFuzr7OfBKM1/URlT9A7URS7/Yon3uS62L0j3AXcXrYGr3d90EPEBtxNU3F+sHcG4R473AuCbFNYHXR1B+e/GPvhD4Ma+P/jq8mF9YLH97E+IYC8wp6uen1Ea0La1ugC8DvwfuA35IbcTbltUPcCm1ewZfo3bSOGZt6oPa/XQLi9fHBziehdTuNev6Pf9X3fpfLOJZABw00MdeT/F0W/4Irw8+VFb9rEPtL4T3AXcCE1tVP534Wl19tsOLNhnYc3XHZQmxtN3vmQbX4rLjqotvAm0wsGcRy1i6XZfbIKZVrs0lxdHna/RgeZVxvDc6nhv9W6yuPTDAca08jrFtXR+Pbeve47Ft3bhuWtKujuKDkiRJkiRJba2dbyeRJEmSJElaySSGJEmSJEmqBJMYkiRJkiSpEkxiSJIkSZKkSjCJIUmSJEmSKsEkhiohIp4v3kdFxP8c4G1/odv8rwdy+5I0mETEFyPi/oi4JyLuiog9m7ivWRExrlnbl6ROZLtaVWcSQ1UzClijk21EDO1llTecbDPzPWsYkyQJiIi9gQ8Au2XmLsABwGPlRiVJamAUtqtVQSYxVDVnAO8t/rr3mYgYEhFnRcQdxV/9PgEQERMi4lcRcTUwryj7aUTMLf5COK0oOwNYr9jeJUVZV3Y6im3fFxH3RsQRddueFRFXRMTvI+KSiIgS6kKS2s1WwFOZ+QpAZj6VmX+KiP9dnKfvi4jpXefM4lz6zYiYExHzI2KPiLgyIh6IiNOLdUbVnWvnF+feN3XfcURMiojfRMSdEfHjiNigKD8jIuYV14hvtLAuJKnd2a5WJUVmlh2D1KuIeD4zN4iICcDnMvMDRfk0YPPMPD0i1gVuAz4MvA34b+Bdmflwse6bM/PpiFgPuAP4u8xc2rXtHvb1IeBYYDLwluIzewI7AD8DdgL+VOzzxMy8tfk1IUntq0gc3Aq8CbgRmJGZv+w6/xbr/BC4PDP/b0TMAm7PzM9HxPHA54HdgaeBB4FdgQ2Bh4F9M/O2iDgfmJeZ3yg+/zngEeBK4KDMfCEiPg+sC5wL/BrYMTMzIjbJzL+0pDIkqU3ZrlbV2RNDVTcJ+FhE3AXcDmwGbF8sm911oi18OiLuBn4LbFu3XiP7Apdm5vLM/DPwS2CPum0vyswVwF3UuuNJ0qCWmc9TS0JMA5YAMyLiaOB9EXF7RNwLTKTWWO1ydfF+L3B/Zj5e9OR4iNq5GuCxzLytmL6Y2vm53l7AGOC24nowlVqj+1ngZeC8iPgH4MWB+q6S1IFsV6sSerunSWp3AfxLZs58Q2Ets/xCt/kDgL0z88Xir3fD+7HfV+qml+OxJEkAZOZyYBYwq0hafALYBRiXmY9FxGm88fzbdT5dwRvPrSt4/dzavdto9/kAbsjMKd3jiYjxwP7A4cBx1JIokqRV2a5WJdgTQ1XzV2pdi7vMBD4ZEcMAIuKdEbF+D5/bGHimONHuSO2vdl1e6/p8N78CjijuDxwB7AfMHpBvIUkdKCJ2iIj6v8aNBRYU008Vt5scvhabfmvUBg2F2iB03bsZ/xbYJyLeUcSxfnE92ADYODOvBT5D7fYUSVKN7WpVklkuVc09wPKi+9oFwLeodTm7sxgEaAnwwR4+dx1wbETMp9ag/m3dsunAPRFxZ2b+Y135VcDewN3U/up3UmY+UZysJUmr2gD4TkRsAiwDFlK7teQvwH3AE9Tug15TC4BPdY2HAXy3fmFmLiluW7m0uI8b4EvUGug/i4jh1P7C+Nm12LckdSrb1aokB/aUJEltKyJGAddk5rvKjkWSJJXP20kkSZIkSVIl2BNDkiRJkiRVgj0xJEmSJElSJZjEkCRJkiRJlWASQ5IkSZIkVYJJDEmSJEmSVAkmMSRJkiRJUiX8P8HjpA77q/iMAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 1080x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "## CONFIRM q_list_test in validation/visualization in Akash's code\n",
    "\n",
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "# loss_crit = torch.nn.CrossEntropyLoss()\n",
    "loss_crit = torch.nn.functional.cross_entropy\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        model.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "            \n",
    "        logP = model(p_batch)\n",
    "        logQ = model(q_batch)\n",
    "        logM = model(m_batch)\n",
    "        \n",
    "        p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "        q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "        m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "        \n",
    "        loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 50 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "                    log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, m_batch, calc_true_kl=True)\n",
    "                    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "                    \n",
    "                    logP = model(p_batch)\n",
    "                    logQ = model(q_batch)\n",
    "                    logM = model(m_batch)\n",
    "\n",
    "                    log_ratio_p_q_from_cob = logP[:, 0] - logP[:, 1]\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob)\n",
    "                    \n",
    "                    log_ratio_p_q_from_cob = logM[:, 0] - logM[:, 1]\n",
    "\n",
    "                    p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "                    \n",
    "                    test_loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "\n",
    "                    # Visualize\n",
    "                    \n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store) )\n",
    "                    \n",
    "                    scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "                    scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)\n",
    "\n",
    "                    ax2.set_xlim( -25., 25. )\n",
    "                    ax2.set_ylim( -400, 100)\n",
    "            \n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    ax3.set_xlim( 0, len(test_loss_store) )\n",
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model.train()"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      " 12%|█▏        | 599/5000 [00:07<00:58, 75.86it/s]\n"
     ]
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-1271455e48ad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     82\u001b[0m                     \u001b[0mclear_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m                     \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     84\u001b[0m                     \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/display.py\u001b[0m in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, *objs, **kwargs)\u001b[0m\n\u001b[1;32m    311\u001b[0m             \u001b[0mpublish_display_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    312\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m             \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minclude\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    314\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    315\u001b[0m                 \u001b[0;31m# nothing to display (e.g. _ipython_display_ took over)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mformat\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mmd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    179\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m                 \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    181\u001b[0m             \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m                 \u001b[0;31m# FIXME: log the exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<decorator-gen-2>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0;34m\"\"\"show traceback on failed format call\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    225\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    226\u001b[0m         \u001b[0;31m# don't warn on NotImplementedErrors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    339\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    340\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    342\u001b[0m             \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    343\u001b[0m             \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m    246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    247\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'png'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 248\u001b[0;31m         \u001b[0mpng_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'png'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    249\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'retina'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m'png2x'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    250\u001b[0m         \u001b[0mpng_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mretina_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    131\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m     \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    133\u001b[0m     \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m   2215\u001b[0m                     \u001b[0morientation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morientation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2216\u001b[0m                     \u001b[0mbbox_inches_restore\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_bbox_inches_restore\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2217\u001b[0;31m                     **kwargs)\n\u001b[0m\u001b[1;32m   2218\u001b[0m             \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2219\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mbbox_inches\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mrestore_bbox\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1637\u001b[0m             \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1638\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1639\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1640\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1641\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/backends/backend_agg.py\u001b[0m in \u001b[0;36mprint_png\u001b[0;34m(self, filename_or_obj, metadata, pil_kwargs, *args)\u001b[0m\n\u001b[1;32m    507\u001b[0m             \u001b[0;34m*\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mincluding\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mdefault\u001b[0m \u001b[0;34m'Software'\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    508\u001b[0m         \"\"\"\n\u001b[0;32m--> 509\u001b[0;31m         \u001b[0mFigureCanvasAgg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    510\u001b[0m         mpl.image.imsave(\n\u001b[1;32m    511\u001b[0m             \u001b[0mfilename_or_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuffer_rgba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"png\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morigin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"upper\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/backends/backend_agg.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    405\u001b[0m              (self.toolbar._wait_cursor_for_draw_cm() if self.toolbar\n\u001b[1;32m    406\u001b[0m               else nullcontext()):\n\u001b[0;32m--> 407\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    408\u001b[0m             \u001b[0;31m# A GUI class may be need to update a window using this draw, so\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    409\u001b[0m             \u001b[0;31m# don't forget to call the superclass.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     39\u001b[0m                 \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0martist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_agg_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m   1862\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1863\u001b[0m             mimage._draw_list_compositing_images(\n\u001b[0;32m-> 1864\u001b[0;31m                 renderer, self, artists, self.suppressComposite)\n\u001b[0m\u001b[1;32m   1865\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1866\u001b[0m             \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'figure'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/image.py\u001b[0m in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m    129\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mnot_composite\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_images\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m             \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;31m# Composite any adjacent images together\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     39\u001b[0m                 \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0martist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_agg_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*inner_args, **inner_kwargs)\u001b[0m\n\u001b[1;32m    409\u001b[0m                          \u001b[0;32melse\u001b[0m \u001b[0mdeprecation_addendum\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    410\u001b[0m                 **kwargs)\n\u001b[0;32m--> 411\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minner_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0minner_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    412\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    413\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/axes/_base.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer, inframe)\u001b[0m\n\u001b[1;32m   2745\u001b[0m             \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_rasterizing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2746\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2747\u001b[0;31m         \u001b[0mmimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_draw_list_compositing_images\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2749\u001b[0m         \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'axes'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/image.py\u001b[0m in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m    129\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mnot_composite\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_images\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m             \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;31m# Composite any adjacent images together\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     39\u001b[0m                 \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0martist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_agg_filter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/axis.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1162\u001b[0m         \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_gid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1164\u001b[0;31m         \u001b[0mticks_to_draw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_ticks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1165\u001b[0m         ticklabelBoxes, ticklabelBoxes2 = self._get_tick_bboxes(ticks_to_draw,\n\u001b[1;32m   1166\u001b[0m                                                                 renderer)\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/axis.py\u001b[0m in \u001b[0;36m_update_ticks\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1022\u001b[0m         \u001b[0mmajor_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmajor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_ticks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmajor_locs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1023\u001b[0m         \u001b[0mmajor_ticks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_major_ticks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmajor_locs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1024\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmajor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_locs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmajor_locs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1025\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mtick\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmajor_ticks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmajor_locs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmajor_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1026\u001b[0m             \u001b[0mtick\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_position\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/ticker.py\u001b[0m in \u001b[0;36mset_locs\u001b[0;34m(self, locs)\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_useOffset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    780\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_offset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 781\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_order_of_magnitude\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    782\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_format\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    783\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/sr/lib/python3.7/site-packages/matplotlib/ticker.py\u001b[0m in \u001b[0;36m_set_order_of_magnitude\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    833\u001b[0m             \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    834\u001b[0m         \u001b[0;31m# restrict to visible ticks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 835\u001b[0;31m         \u001b[0mvmin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvmax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_view_interval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    836\u001b[0m         \u001b[0mlocs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlocs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    837\u001b[0m         \u001b[0mlocs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlocs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvmin\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mlocs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlocs\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mvmax\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "# Set up viz\n",
    "fig, ax2 = plt.subplots(1, 1,figsize=(6,4))\n",
    "\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True Log p/q, KL = '+str(np.around(true_kl_p_q.item(),2)),alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB Log p/q, KL = '+str(np.around(kl_from_cob.item(),2)),alpha=0.9,s=10.,c='r')\n",
    "\n",
    "scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)                    \n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-25,25])\n",
    "ax2.set_ylim([-400,100])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('../plots/cob_mu1e6.png')"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmdUlEQVR4nO3de3RU9b3//+ebGIIEEUEtalCpDSqXECCAF9QoGvB+V1AUpBbxh8ufeNQiekQ9XluP9bSgVq2IB2y9tPVWL4NoDsWKGGxEUKjRgsZiRRCQWxB4f/+YnXEISRggM7Nn8nqslbVnf/ae2e/sFXn52XvP52PujoiISNi0SHcBIiIi9VFAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEUloDysweN7OvzWx+XFt7M5tuZp8Ey72CdjOzX5tZlZnNM7Pe6atcRESSLd09qCeAwXXaxgEz3L0QmBGsA5wMFAY/o4CHUlSjiIikQVoDyt1nAivqNJ8JTAleTwHOimt/0qNmA+3MbL+UFCoiIim3W7oLqMeP3H1p8Por4EfB6wOAL+L2qw7alsa1YWajiPawyM/P73PYYYclt1oREdklc+fO/cbd96nbHsaAinF3N7MdGovJ3R8BHgEoKSnxioqKpNQmIiJNw8yW1Nee7ntQ9fl37aW7YPl10P4l0Cluv4KgTUREslAYA+pFYHjwejjwQlz7pcHTfEcAq+IuBYqISJZJ6yU+M/s9UArsbWbVwATgHuAZM/spsAS4INj9FeAUoApYB1yW8oJFRCRl0hpQ7j60gU0D69nXgTHJrUikefj++++prq5mw4YN6S5FmpFWrVpRUFBAbm5uQvuH+iEJEUmO6upq9thjDw4++GDMLN3lSDPg7ixfvpzq6mo6d+6c0HvCeA9KRJJsw4YNdOjQQeEkKWNmdOjQYYd67QookWZK4SSptqN/cwooEREJJQWUiKTc8uXLKS4upri4mI4dO3LAAQfE1jdu3NgkxygtLSWVX9SfPXs2P/vZz3bqvSNGjOC5554DYMWKFfTq1YvJkyezePFiunfv3pRlMnLkSPbdd99tPvc///M/KSoqori4mLKyMv71r39t894lS5bQu3dviouL6datGw8//DAA69at49RTT+Wwww6jW7dujBs3bpv37gwFlIikXIcOHaisrKSyspLRo0czduzY2HrLli3ZtGlTukvcYa+++iqDB9cd+3rHrFq1ikGDBjFq1Cguuyw536QZMWIEr7322jbt119/PfPmzaOyspLTTjuN22+/fZt99ttvP9555x0qKyt59913ueeee2JBdt1117Fw4UL+/ve/8/bbb/Pqq6/ucq0KKBEJhREjRjB69Gj69+/PDTfcwK233sp9990X2969e3cWL14MwNSpU+nXrx/FxcVcccUVbN68OaFjrFixgrPOOouioiKOOOII5s2bB8CyZcs46aST6NatG5dffjkHHXQQ33zzzTbvb9OmDWPHjqVbt24MHDiQZcuWxbbNmDGDE088kfXr1zNkyBAOP/xwzj77bPr3759QT27NmjWcfPLJXHTRRVx55ZUJ/T4749hjj6V9+/bbtLdt2zb2eu3atfXeL2rZsiV5eXkA1NTUsGXLFgBat27N8ccfH9und+/eVFdX73KtCigRSUgkAuPHR5fJUl1dzd/+9jfuv//+Bvf5+OOPefrpp3n77beprKwkJyeHadOmJfT5EyZMoFevXsybN4+77rqLSy+9FIDbbruNE044gQULFnDeeefx+eef1/v+tWvXUlJSwoIFCzjuuOO47bbbAPjmm2/Izc1lzz335KGHHqJ169Z8/PHH3HbbbcydOzeh2q699loGDBjA2LFjE9q/1rRp02KXR+N/zjvvvB36HICbbrqJTp06MW3atHp7UABffPEFRUVFdOrUiZ///Ofsv//+W21fuXIlL730EgMHbvN11h2mgBKR7YpEYNgwmDQpukxWSJ1//vnk5OQ0us+MGTOYO3cuffv2pbi4mBkzZvDZZ58l9PmzZs3ikksuAeCEE05g+fLlrF69mlmzZjFkyBAABg8ezF577VXv+1u0aMGFF14IwLBhw5g1axYAkUiEsrIyAGbOnMmwYcMAKCoqoqioKKHaTjjhBF544QW+/vrr7e8c5+KLL45dHo3/qb2ntSPuvPNOvvjiCy6++GImTpxY7z6dOnVi3rx5VFVVMWXKFP7973/Htm3atImhQ4dy9dVX8+Mf/3iHj1+XAkpEtqu8HGpqID8/uiwvT85x8vPzY69322232CUkIPb9GXdn+PDhsX+IFy1axK233pqcgraj9jJYU9x/GjJkCKNHj+aUU07hu+++S/h9TdmDqnXxxRfzxz/+sdF99t9/f7p3785f//rXWNuoUaMoLCzkmmuu2eljx1NAich2lZZCXh6sXRtdlpYm/5gHH3ww77//PgDvv/8+//znPwEYOHAgzz33XKynsWLFCpYsqXe2hm0cc8wxscuB5eXl7L333rRt25ajjz6aZ555Boj2hr799tt6379ly5ZYz+Spp55iwIABuDvz5s2juLgYiN7jeeqppwCYP39+7D4XwKWXXsqcOXMarG/s2LEMHDiQc845J+GnGZuqB/XJJ5/EXr/wwgvUN5dedXU169evB+Dbb79l1qxZHHrooQDcfPPNrFq1igceeGCHjtsYBZSIbFdZGUydCmPGRJfB1aykOvfcc1mxYgXdunVj4sSJdOnSBYCuXbtyxx13UFZWRlFRESeddBJLl9Y/scGpp55KQUEBBQUFnH/++dx6663MnTuXoqIixo0bx5Qp0cm7J0yYQCQSoXv37jz77LN07NiRPfbYY5vPy8/PZ86cOXTv3p0333yTW265hblz59KrV69Yb+rKK69kzZo1HH744dxyyy306dMn9v558+Ztc8+mrnvvvZeCggIuueQStmzZwqJFi2K/Q0FBAc8+++xOnc9aQ4cO5cgjj4x97u9+9zsAxo0bR/fu3SkqKiISifA///M/AFRUVHD55ZcD0ft//fv3p2fPnhx33HFcd9119OjRg+rqau68804++uij2GPojz322C7VCWDRMVizkyYsFKnfxx9/zOGHH57uMkKjpqaGnJwcdtttN9555x2uvPJKKisrt9mvTZs2rFmzZqu2O+64g5/85Cexe1h1lZaWct9999GlSxd++tOf7nLAZLr6/vbMbK67l9TdV4PFikiz9/nnn3PBBRewZcsWWrZsyaOPPprwe2+++eaE9mvbtm2zD6cdpYASkWavsLCQv//979vdr27vKRHlyXqipBnQPSgREQklBZSIiISSAkpEREJJASUiIqGkgBKRtPjqq68YMmQIhxxyCH369OGUU07hH//4R4P7L168mN13353i4mJ69uzJUUcdxaJFi+rdr6mnqNie0aNH8/bbb+/Ue9u0aRN7/corr9ClSxeWLFmyzWC5u2rhwoUceeSR5OXlbfW5GzZsoF+/fvTs2ZNu3boxYcKEet//8MMP06NHD4qLixkwYAAfffQRANOnT6dPnz706NGDPn368OabbzZZzQooEUk5d+fss8+mtLSUTz/9lLlz53L33XdvNa5bfQ455BAqKyv54IMPGD58OHfddVeKKm7c7NmzOeKII3bpM2bMmMHVV1/Nq6++ykEHHdRElf2gffv2/PrXv+a6667bqj0vL48333yTDz74gMrKSl577TVmz569zfsvuugiPvzwQyorK7nhhhu49tprAdh777156aWX+PDDD5kyZUpsrMOmoIASkZR76623yM3NZfTo0bG2nj17cswxx+DuXH/99XTv3p0ePXrw9NNP1/sZq1evbnBQ1/rMmDGDXr160aNHD0aOHElNTQ0Q7bUcdthh9OnTh6uvvprTTjttm/c+8cQTnHnmmZSWllJYWBgbxRyiXzzt0qULOTk5zJ07l549e9KzZ8/Y75CImTNn8rOf/YyXX36ZQw45JOHfaUfsu+++9O3bl9zc3K3azSzWi/v+++/5/vvv651qo6HpOHr16hUbHaNbt26sX78+dm53lb4HJSKJiUSio8SWlu7yWEfz58/fagigeH/6059ivaRvvvmGvn37cuyxxwLw6aefUlxczHfffce6det49913Ezrehg0bGDFiBDNmzKBLly5ceumlPPTQQ4wePZorrriCmTNn0rlzZ4YOHdrgZ8yZM4f58+fTunVr+vbty6mnnkpJSclWA8VedtllTJw4kWOPPZbrr78+odpqamo466yzKC8vr3f8u8ZceOGF9V7mvPbaa2NTiSRi8+bN9OnTh6qqKsaMGUP//v3r3W/SpEncf//9bNy4sd5LeX/84x/p3bt3bM6oXaUelIhsX6rm2yA6JcbQoUPJycnhRz/6Eccddxzvvfce8MMlvk8//ZQHHniAUaNGJfSZixYtonPnzrHx/IYPH87MmTNZuHAhP/7xj+ncuTNAowF10kkn0aFDB3bffXfOOeec2FQbr7/+OoMHD2blypWsXLkyFqaJXurKzc3lqKOOio2JtyOefvrpegeK3ZFwAsjJyaGyspLq6upYENdnzJgxfPrpp9x7773ccccdW21bsGABP//5z/ntb3+7w79HQxRQIpmksBBatIguU6mJ59vo1q1bwhP5NeSMM85g5syZu/QZO6LuZS8zY926daxcuXK7A8A2pkWLFjzzzDPMmTNnh++pXXjhhfVOtfHkk0/uVC3t2rXj+OOPr3dK+HhDhgzh+eefj61XV1dz9tln8+STTzbpJUoFlEimKCyEqipwjy5TGVJNPN/GCSecQE1NDY888kisbd68efz1r3/lmGOO4emnn2bz5s0sW7aMmTNn0q9fv20+Y9asWQn/Y3jooYeyePFiqqqqAPjf//1fjjvuOA499FA+++yz2FTyDd3vgujTaitWrGD9+vU8//zzHH300bz11luxqc7btWtHu3btYj2r+Fl+v/zyy0ZnmG3dujV/+ctfmDZt2g71pJqiB7Vs2TJWrlwJwPr165k+fXq9lxrjp+P4y1/+QmHw97dy5UpOPfVU7rnnHo4++uiEj5sI3YMSyRTBP64NridT7XwbTXQPysz485//zDXXXMO9995Lq1atOPjgg3nggQcYMGAA77zzDj179sTM+MUvfkHHjh1ZvHhx7B6Uu9OyZcsGp3SonUqi1q9+9SsmT57M+eefz6ZNm+jbty+jR48mLy+PBx98kMGDB5Ofn0/fvn0brLlfv36ce+65VFdXM2zYMEpKSrjqqqu2mhhw8uTJjBw5EjOLzbALsHTpUnbbrfF/btu3b89rr73Gscceyz777ANER0qPn1+purq60c9ozFdffUVJSQmrV6+mRYsWPPDAA3z00UcsXbqU4cOHs3nzZrZs2cIFF1wQe1DklltuoaSkhDPOOIOJEyfyxhtvkJuby1577RWbqmTixIlUVVVx++23x6aJj0Qi7Lvvvjtday1NtyGSKep5soqd/O9X0238YM2aNbRp0wZ3Z8yYMRQWFjJ27Nit9nniiSeoqKjYZhr03r178+67727zZBxEv4912mmnMX/+fCZOnMiBBx7IGWeckdTfJRNoug0RkQQ9+uijTJkyhY0bN9KrVy+uuOKKhN9bO+Pv9lx11VU7W16zph6USKZQD0qywI70oPSQhEimaNGi8fUdlM3/cyrhtKN/cwookUwR3DhvcH0HtGrViuXLlyukJGXcneXLl9OqVauE36N7UCKZYsMG4uPENmzY6Y8qKCigurqaZcuW7XpdIglq1arVVk9Xbo8CSiRDrF+1nlZ11nffyc/Kzc2NjZ4gEla6xCeSITbQqtF1kWyjgBLJEEs4qNF1kWyjgBLJECto3+i6SLbJuIAys8FmtsjMqsxsXLrrEUmVvNzG10WyTUYFlJnlAJOAk4GuwFAz65reqkRSo127xtdFsk1GBRTQD6hy98/cfSPwB+DMNNckkhodO+IQ+6Fjx/TWI5JkmRZQBwBfxK1XB20iWe/driNZzZ5spCWr2ZN3u45Md0kiSZVpAbVdZjbKzCrMrEJfQpRsUjCyjIqc/mwkl4qc/hSM3LUpL0TCLtMC6kugU9x6QdAW4+6PuHuJu5fsswtDwYiETdf/HsnAzRH2YC0DN0fo+t/qQUl2y7SAeg8oNLPOZtYSGAK8mOaaRFKiw1vPNboukm0yaqgjd99kZlcBrwM5wOPuviDNZYmkRG6LzY2ui2SbjAooAHd/BXgl3XWIpNqaVh3Ys2bdVuvt0leOSNJl2iU+kWbr922vjI1m7sG6SDZTQIlkiLoT4GpCXMl2CiiRDFH63UsYsIUWWLAuks0UUCKZ4vTTMTNy2IKZwemnp7sikaTKuIckRJqtG2+ETz6B11+HQYOi6yJZTD0okUwRicDLL8OaNdFlJJLuikSSSgElkinKy6GmBvLzo8vy8nRXJJJUCiiRTFFaCnl5sHZtdFlamu6KRJJK96BEMkVZGUydGu05lZZG10WymAJKJJOUlSmYpNnQJT4REQklBZRIBolEYPx4PcAnzYMCSiRDRCIwbBhMmhRdKqQk2ymgRDJEeTmsXg3r10eXespcsp0ekhDJEF99Ff36U/y6SDZTD0okQyxcGF2abb0ukq0UUCIZ4vTTo+HkHl1qrFjJdrrEJ5IhaseGfemlaDhprFjJdubu298rQ5WUlHhFRUW6yxARkUaY2Vx3L6nbrkt8IiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJZBCNZi7NiQJKJENEInDBBXD//dGlQkqynQJKJENMnhwdxXzjxuhy8uR0VySSXBrqSCRDuMNJHuE4yvk/SnHX1O+S3RRQIhni+p4RDnxmGLlew8/sMT7vORVQSEn20iU+kQzR57ty2u1eQ4s2+bTbvYY+35WnuSKR5FJAiWSK0lI8N4/c79fiuXlQWpruikSSSpf4RDJEhDIe+n4qR3xfzuycUq6kTBf4JKspoEQyxOTJ8Py6Mp6nDNZBq8lQpoSSLKZLfCIZYunSxtdFso0CSiRD7Ldf4+si2UYBJZIhLrsM9twT8vKiy8suS3dFIsmle1AiGaKsDJ55BsrLow/w6f6TZLu09KDM7HwzW2BmW8yspM62G82syswWmdmguPbBQVuVmY1LfdUi6VdWBnfdpXCS5iFdl/jmA+cAM+MbzawrMAToBgwGHjSzHDPLASYBJwNdgaHBviLNyt13w1FHRZci2S4tAeXuH7v7ono2nQn8wd1r3P2fQBXQL/ipcvfP3H0j8IdgX5Fm4+67YfX4u/nlO0exevzdCinJemF7SOIA4Iu49eqgraH2bZjZKDOrMLOKZcuWJa1QkVRr99Dd3MlNHMk73MlNtHtICSXZLWkBZWZvmNn8en6S2vNx90fcvcTdS/bZZ59kHkokpU7nJcBxWgAerItkr+0+xWdmewK3AscETf8H3O7uqxp7n7ufuBP1fAl0ilsvCNpopF2kWSi48nS2jJ+NsQUwCq48Pd0liSRVIo+ZP070oYYLgvVLgMlEH3Joai8CT5nZ/cD+QCEwBzCg0Mw6Ew2mIcBFSTi+SGjdzY2sBk7jJV7mdNpyIzemuyiRJDJ3b3wHs0p3L95e2w4d1Oxs4DfAPsBKoNLdBwXbbgJGApuAa9z91aD9FOABIAd43N3v3N5xSkpKvKKiYmfLFAmVo46Cd96BFi1gyxY48kj429/SXZXIrjOzue5eUrc9kR7UejMb4O6zgg86Gli/K8W4+5+BPzew7U5gm/Bx91eAV3bluCKZ7PTTYfbsaDiZRddFslkiAXUlMCW4F2XACmBEMosSkW3dGFzPe+mlaDjdqOt7kuW2e4kvtqNZWwB3X53UipqQLvGJiITfDl/iM7Nh7j7VzK6t0w6Au9/f5FWKiIgEGrvElx8s96hnW2LdLhFpWpGIRouVZqPBgHL33wYv33D3t+O3BQ9KiEgqRSIwbBjU1MBjj8HUqQopyWqJjCTxmwTbRCSZysuj4ZSfH12Wl6e7IpGkauwe1JHAUcA+de5DtSX6XSQRSaXS0mjPae3a6KyFpaXprkgkqRq7B9USaBPsE38fajVwXjKLEpF6lJVFL+vpHpQ0E4mMJHGQuy9JUT1NSo+Zi4iE366MJLHOzH5JdBLBVrWN7n5CE9YnIiKylUQekpgGLAQ6A7cBi4H3kliTiDQgEoHx46NLkWyXSEB1cPffAd+7+/+5+0hAvSeRFKt9ynzSpOhSISXZLpGA+j5YLjWzU82sF9A+iTWJSD30lLk0N4kE1B3BQLH/AVwHPAZck8yiRGRbpaXRUcyXL48u9ZS5ZLvtPiTh7i8HL1cBx4NGkhBJp2A4TJGs12APysxyzGyomV1nZt2DttPM7G/AxJRVKCJA9JKeO7RvH13qEp9ku8Yu8f0OuBzoAPzazKYC9wG/cPdeqShORH5QWhodQEIDSUhz0dglvhKgyN23mFkr4CvgEHdfnprSRCSeBpKQ5qaxgNro7lsA3H2DmX2mcBJJr7IyBZM0H40F1GFmNi94bcAhwboB7u5FSa9ORESarcYC6vCUVSEiIlJHYxMWZuQAsSIikh0S+aKuiIhIyimgREQklBRQIiISStsd6sjMPgTqzmq4CqgA7tCj5yIikgyJTFj4KrAZeCpYHwK0JvrF3SeA05NSmYiINGuJBNSJ7t47bv1DM3vf3Xub2bBkFSYiIs1bIvegcsysX+2KmfUFcoLVTUmpSkTqpRl1pTlJpAd1OfC4mbUhOorEauCnZpYP3J3M4kTkB7Uz6tbUwGOPRcfl07BHks0SmQ/qPaBHMGkh7r4qbvMzySpMRLYWP6Pu2rXRdQWUZLPtXuIzsz3N7H5gBjDDzP67NqxEJHVKS2GQRbh2+XgGWUTTbUjWS+QS3+PAfOCCYP0SYDJwTrKKEpFtlRGhlGFsthqu5jFaMhVQF0qyVyIPSRzi7hPc/bPg5zbgx8kuTETqKC+npdewe/t8WnqNptSVrJdIQK03swG1K2Z2NLA+eSWJSL00pa40M4lc4hsNPBl33+lbYHjyShKRemlKXWlmEnmK7wOgp5m1DdZXm9k1wLxG3ygiTU9T6kozkvBgse6+2t1XB6vX7spBzeyXZrbQzOaZ2Z/NrF3cthvNrMrMFpnZoLj2wUFblZmN25Xji4hI+O3saOa2i8edDnQPpo3/B3AjgJl1JTrWXzdgMPCgmeWYWQ4wCTgZ6AoMDfYVEZEstbMBVXd08x17s3vE3WuHSZoNFASvzwT+4O417v5PoAroF/xUBU8RbgT+EOwrIiJZqsF7UGb2HfUHkQG7N2ENI4Gng9cHEA2sWtVBG8AXddr71/dhZjYKGAVw4IEHNmGZIiKSSg0GlLvvsSsfbGZvAB3r2XSTu78Q7HMT0QFnp+3KseK5+yPAIwAlJSW71NMTEZH0SeQx853i7ic2tt3MRgCnAQPdvTZIvgQ6xe1WELTRSLuIiGShtEz5bmaDgRuAM9x9XdymF4EhZpZnZp2BQmAO8B5QaGadzawl0QcpXkx13SIikjpJ60Ftx0QgD5huZgCz3X20uy8ws2eAj4he+hvj7psBzOwq4HWic1E97u4L0lO6SPpEIvqerjQf9sPVtexTUlLiFRUV6S5DpEnEzweVl6f5oCR7mNlcdy+p256WS3wisuPKy+HotRFu3zyeo9dGNFasZL10XeITkR107h4RDlo/jFyv4WJ7jCV7aLoNyW7qQYlkiD7flbPn7jVYm3z23L2GPt+Vp7kikeRSQIlkitJScvPzaNtiLbn5mm5Dsp8u8YlkCk23Ic2MAkokk2i6DWlGdIlPRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKJEMEonA+PHRpUi2U0CJZIhIBIYNg0mTokuFlGQ7BZRIhigvh5oayM+PLsvL012RSHIpoEQyRGkp5OXB2rXRZWlpuisSSS5N+S6SIcrKYOrUaM+ptFQzv0v2U0CJZJCyMgWTNB+6xCciIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlEgG0Vh80pwooEQyhMbik+ZGASWSITQWnzQ3CiiRDKGx+KS50VBHIhlCY/FJc5OWHpSZ/ZeZzTOzSjOLmNn+QbuZ2a/NrCrY3jvuPcPN7JPgZ3g66hZJt7IyuOsuhZM0D+m6xPdLdy9y92LgZeCWoP1koDD4GQU8BGBm7YEJQH+gHzDBzPZKddEiIpI6aQkod18dt5oPePD6TOBJj5oNtDOz/YBBwHR3X+Hu3wLTgcEpLVpERFIqbfegzOxO4FJgFXB80HwA8EXcbtVBW0PtIiKSpZLWgzKzN8xsfj0/ZwK4+03u3gmYBlzVhMcdZWYVZlaxbNmypvpYERFJsaT1oNz9xAR3nQa8QvQe05dAp7htBUHbl0BpnfbyBo77CPAIQElJide3j0imikT0FJ80H+l6iq8wbvVMYGHw+kXg0uBpviOAVe6+FHgdKDOzvYKHI8qCNpFmQyNJSHOTrntQ95jZocAWYAkwOmh/BTgFqALWAZcBuPsKM/sv4L1gv9vdfUVqSxZJr/iRJNauja6rFyXZLC0B5e7nNtDuwJgGtj0OPJ7MukTCrLQUHntMI0lI86GRJEQyhEaSkOZGASWSQcrKFEzSfGiwWBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKRERCSQElIiKhpIASEZFQUkCJiEgoKaBERCSUFFAiIhJKCigREQklBZSIiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoEREJJQUUCIiEkoKKBERCSUFlIiIhJICSkREQkkBJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSgooEREJJQWUiIiEkgJKJJNEIjB+fHQpkuV2S3cBIpKgSISNFwxj84Yach58jJbPTIWysnRXJZI0ae1Bmdl/mJmb2d7BupnZr82syszmmVnvuH2Hm9knwc/w9FUtkh6fTS5n/eoaVm7MZ/3qGj6bXJ7ukkSSKm0BZWadgDLg87jmk4HC4GcU8FCwb3tgAtAf6AdMMLO9UlqwSJqVeykbyaM1a9lIHuVemu6SRJIqnT2oXwE3AB7XdibwpEfNBtqZ2X7AIGC6u69w92+B6cDglFcskkYFI8sY03Yqj+aOYUzbqRSM1OU9yW5puQdlZmcCX7r7B2YWv+kA4Iu49eqgraH2+j57FNHeF8AaM1vUVHU3kb2Bb9JdRIjp/DSqXdtnsX3Z+NDXzw5auTrd1YSU/oYaF8bzc1B9jUkLKDN7A+hYz6abgPFEL+81OXd/BHgkGZ/dFMyswt1L0l1HWOn8bJ/OUeN0fhqXSecnaQHl7ifW125mPYDOQG3vqQB438z6AV8CneJ2LwjavgRK67SXN3nRIiISGim/B+XuH7r7vu5+sLsfTPRyXW93/wp4Ebg0eJrvCGCVuy8FXgfKzGyv4OGIsqBNRESyVNi+B/UKcApQBawDLgNw9xVm9l/Ae8F+t7v7ivSUuMtCe/kxJHR+tk/nqHE6P43LmPNj7r79vURERFJMQx2JiEgoKaBERCSUFFApYGa/NLOFwfBNfzazdnHbbgyGdlpkZoPSWGZamdn5ZrbAzLaYWUmdbTpHgJkNDs5BlZmNS3c9YWBmj5vZ12Y2P66tvZlND4ZFm96cR50xs05m9paZfRT89/X/B+0ZcY4UUKkxHeju7kXAP4AbAcysKzAE6EZ0ZIwHzSwnbVWm13zgHGBmfKPOUVTwO08iOhxYV2BocG6auyfYdlSZccAMdy8EZgTrzdUm4D/cvStwBDAm+LvJiHOkgEoBd4+4+6ZgdTbR73FBdGinP7h7jbv/k+jTi/3SUWO6ufvH7l7fqB86R1H9gCp3/8zdNwJ/IHpumjV3nwnUfaL3TGBK8HoKcFYqawoTd1/q7u8Hr78DPiY6Ck9GnCMFVOqNBF4NXic8hFMzpnMUpfOQuB8F358E+Ar4UTqLCQszOxjoBbxLhpyjsH0PKmM1NrSTu78Q7HMT0S73tFTWFhaJnCORpuTubmbN/rs0ZtYG+CNwjbuvjh8DNcznSAHVRBoa2qmWmY0ATgMG+g9fPmtoaKestL1z1IBmdY4aofOQuH+b2X7uvjSYDeHrdBeUTmaWSzScprn7n4LmjDhHusSXAmY2mOjUIme4+7q4TS8CQ8wsz8w6E50Ha046agwxnaOo94BCM+tsZi2JPjjyYpprCqsXgdpJTYcDzbZ3btGu0u+Aj939/rhNGXGONJJECphZFZAHLA+aZrv76GDbTUTvS20i2v1+tf5PyW5mdjbwG2AfYCVQ6e6Dgm06R4CZnQI8AOQAj7v7nemtKP3M7PdEB5LeG/g30YlNnweeAQ4ElgAXZPDQaLvEzAYAfwU+BLYEzeOJ3ocK/TlSQImISCjpEp+IiISSAkpEREJJASUiIqGkgBIRkVBSQImISCgpoESSzMxuCkaSnmdmlWbWP4nHKq87GrxIptJIEiJJZGZHEh1BpLe715jZ3kDLNJclkhHUgxJJrv2Ab9y9BsDdv3H3f5nZLWb2npnNN7NHgm/81/aAfmVmFWb2sZn1NbM/BfP23BHsc3Awv9i0YJ/nzKx13QObWZmZvWNm75vZs8F4bJjZPcH8QPPM7L4UnguRHaKAEkmuCNDJzP5hZg+a2XFB+0R37+vu3YHdifayam109xLgYaJD0IwBugMjzKxDsM+hwIPufjiwGvj/4g8a9NRuBk50995ABXBt8P6zgW7B/GR3JOF3FmkSCiiRJHL3NUAfYBSwDHg6GDj4eDN718w+BE4gOiFjrdox9j4EFgRz+tQAn/HDgLFfuPvbweupwIA6hz6C6MSGb5tZJdHx1g4CVgEbgN+Z2TnAOkRCSvegRJLM3TcD5UB5EEhXAEVAibt/YWa3Aq3i3lITLLfEva5dr/1vtu4YZXXXDZju7kPr1mNm/YCBwHnAVUQDUiR01IMSSSIzO9TMCuOaioHamYO/Ce4LnbcTH31g8AAGwEXArDrbZwNHm9lPgjryzaxLcLw93f0VYCzQcyeOLZIS6kGJJFcb4Ddm1o7oaOxVRC/3rQTmE53N9L2d+NxFwBgzexz4CHgofqO7LwsuJf7ezPKC5puB74AXzKwV0V7WtTtxbJGU0GjmIhkmmLr75eABC5GspUt8IiISSupBiYhIKKkHJSIioaSAEhGRUFJAiYhIKCmgREQklBRQIiISSv8PPMWWqytPyJsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sr",
   "language": "python",
   "name": "sr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}