{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_DIMS = 1\n",
    "NUM_SAMPLES = 100000\n",
    "BS = 500\n",
    "NUM_EPOCHS = 500\n",
    "SEED = 10\n",
    "LR = 1e-2\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "# Break by changing num datapoints, scales, means, or to 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "model = RatioCritic1D(dim_input=N_DIMS, dim_output=3, dropout=DROPOUT)\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_1d(mu1=-2., mu2=2., mu3=0, scale_p=0.1, scale_q=0.2, scale_m=1.0)\n",
    "\n",
    "# -5, 5, m_var=3.0\n",
    "# -10, 10, m_var=3.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling p\n",
      "Sampling q\n",
      "Cauchy(loc: 0.0, scale: 1.0)\n",
      "Sampling m\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "Sampling p\n",
      "Sampling q\n",
      "Cauchy(loc: 0.0, scale: 1.0)\n",
      "Sampling m\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n"
     ]
    }
   ],
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES) # Test dataset is only of size batch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABDAAAAEYCAYAAACqUwbqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAxCUlEQVR4nO3de7RdZX3v//enCRLuAgYEgiYi6gkoKOEmFhE8IVqPyBGUnFLjpb8cPXhBqy3o6QHPELXaWouiR7wBBbmUQqFU5KaBAUVCQJSbURCUCELAFhDllnx/f6wZ2En2TnYua625136/xphjzfnM2/dZ7Dzz4bvmfGaqCkmSJEmSpDb7o34HIEmSJEmStDomMCRJkiRJUuuZwJAkSZIkSa1nAkOSJEmSJLWeCQxJkiRJktR6JjAkSZIkSVLrdS2BkeRbSR5IcsuQsq2SXJbk583nlt06vySNV0nuTnJzkpuSLGjKRmx/kxyb5I4kC5McPKR8j+Y4dyQ5MUn6UR9Jkn1rSYLu3oFxCjBrhbJjgCuqamfgimZZkrT+va6qdq+qGc3ysO1vkunAEcAudNrsrySZ0OzzVWAusHMzrdimS5J65xTsW0sa57qWwKiqq4DfrlB8CHBqM38q8JZunV+StJyR2t9DgLOq6omqugu4A9gryXbA5lV1bVUVcBq22ZLUN/atJQkm9vh821bVfQBVdV+SbUbaMMlcOr/8sckmm+zxspe9rEchShovbrjhhgeranK/4+iCAi5NUsDXqupkRm5/dwB+OGTfRU3ZU838iuXLsa2W1G0D3FavD/atJbVCr9rqXicwRq3pcJ8MMGPGjFqwYEGfI5I0aJL8st8xdMl+VXVv05G9LMlPV7HtcONa1CrKly+wrZbUZQPcVveU7bWkbupVW93rt5Dc39yWTPP5QI/PL0kDr6rubT4fAM4H9mLk9ncRsOOQ3acA9zblU4YplyS1h31rSeNKrxMYFwJzmvk5wAU9Pr8kDbQkmyTZbNk8MBO4hZHb3wuBI5JsmGQancE65ze3JD+aZJ/m7SPvwDZbktrGvrWkcaVrj5AkORM4AHhekkXAccBngXOSvAf4FXB4t84vSePUtsD5zRtPJwLfqarvJbmeYdrfqro1yTnAbcDTwFFVtaQ51vvojHq/EXBxM0mS+sC+tSR1MYFRVbNHWHVQt84pSeNdVf0C2G2Y8ocYof2tqhOAE4YpXwDsur5jlMaTp556ikWLFvH444/3O5TWmzRpElOmTGGDDTbodyitZN9aklo8iKckSdJYt2jRIjbbbDOmTp1Kc2eUhlFVPPTQQyxatIhp06b1OxxJUkv1egwMSZKkcePxxx9n6623NnmxGknYeuutvVNFkrRKJjAkSZK6yOTF6Pg9SZJWxwSGJEmSJElqPRMYkiRJA+qhhx5i9913Z/fdd+f5z38+O+ywwzPLTz75ZNfPf9999zFz5syun0eSND44iKckSdKA2nrrrbnpppsAOP7449l000356Ec/+sz6p59+mokTu9cd/N73vsfBBx/cteNLksYXExiSJEnjyDvf+U622morfvSjH/GqV72KzTbbbLnExq677spFF13E1KlTOf300znxxBN58skn2XvvvfnKV77ChAkTljve1KlTefvb384PfvADAL7zne/w4he/GOgkMI477jiqig984AN8//vfZ9q0aVQV7373uznssMN6W3lJ0pjmIySSJEktcuml8PGPdz675Wc/+xmXX345f/d3fzfiNrfffjtnn30211xzDTfddBMTJkzgjDPOGHbbzTffnPnz5/P+97+fo48+GoAlS5awcOFCpk+fzvnnn8/ChQu5+eab+frXv86///u/d6NakqQB5x0YkiRJLXHppXDkkfDEE/CNb8Dpp0M3hpA4/PDDV7qTYkVXXHEFN9xwA3vuuScAf/jDH9hmm22G3Xb27NnPfH74wx8G4LrrrmPvvfcG4KqrrmL27NlMmDCB7bffngMPPHB9VUWSNI6YwJAkSWqJefM6yYtNNoHHHussdyOBsckmmzwzP3HiRJYuXfrM8uOPPw5AVTFnzhw+85nPrPZ4Q1+Bumz+4osvZtasWcNuI0nS2vAREkmSpJY44ADYcMNO8mLDDTvL3TZ16lRuvPFGAG688UbuuusuAA466CDOPfdcHnjgAQB++9vf8stf/nLYY5x99tnPfO67775A5w6Ogw46CID999+fs846iyVLlnDfffc9M16GJElrwjswJEmSWmLmzM5jI/PmdZIXvXgD6Vvf+lZOO+00dt99d/bcc09e8pKXADB9+nQ+9alPMXPmTJYuXcoGG2zASSedxAtf+MKVjvHEE0+w9957s3TpUs4880wWL17MpEmT2HzzzQE49NBD+f73v8/LX/5yXvKSl/Da1762+xWTJA0cExiSJEktMnNmdxIXxx9//LDlG220EZeOMGLo29/+dt7+9rev9thHHXUUxx133DPLp59+OjOHVCIJX/7yl59Zfuc73zm6oCVJGsIEhiRJktarI488st8hSJIGkAkMSZIkrbW77757jfc55ZRT1nsckqTB5yCekiRJkiSp9UxgSJIkSZKk1jOBIUmSJEmSWs8EhiRJkiRJaj0TGJIkSQPsN7/5DUcccQQ77bQT06dP541vfCM/+9nPRtz+7rvvZqONNmL33Xdnt91249WvfjULFy5c4/N+5jOf4YwzzliX0CVJWo4JDEmSpAFVVRx66KEccMAB3Hnnndx22218+tOf5v7771/lfjvttBM33XQTP/7xj5kzZw6f/vSn1/jcl156KTNnzlzb0CVJWokJDEmSpAH1gx/8gA022ID3vve9z5Ttvvvu/PEf/zFVxcc+9jF23XVXXv7yl3P22WcPe4xHHnmELbfccqXyefPmsf/++3PooYcyffp03vve97J06dJn9nnyySeZPHkyd911F/vuuy977rknf/3Xf82mm27ancpKkgbexH4HIEmSpCEuvRTmzYMDDoB1vIPhlltuYY899hh23XnnnffMXRYPPvgge+65J/vvvz8Ad955J7vvvjuPPvoov//977nuuuuGPcb8+fO57bbbeOELX8isWbM477zzOOyww7j88ss56KCDAPjQhz7E+973Pt7xjndw0kknrVN9JEnjm3dgSJIktcWll8KRR8JJJ3U+L720a6e6+uqrmT17NhMmTGDbbbflta99Lddffz3w7CMkd955J1/84heZO3fusMfYa6+9eNGLXsSECROYPXs2V199NQDf+973eMMb3gDANddcw+zZswH4sz/7s67VR5I0+ExgSJIktcW8efDEE7DJJp3PefPW6XC77LILN9xww7DrqmpUx3jzm9/MVVddNey6JMMuz58/n7322mvE7SRJWhsmMCRJktrigANgww3hscc6nwccsE6HO/DAA3niiSf4+te//kzZ9ddfz5VXXsn+++/P2WefzZIlS1i8eDFXXXXVckmHZa6++mp22mmnYY8/f/587rrrLpYuXcrZZ5/Na17zGm699VZe9rKXMWHCBAD2228/zjrrLADfSiJJWicmMCRJktpi5kw4/XQ46qjO5zqOgZGE888/n8suu4yddtqJXXbZheOPP57tt9+eQw89lFe84hXstttuHHjggXzuc5/j+c9/PvDsGBi77bYbH//4x/nGN74x7PH33XdfjjnmGHbddVemTZvGoYceysUXX8ysWbOe2eYf/uEfOOmkk9hzzz15+OGH16k+kqTxzUE8JWnAJJkALAB+XVVvSrIVcDYwFbgbeFtV/Uez7bHAe4AlwAer6pKmfA/gFGAj4LvAh2q095tLWjczZ65z4mKo7bffnnPOOWfYdZ///Of5/Oc/v1zZ1KlT+cMf/jCqY2+88cYrvb3kkksu4bTTTntmedq0aVx77bXPLH/qU58abeiSJC3HOzAkafB8CLh9yPIxwBVVtTNwRbNMkunAEcAuwCzgK03yA+CrwFxg52aahSSNwmWXXcZ2223X7zAkSQPIBIYkDZAkU4A/AYbe730IcGozfyrwliHlZ1XVE1V1F3AHsFeS7YDNq+ra5q6L04bsI0kAHHDAAVx00UVrvN/vfve7LkQjSRoPTGBI0mD5IvCXwNIhZdtW1X0Azec2TfkOwD1DtlvUlO3QzK9YvpIkc5MsSLJg8eLF66UC0qDx6avR8XuSJK2OCQxJGhBJ3gQ8UFXDvzNxmF2GKatVlK9cWHVyVc2oqhmTJ08e5Wml8WPSpEk89NBD/s/5alQVDz30EJMmTep3KJKkFnMQT0kaHPsBb07yRmASsHmS04H7k2xXVfc1j4c80Gy/CNhxyP5TgHub8inDlEtaQ1OmTGHRokV4h9LqTZo0iSlTpqx+Q0nSuGUCQ5IGRFUdCxwLkOQA4KNVdWSSzwNzgM82nxc0u1wIfCfJF4Dt6QzWOb+qliR5NMk+wHXAO4Av9bIu0qDYYIMNmDZtWr/DkCRpIJjAkKTB91ngnCTvAX4FHA5QVbcmOQe4DXgaOKqqljT7vI9nX6N6cTNJkiRJfWMCQ5IGUFXNA+Y18w8BB42w3QnACcOULwB27V6EkiRJ0ppxEE9JkiRJktR6fUlgJPlwkluT3JLkzCQOOS1JkiStBfvWksaLnicwkuwAfBCYUVW7AhOAI3odhyRJkjTW2beWNJ706xGSicBGSSYCG+Pr+SRJkqS1Zd9a0rjQ8wRGVf0a+Fs6I+HfBzxcVZf2Og5JkiRprLNvLWk86ccjJFsChwDTgO2BTZIcOcx2c5MsSLJg8eLFvQ5TkiRJaj371pLGk348QvJ64K6qWlxVTwHnAa9ecaOqOrmqZlTVjMmTJ/c8SEmSJGkMsG8tadzoRwLjV8A+STZOEuAg4PY+xCFJkiSNdfatJY0b/RgD4zrgXOBG4OYmhpN7HYckSZI01tm3ljSeTOzHSavqOOC4fpxbkiRJGiT2rSWNF/16jaokSZIkSdKomcCQJEmSJEmtZwJDkiRJkiS1ngkMSZIkSZLUeiYwJEmSJElS65nAkCRJkiRJrWcCQ5IkSZIktZ4JDEmSJEmS1HomMCRJkiRJUuuZwJAkSZIkSa1nAkOSJEmSJLWeCQxJkiRJktR6JjAkSZIkSVLrmcCQJEmSJEmtZwJDkiRJkiS1ngkMSZIkSZLUeiYwJGmAJJmUZH6SHye5Ncknm/KtklyW5OfN55ZD9jk2yR1JFiY5eEj5HklubtadmCT9qJMkSZIEJjAkadA8ARxYVbsBuwOzkuwDHANcUVU7A1c0yySZDhwB7ALMAr6SZEJzrK8Cc4Gdm2lWD+shSZIkLccEhiQNkOr4XbO4QTMVcAhwalN+KvCWZv4Q4KyqeqKq7gLuAPZKsh2weVVdW1UFnDZkH0mSJKnnTGBI0oBJMiHJTcADwGVVdR2wbVXdB9B8btNsvgNwz5DdFzVlOzTzK5ZLkiRJfWECQ5IGTFUtqardgSl07qbYdRWbDzeuRa2ifPmdk7lJFiRZsHjx4rWKV5IkSRoNExiSNKCq6j+BeXTGrri/eSyE5vOBZrNFwI5DdpsC3NuUTxmmfMVznFxVM6pqxuTJk9d3FSRJkqRnmMCQpAGSZHKS5zbzGwGvB34KXAjMaTabA1zQzF8IHJFkwyTT6AzWOb95zOTRJPs0bx95x5B9JEmSpJ6b2O8AJEnr1XbAqc2bRP4IOKeqLkpyLXBOkvcAvwIOB6iqW5OcA9wGPA0cVVVLmmO9DzgF2Ai4uJkkSZKkvjCBIUkDpKp+ArxymPKHgING2OcE4IRhyhcAqxo/Q5IkSeoZHyGRJEmSJEmtZwJDkiRJkiS1ngkMSZIkSZLUeiYwJEmSJElS65nAkCRJkiRJrWcCQ5IkSZIktZ4JDEmSJEmS1HomMCRJkiRJUuuZwJAkSZIkSa1nAkOSJEmSJLWeCQxJkiRJktR6JjAkSZIkSVLrmcCQJEmSJEmt15cERpLnJjk3yU+T3J5k337EIUmSJI119q0ljRcT+3TefwC+V1WHJXkOsHGf4pAkSZLGOvvWksaFnicwkmwO7A+8E6CqngSe7HUckiRJ0lhn31rSeLLaR0iS7JRkw2b+gCQfTPLcdTjni4DFwLeT/CjJN5JsMsx55yZZkGTB4sWL1+F0kiRJ0sCyby1p3BjNGBj/DCxJ8mLgm8A04DvrcM6JwKuAr1bVK4HHgGNW3KiqTq6qGVU1Y/LkyetwOkmSJGlg2beWNG6MJoGxtKqeBg4FvlhVHwa2W4dzLgIWVdV1zfK5dBpdSRo3kmyR5O+X/RqW5O+SbNHvuCRJY459a0njxmgSGE8lmQ3MAS5qyjZY2xNW1W+Ae5K8tCk6CLhtbY8nSWPUt4BHgLc10yPAt/sakSSp65J8LsnmSTZIckWSB5McubbHs28taTwZzSCe7wLeC5xQVXclmQacvo7n/QBwRjNK8i+ac0jSeLJTVb11yPInk9zUr2AkST0zs6r+MsmhdO6eOBz4AevWv7ZvLWlcWG0Co6puAz4IkGRLYLOq+uy6nLSqbgJmrMsxJGmM+0OS11TV1QBJ9gP+0OeYJEndt+xO5jcCZ1bVb5Os0wHtW0saL1abwEgyD3hzs+1NwOIkV1bVR7obmiQNtPcBpzbjXgT4Lc0r8CRJA+1fk/yUTtL6fyWZDDze55gkaUwYzSMkW1TVI0n+HPh2VR2X5CfdDkySBlnza9luSTZvlh/pb0SSpF6oqmOS/A3wSFUtSfIYcEi/45KksWA0CYyJSbajM8jcJ7ocjyQNtCRHVtXpST6yQjkAVfWFvgQmSeqJJIcD32uSF/+bzhtDPgX8pr+RSVL7jeYtJP8XuAS4s6quT/Ii4OfdDUuSBtYmzedmw0yb9isoSVLP/HVVPZrkNcDBwKnAV/sckySNCaMZxPOfgH8asvwL4K0j7yFJGklVfa2Zvbyqrhm6rhnIU5I02JY0n38CfLWqLkhyfB/jkaQxY7V3YCSZkuT8JA8kuT/JPyeZ0ovgJGmAfWmUZWskyY5JfpDk9iS3JvlQU75VksuS/Lz53HLIPscmuSPJwiQHDynfI8nNzboTs67D5EuSAH6d5Gt0Hs/+bpINGd1d0ZI07o2msfw2cCGwPbAD8K9NmSRpDSXZN8lfAJOTfGTIdDwwYT2c4mngL6rqvwD7AEclmQ4cA1xRVTsDVzTLNOuOAHYBZgFfSbIsjq8Cc4Gdm2nWeohPksa7t9F5PHtWVf0nsBXwsb5GJEljxGgSGJOr6ttV9XQznQJM7nJckjSonkNnrIuJLD/+xSPAYet68Kq6r6pubOYfBW6nk3w+hM5z1jSfb2nmDwHOqqonquou4A5gr2bw5s2r6tqqKuC0IftIktZSVf0euBM4OMn7gW2q6tI+hyVJY8Jo3kLyYJIjgTOb5dnAQ90LSZIGV1VdCVyZ5JSq+mU3z5VkKvBK4Dpg26q6r4nhviTbNJvtAPxwyG6LmrKnmvkVy1c8x1w6d2nwghe8YD3XQJIGT/No3/8HnNcUnZ7k5Kpa58cIJWnQjSaB8W7gy8DfAwX8O/CubgYlSePA75N8ns6jG5OWFVbVgevj4Ek2Bf4ZOLqqHlnF8BXDrahVlC9fUHUycDLAjBkzVlovSVrJe4C9q+oxgCR/A1zLehgHSZIG3WofIamqX1XVm6tqclVtU1VvAT7Y/dAkaaCdAfwUmAZ8ErgbuH59HDjJBnSSF2dU1bJf+O5vHguh+XygKV8E7Dhk9ynAvU35lGHKJUnrJjz7JhKaeQdJlqRRWNsRj9+2XqOQpPFn66r6JvBUVV1ZVe+mM+jmOmneFPJN4Paq+sKQVRcCc5r5OcAFQ8qPSLJhkml0Buuc3zxu8miSfZpjvmPIPpKktfdt4LokxzcDOP+QTrstSVqN0TxCMhyzxJK0bp5qPu9L8id07m5YH6+o3g/4M+DmJDc1ZR8HPguck+Q9wK+AwwGq6tYk5wC30XmDyVFVteyXwfcBpwAbARc3kyRpHVTVF5LMA15Dp0/9LuD+vgYlSWPEiAmMJFuNtAoTGJK0rj6VZAvgL+g897w5cPS6HrSqrmbkNvqgEfY5AThhmPIFwK7rGpMkaXnN26JuXLac5FeAIyFL0mqs6g6MGxh5ILcnuxOOJI0PVXVRM/sw8DqAJPv1LyJJUh/546AkjcKICYyqmtbLQCRpPEgygc44QjsA36uqW5K8ic5jHhvRee2pJGl88S1OkjQKazsGhiRp7XyTzls/5gMnJvklsC9wTFX9Sz8DkyR1T5IvMXyiIsBzexuNJI1NJjAkqbdmAK+oqqVJJgEPAi+uqt/0OS5JUnctWMt1kqSGCQxJ6q0nq2opQFU9nuRnJi8kafBV1an9jkGSxrpRJTCaZ7a3Hbp9Vf2qW0FJ0gB7WZKfNPMBdmqWA1RVvaJ/oUmSJEnttdoERpIPAMfReT/10qa4ADvZkrTm/ku/A5AkSZLGotHcgfEh4KVV9VC3g5GkQVdVv+x3DJKk/kmyX1Vds7oySdLK/mgU29wDPNztQCRJkqRx4EujLJMkrWA0d2D8ApiX5N+AJ5YVVtUXuhaVJEmSNECS7Au8Gpic5CNDVm0OTOhPVJI0towmgfGrZnpOM0mSJElaM88BNqXT/95sSPkjwGF9iUiSxpjVJjCq6pO9CESSxpMkN9MZEHmoh4EFwKccd0iSBktVXQlcmeSUZeMhJfkjYNOqeqS/0UnS2DBiAiPJF6vq6CT/ysqdbKrqzV2NTJIG28XAEuA7zfIRzecjwCnAf+tDTJKk7vtMkvfSuQbcAGyR5AtV9fk+xyVJrbeqOzD+sfn8214EIknjzH5Vtd+Q5ZuTXFNV+yU5sm9RSZK6bXpVPZLkT4HvAn9FJ5FhAkOSVmPEBEZV3dB8Xtm7cCRp3Ng0yd5VdR1Akr3oPBsN8HT/wpIkddkGSTYA3gJ8uaqeSrLS3c6SpJWtdgyMJDsDnwGmA5OWlVfVi7oYlyQNuj8HvpVkUyB0Hh15T5JN6LS5kqTB9DXgbuDHwFVJXkjnGiBJWo3RvIXk28BxwN8DrwPeRaezLUlaS1V1PfDyJFsAqar/HLL6nP5EJUnqtqo6EThxSNEvk7yuX/FI0ljyR6PYZqOquoJOB/uXVXU8cGB3w5KkwZZkiyRfAK4ALk/yd00yQ5I0wJJsm+SbSS5ulqcDc/ocliSNCaNJYDzevOLp50nen+RQYJsuxyVJg+5bwKPA25rpETp3vEmSBtspwCXA9s3yz4Cj+xWMJI0lo0lgHA1sDHwQ2AM4ErPEkrSudqqq46rqF830ScCxhSRpQCVZ9uj286rqHGApQFU9TeeVqpKk1VhlAiPJBOBtVfW7qlpUVe+qqrdW1Q97FJ8kDao/JHnNsoUk+wF/6GM8kqTumt98PpZka6AAkuwDPNy3qCRpDBlxEM8kE6vq6SR7JElV+XonSVp/3gucNmTci//Au9skaZAtGwT/I8CFwE5JrgEmA4f1LSpJGkNW9RaS+cCrgB8BFyT5J+CxZSur6rwuxyZJA6uqfgzslmTzZvmRJEcDP+lrYJKkbpmc5CPN/PnAd+kkNZ4AXo/tvySt1mheo7oV8BCdN48UnYa2gHVKYDSPpywAfl1Vb1qXY0nSWFVVjwxZ/AjwxT6FIknqrgnApjx7J8YyG6+Pg9u3ljQerCqBsU2TJb6FZxMXy6yPx0k+BNwObL4ejiVJg2DFTq0kaXDcV1X/t4vHt28taeCtahDPZVniTYHNhswvm9ZakinAnwDfWJfjSNKAcawhSRpcXUtS27eWNF6s6g6MbmaJvwj8JZ3EyLCSzAXmArzgBS/oUhiS1FtJHmX4REWAjXocjiSpdw7q4rG/iH1rSePAqu7A6EqWOMmbgAeq6oZVbVdVJ1fVjKqaMXny5G6EIkk9V1WbVdXmw0ybVdVoxiVapSTfSvJAkluGlG2V5LIkP28+txyy7tgkdyRZmOTgIeV7JLm5WXdiEh9vkaR1UFW/7cZx7VtLGk9WlcDoVpZ4P+DNSe4GzgIOTHJ6l84lSePNKcCsFcqOAa6oqp2BK5plkkwHjgB2afb5SjMIHMBX6fxSt3MzrXhMSVI72LeWNG6MmMDoVpa4qo6tqilVNZVOx/n7VXVkN84lSeNNVV0FrNh+HwKc2syfCrxlSPlZVfVEVd0F3AHslWQ7YPOquraqCjhtyD6SpBaxby1pPFnVHRiSpMGwbVXdB9B8btOU7wDcM2S7RU3ZDs38iuUrSTI3yYIkCxYvXrzeA5ckSZKW6WsCo6rm+Z5qSeqb4ca1WPG12UPLVy70mWpJag371pIGnXdgSNLgu795LITm84GmfBGw45DtpgD3NuVThimXJEmS+sYEhiQNvguBOc38HOCCIeVHJNkwyTQ6g3XObx4zeTTJPs3bR94xZB9JkiSpL9b5lX2SpPZIciZwAPC8JIuA44DPAuckeQ/wK+BwgKq6Nck5wG3A08BRVbWkOdT76LzRZCPg4maSJEmS+sYEhiQNkKqaPcKqYV+NXVUnACcMU74A2HU9hiZJkiStEx8hkSRJkiRJrWcCQ5IkSZIktZ4JDEmSJEmS1HomMCRJkiRJUuuZwJAkSZIkSa1nAkOSJEmSJLWeCQxJkiRJktR6JjAkSZIkSVLrmcCQJEmSJEmtZwJDkiRJkiS1ngkMSZIkSZLUeiYwJEmSJElS65nAkCRJkiRJrWcCQ5IkSZIktZ4JDEmSJEmS1HomMCRJkiRJUuuZwJAkSZIkSa1nAkOSJEmSJLWeCQxJkiRJktR6JjAkSZIkSVLrmcCQJEmSJEmtZwJDkiRJkiS1ngkMSZIkSZLUeiYwJEmSJElS65nAkCRJkiRJrWcCQ5IkSZIktZ4JDEnSiJLMSrIwyR1Jjul3PJIkSRq/JvY7AElSOyWZAJwE/FdgEXB9kgur6rb+Rqa1sSRZb79aLAUmVK2no0mSJI2Od2BIkkayF3BHVf2iqp4EzgIO6XNMWgvrM3kBnc7DkmQ9HlGSJGn1TGBIkkayA3DPkOVFTdkzksxNsiDJgsWLF/c0OI1eNy72diAkSVKv2f+QJI1kuJ/Yl3tuoKpOrqoZVTVj8uTJPQpLa2rpGDmmJEnSqpjAkCSNZBGw45DlKcC9fYpF62BC1XpNODgGhiRJ6gcH8ZQkjeR6YOck04BfA0cA/6O/IWltrc+Ew4T1diRJkqTR6/kdGEl2TPKDJLcnuTXJh3odgyRp9arqaeD9wCXA7cA5VXVrf6OSJA1l31rSeNKPOzCeBv6iqm5MshlwQ5LLfC2fJLVPVX0X+G6/45Akjci+taRxo+d3YFTVfVV1YzP/KJ1f9XZY9V6SJEmSVmTfWtJ40tdBPJNMBV4JXNfPOCRJkqSxzr61pEHXtwRGkk2BfwaOrqpHhlk/N8mCJAsWL17c+wAlSZKkMcK+taTxoC8JjCQb0Glgz6iq84bbpqpOrqoZVTVj8uTJvQ1QkiRJGiPsW0saL/rxFpIA3wRur6ov9Pr8kiRJ0qCwby1pPOnHHRj7AX8GHJjkpmZ6Yx/ikCRJksY6+9aSxo2ev0a1qq4G0uvzSpIkSYPGvrWk8aSvbyGRJEmSJEkaDRMYkiRJkiSp9UxgSJIkSZKk1jOBIUmSJEmSWs8EhiRJkiRJaj0TGJIkSZIkqfVMYEiSJEmSpNYzgSFJkiRJklrPBIYkSZIkSWo9ExiSJEmSJKn1TGBIkiRJkqTWM4EhSZIkSZJazwSGJEmSJElqPRMYkiRJkiSp9UxgSJIkSZKk1jOBIUmSJEmSWm9MJDCWLK1+hyBJkiQNhKVl31rS2DQmEhj3/Mfv+x2CJEmSNBB+8/Dj/Q5BktbKmEhgSJIkSZKk8c0EhiRJkiRJaj0TGJI0IJIcnuTWJEuTzFhh3bFJ7kiyMMnBQ8r3SHJzs+7EJGnKN0xydlN+XZKpPa6OJEmStBwTGJI0OG4B/jtw1dDCJNOBI4BdgFnAV5JMaFZ/FZgL7NxMs5ry9wD/UVUvBv4e+JuuRy9JkiStggkMSRoQVXV7VS0cZtUhwFlV9URV3QXcAeyVZDtg86q6tqoKOA14y5B9Tm3mzwUOWnZ3hiRJktQPJjAkafDtANwzZHlRU7ZDM79i+XL7VNXTwMPA1iseOMncJAuSLFi8eHEXQpckSZI6JvY7AEnS6CW5HHj+MKs+UVUXjLTbMGW1ivJV7bN8QdXJwMkAM2bMWGm9JEmStL6YwJCkMaSqXr8Wuy0CdhyyPAW4tymfMkz50H0WJZkIbAH8di3OLUmSJK0XPkIiSYPvQuCI5s0i0+gM1jm/qu4DHk2yTzO+xTuAC4bsM6eZPwz4fjNOhiRJktQX3oEhSQMiyaHAl4DJwL8luamqDq6qW5OcA9wGPA0cVVVLmt3eB5wCbARc3EwA3wT+MckddO68OKJ3NZEkSZJWZgJDkgZEVZ0PnD/CuhOAE4YpXwDsOkz548Dh6ztGSZIkaW35CIkkSZIkSWo9ExiSJEmSJKn1TGBIkiRJkqTWM4EhSZIkSZJazwSGJEmSJElqPRMYkiRJkiSp9UxgSJIkSZKk1jOBIUmSJEmSWs8EhiRJkiRJar2+JDCSzEqyMMkdSY7pRwySJEnSILBvLWm86HkCI8kE4CTgDcB0YHaS6b2OQ5IkSRrr7FtLGk/6cQfGXsAdVfWLqnoSOAs4pA9xSJIkSWOdfWtJ48bEPpxzB+CeIcuLgL1X3CjJXGBus/hEklt6EFu/PA94sN9BdMkg1w2s31j30n4HMEhuuOGG3yVZ2O84Gm362zWW4RnLytoSB7QrFtvqVbNvvbI2/f2ub4NcN7B+Y1lP2up+JDAyTFmtVFB1MnAyQJIFVTWj24H1yyDXb5DrBtZvrEuyoN8xDJiFbfl7adPfrrEMz1jaGwe0L5Z+x9By9q1XMMj1G+S6gfUby3rVVvfjEZJFwI5DlqcA9/YhDkmSJGmss28tadzoRwLjemDnJNOSPAc4AriwD3FIkiRJY519a0njRs8fIamqp5O8H7gEmAB8q6puXc1uJ3c/sr4a5PoNct3A+o11g16/XmvT92kswzOW4bUllrbEAcYyZti3HtYg12+Q6wbWbyzrSd1StdIjcpIkSZIkSa3Sj0dIJEmSJEmS1ogJDEmSJEmS1HqtTmAkmZVkYZI7khzT73hGkmTHJD9IcnuSW5N8qCnfKsllSX7efG45ZJ9jm3otTHLwkPI9ktzcrDsxSZryDZOc3ZRfl2RqH+o5IcmPklw0aPVL8twk5yb5afPfcd8Bq9+Hm7/NW5KcmWTSWK5fkm8leSBD3mHfq/okmdOc4+dJ5nSznmNVkg803/WtST7Xgng+mqSSPK+PMXy+aV9+kuT8JM/t8flbcT3NCNfLflrx2tbHOFa6DvUxlpWuGT089xq171ozbWkLVmektmIs911GqKd96zFYv+HayLFctzVtd9dnfbI2/eqqauVEZxCiO4EXAc8BfgxM73dcI8S6HfCqZn4z4GfAdOBzwDFN+THA3zTz05v6bAhMa+o5oVk3H9iXzju9Lwbe0JT/L+D/NfNHAGf3oZ4fAb4DXNQsD0z9gFOBP2/mnwM8d1DqB+wA3AVs1CyfA7xzLNcP2B94FXDLkLKu1wfYCvhF87llM79lL/9W2z4BrwMuBzZslrfpczw70hnY7pfA8/oYx0xgYjP/N8v+Pnt07tZcTxnhetnnv5Hlrm19jGOl61Cf4hj2mtHD84+6fXda4++2NW3BKGK1bz3G6zdcmzYI9RupjRzLdWOM9at7+o90Db/IfYFLhiwfCxzb77hGGfsFwH8FFgLbNWXbAQuHqwudzvW+zTY/HVI+G/ja0G2a+YnAgzSDsPaoTlOAK4ADebaRHYj6AZvTaYiyQvmg1G8H4J6mcZgIXETnf6bGdP2AqSzf0Ha9PkO3adZ9DZjdi/+OY2WicyF/fb/jGBLPucBuwN30MYGxQkyHAmf08HytvZ7SXC/7eP6Vrm19imPY61CfYhn2mtHjGEbVvjut8ffa2rZgFLHbtx5D9RupTRuE+o3URo71ujGG+tVtfoRk2R/HMouaslZrbol5JXAdsG1V3QfQfG7TbDZS3XZo5lcsX26fqnoaeBjYuiuVGN4Xgb8Elg4pG5T6vQhYDHy7uY3vG0k2YUDqV1W/Bv4W+BVwH/BwVV3KgNRviF7UZ0y2Sz32EuCPm1sEr0yyZ78CSfJm4NdV9eN+xTCCd9P5ZaJXWvl3u8L1sl++yMrXtn4Y6TrUc6u4ZvTTSO271kwr24LVsW8NjL36DWzf2n51//vVbU5gZJiy6nkUayDJpsA/A0dX1SOr2nSYslpF+ar26bokbwIeqKobRrvLMGWtrR+dTOCrgK9W1SuBx+jcKjWSMVW/5pm1Q+jc5rU9sEmSI1e1yzBlra3fKKzP+rS5nj2T5PLmuc8Vp0Po/HvaEtgH+BhwzrJnIPsQyyeA/9Otc69hLMu2+QTwNHBGr+KihX+3a3C97GYMa3pt66Y1vQ51zVpcMzR2tK4tWB371s/uMkxZa+vHAPet7Vf3v1/d5gTGIjrPLi8zBbi3T7GsVpIN6DSwZ1TVeU3x/Um2a9ZvBzzQlI9Ut0XN/Irly+2TZCKwBfDb9V+TYe0HvDnJ3cBZwIFJTmdw6rcIWFRVy34FPJdOozso9Xs9cFdVLa6qp4DzgFczOPVbphf1GVPtUrdU1euratdhpgvofEfnVcd8Or8sdW3wzJFiofMc5TTgx03bNQW4Mcnzex1L873QDE71JuBPq7lXskda9Xc7wvWyH0a6tvXDSNehfhjpmtFPI7XvWjOtagtWx771mK7fIPet7Vf3uV/d5gTG9cDOSaYleQ6dAT8u7HNMw2p+XfwmcHtVfWHIqguBOc38HDrP7y0rP6IZkXUasDMwv7k959Ek+zTHfMcK+yw71mHA93vVAa6qY6tqSlVNpfPf4ftVdSSDU7/fAPckeWlTdBBwGwNSPzq3uO2TZOMmroOA2xmc+i3Ti/pcAsxMsmWTgZ/ZlOlZ/0LneV6SvITOwF0P9jqIqrq5qrapqqlN27WIzoBwv+l1LNAZ+R/4K+DNVfX7Hp++NdfTVVwve24V17Z+xDLSdagfRrpm9NNI7bvWTGvagtWxbw2M7foNct/afnW/+9XVg4Fc1nYC3khn1OE7gU/0O55VxPkaOre7/AS4qZneSOfZniuAnzefWw3Z5xNNvRbSjNDalM8AbmnWfZlmwBZgEvBPwB10Rnh9UZ/qegDPDjQ0MPUDdgcWNP8N/4XOLfCDVL9PAj9tYvtHOiMHj9n6AWfSee7wKTr/Y/qeXtWHzvgFdzTTu3r9b7DtE52ExenN93ojcGC/Y2riupv+voXkDjrPed7UTP+vx+dvxfWUEa6XLfj7OID+v4VkpetQH2NZ6ZrRw3OvUfvutMbfbyvaglHEad96jNdvuDZtUOo3XBs5luvGGOtXLzuoJEmSJElSa7X5ERJJkiRJkiTABIYkSZIkSRoDTGBIkiRJkqTWM4EhSZIkSZJazwSGJEmSJElqPRMY6qskv2s+pyb5H+v52B9fYfnf1+fxJWk8SfKJJLcm+UmSm5Ls3cVzzUsyo1vHl6RBZd9ag84EhtpiKrBGjWySCavZZLlGtqpevYYxSZKAJPsCbwJeVVWvAF4P3NPfqCRJqzAV+9YaQCYw1BafBf64+VXvw0kmJPl8kuubX/v+J0CSA5L8IMl3gJubsn9JckPzy+DcpuyzwEbN8c5oypZlpNMc+5YkNyd5+5Bjz0tybpKfJjkjSfrwXUhS22wHPFhVTwBU1YNVdW+S/9O007ckOXlZm9m0pX+f5KoktyfZM8l5SX6e5FPNNlObtvbUpp0/N8nGK544ycwk1ya5Mck/Jdm0Kf9sktuaff+2h9+FJI0F9q01kFJV/Y5B41iS31XVpkkOAD5aVW9qyucC21TVp5JsCFwDHA68EPg3YNequqvZdquq+m2SjYDrgddW1UPLjj3Mud4KvBeYBTyv2Wdv4KXABcAuwL3NOT9WVVd3/5uQpPZqkgZXAxsDlwNnV9WVy9rfZpt/BM6pqn9NMg+4rqr+KsmHgL8C9gB+C9wJ7AZsBtwFvKaqrknyLeC2qvrbZv+PAncD5wFvqKrHkvwVsCHwZeBa4GVVVUmeW1X/2ZMvQ5JazL61Bp13YKitZgLvSHITcB2wNbBzs27+sga28cEkPwZ+COw4ZLuRvAY4s6qWVNX9wJXAnkOOvaiqlgI30bn9TpLGtar6HZ0ExFxgMXB2kncCr0tyXZKbgQPpdFKXubD5vBm4tarua+7g+AWdthrgnqq6ppk/nU77PNQ+wHTgmuZ6MIdOZ/sR4HHgG0n+O/D79VVXSRpQ9q01ECb2OwBpBAE+UFWXLFfYySY/tsLy64F9q+r3za92k0Zx7JE8MWR+Cf4bkSQAqmoJMA+Y1yQs/ifwCmBGVd2T5HiWb3+XtadLWb5tXcqzbeuKt4GuuBzgsqqavWI8SfYCDgKOAN5PJ4EiSRqefWsNBO/AUFs8Sud24mUuAd6XZAOAJC9Jsskw+20B/EfTwL6Mzq91yzy1bP8VXAW8vXkWcDKwPzB/vdRCkgZQkpcmGfoL3O7Awmb+weYRk8PW4tAvSGeAUIDZdB5TGeqHwH5JXtzEsXFzPdgU2KKqvgsc3cQjSXqWfWsNJDNgaoufAE83t6udAvwDnVvMbmwG+1kMvGWY/b4HvDfJT+h0pn84ZN3JwE+S3FhVfzqk/HxgX+DHdH7t+8uq+k3TSEuSVrYp8KUkzwWeBu6g8zjJf9J5RORuOs88r6nbgTlJvgb8HPjq0JVVtbh5VOXM5pltgP9Np2N+QZJJdH75+/BanFuSBpl9aw0kB/GUJEk9l2QqcFFV7drvWCRJ0tjgIySSJEmSJKn1vANDkiRJkiS1nndgSJIkSZKk1jOBIUmSJEmSWs8EhiRJkiRJaj0TGJIkSZIkqfVMYEiSJEmSpNb7/wEJoM8YWdIkPgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 1080x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3] = plt.subplots(1, 3,figsize=(15,4))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line, = ax3.plot([0,1],[0,1])\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-6,10])\n",
    "ax2.set_ylim([-1500,5000])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Test Loss\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax3.set_ylim([0,10])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "test_loss_store = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS=1500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 6/1500 [00:21<1:29:21,  3.59s/it]\n",
      "ERROR:root:Internal Python error in the inspect module.\n",
      "Below is the traceback from this internal error.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 3441, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"<ipython-input-10-f85d0c7bbb96>\", line 83, in <module>\n",
      "    display(fig)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/display.py\", line 320, in display\n",
      "    format_dict, md_dict = format(obj, include=include, exclude=exclude)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\", line 180, in format\n",
      "    data = formatter(obj)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/decorator.py\", line 232, in fun\n",
      "    return caller(func, *(extras + args), **kw)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\", line 224, in catch_format_error\n",
      "    r = method(self, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\", line 341, in __call__\n",
      "    return printer(obj)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/pylabtools.py\", line 250, in <lambda>\n",
      "    png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/pylabtools.py\", line 134, in print_figure\n",
      "    fig.canvas.print_figure(bytes_io, **kw)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/backend_bases.py\", line 2193, in print_figure\n",
      "    self.figure.draw(renderer)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\", line 41, in draw_wrapper\n",
      "    return draw(artist, renderer, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/figure.py\", line 1863, in draw\n",
      "    mimage._draw_list_compositing_images(\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/image.py\", line 131, in _draw_list_compositing_images\n",
      "    a.draw(renderer)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\", line 41, in draw_wrapper\n",
      "    return draw(artist, renderer, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/cbook/deprecation.py\", line 411, in wrapper\n",
      "    return func(*inner_args, **inner_kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axes/_base.py\", line 2747, in draw\n",
      "    mimage._draw_list_compositing_images(renderer, self, artists)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/image.py\", line 131, in _draw_list_compositing_images\n",
      "    a.draw(renderer)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\", line 41, in draw_wrapper\n",
      "    return draw(artist, renderer, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axis.py\", line 1169, in draw\n",
      "    tick.draw(renderer)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\", line 41, in draw_wrapper\n",
      "    return draw(artist, renderer, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axis.py\", line 291, in draw\n",
      "    artist.draw(renderer)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\", line 41, in draw_wrapper\n",
      "    return draw(artist, renderer, *args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/text.py\", line 688, in draw\n",
      "    posx, posy = trans.transform((posx, posy))\n",
      "KeyboardInterrupt\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py\", line 2061, in showtraceback\n",
      "    stb = value._render_traceback_()\n",
      "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n",
      "    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 248, in wrapped\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 281, in _fixed_getinnerframes\n",
      "    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/inspect.py\", line 1515, in getinnerframes\n",
      "    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/inspect.py\", line 1477, in getframeinfo\n",
      "    lines, lnum = findsource(frame)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\", line 182, in findsource\n",
      "    lines = linecache.getlines(file, globals_dict)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/linecache.py\", line 47, in getlines\n",
      "    return updatecache(filename, module_globals)\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/linecache.py\", line 137, in updatecache\n",
      "    lines = fp.readlines()\n",
      "  File \"/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/codecs.py\", line 319, in decode\n",
      "    def decode(self, input, final=False):\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "object of type 'NoneType' has no len()",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m<ipython-input-10-f85d0c7bbb96>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     82\u001b[0m                     \u001b[0mclear_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m                     \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     84\u001b[0m                     \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/display.py\u001b[0m in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, *objs, **kwargs)\u001b[0m\n\u001b[1;32m    319\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m             \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minclude\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    321\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mformat\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m    179\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m                 \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    181\u001b[0m             \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/decorator.py\u001b[0m in \u001b[0;36mfun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m    231\u001b[0m                 \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 232\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mcaller\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mextras\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    233\u001b[0m     \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    225\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    340\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    342\u001b[0m             \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m    249\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'png'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 250\u001b[0;31m         \u001b[0mpng_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'png'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    251\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'retina'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m'png2x'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m    133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m     \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m     \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m   2192\u001b[0m                     \u001b[0;32mwith\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2193\u001b[0;31m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m   1862\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1863\u001b[0;31m             mimage._draw_list_compositing_images(\n\u001b[0m\u001b[1;32m   1864\u001b[0m                 renderer, self, artists, self.suppressComposite)\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/image.py\u001b[0m in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m             \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/cbook/deprecation.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*inner_args, **inner_kwargs)\u001b[0m\n\u001b[1;32m    410\u001b[0m                 **kwargs)\n\u001b[0;32m--> 411\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minner_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0minner_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    412\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axes/_base.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer, inframe)\u001b[0m\n\u001b[1;32m   2746\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2747\u001b[0;31m         \u001b[0mmimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_draw_list_compositing_images\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/image.py\u001b[0m in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m             \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axis.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1168\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mtick\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mticks_to_draw\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m             \u001b[0mtick\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/axis.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m    290\u001b[0m                        self.label1, self.label2]:\n\u001b[0;32m--> 291\u001b[0;31m             \u001b[0martist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    292\u001b[0m         \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mdraw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0martist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/matplotlib/text.py\u001b[0m in \u001b[0;36mdraw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m    687\u001b[0m             \u001b[0mposy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtextobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_yunits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtextobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 688\u001b[0;31m             \u001b[0mposx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mposy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mposy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    689\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfinite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2060\u001b[0m                         \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2061\u001b[0;31m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'KeyboardInterrupt' object has no attribute '_render_traceback_'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2061\u001b[0m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2063\u001b[0;31m                         stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m   2064\u001b[0m                                             value, tb, tb_offset=tb_offset)\n\u001b[1;32m   2065\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1365\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1366\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1367\u001b[0;31m         return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1368\u001b[0m             self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m   1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1265\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1266\u001b[0m             \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1267\u001b[0;31m             return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1268\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1269\u001b[0m             )\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1122\u001b[0m         \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1124\u001b[0;31m         formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m   1125\u001b[0m                                                                tb_offset)\n\u001b[1;32m   1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m   1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1081\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1082\u001b[0;31m         \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1083\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1084\u001b[0m         \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m    380\u001b[0m     \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    384\u001b[0m     \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()"
     ]
    }
   ],
   "source": [
    "## CONFIRM q_list_test in validation/visualization in Akash's code\n",
    "\n",
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "# loss_crit = torch.nn.CrossEntropyLoss()\n",
    "loss_crit = torch.nn.functional.cross_entropy\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        model.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "            \n",
    "        logP = model(p_batch)\n",
    "        logQ = model(q_batch)\n",
    "        logM = model(m_batch)\n",
    "        \n",
    "        p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "        q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "        m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "        \n",
    "        loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 50 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "                    log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, m_batch, calc_true_kl=True)\n",
    "                    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "                    \n",
    "                    logP = model(p_batch)\n",
    "                    logQ = model(q_batch)\n",
    "                    logM = model(m_batch)\n",
    "\n",
    "                    log_ratio_p_q_from_cob = logP[:, 0] - logP[:, 1]\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob)\n",
    "                    \n",
    "                    log_ratio_p_q_from_cob = logM[:, 0] - logM[:, 1]\n",
    "\n",
    "                    p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "                    \n",
    "                    test_loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "\n",
    "                    # Visualize\n",
    "                    \n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store) )\n",
    "                    \n",
    "                    scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "                    scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)\n",
    "\n",
    "                    ax2.set_xlim( -50., 50. )\n",
    "                    ax2.set_ylim( -500, 200)\n",
    "            \n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    ax3.set_xlim( 0, len(test_loss_store) )\n",
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "    log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, m_batch, calc_true_kl=True)\n",
    "    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "        u_batch1 = torch.FloatTensor((500, 1)).uniform_(-5, 5).to(DEVICE)\n",
    "        u_batch2 = torch.FloatTensor((500, 1)).uniform_(-10, 10).to(DEVICE)\n",
    "        u_batch3 = torch.FloatTensor((500, 1)).uniform_(-20, 20).to(DEVICE)\n",
    "        \n",
    "    logP = model(p_batch)\n",
    "    logQ = model(q_batch)\n",
    "    logM = model(m_batch)\n",
    "    logU1 = model(u_batch1)\n",
    "    logU2 = model(u_batch2)\n",
    "    logU3 = model(u_batch3)\n",
    "\n",
    "    log_ratio_p_q_from_cob_p = logP[:, 0] - logP[:, 1]\n",
    "    kl_from_cob = torch.mean(log_ratio_p_q_from_cob_p)\n",
    "\n",
    "    log_ratio_p_q_from_cob_m = logM[:, 0] - logM[:, 1]\n",
    "    log_ratio_p_q_from_cob_q = logQ[:, 0] - logQ[:, 1]\n",
    "    \n",
    "    log_ratio_p_q_from_cob_u1 = logU1[:, 0] - logU1[:, 1]\n",
    "    log_ratio_p_q_from_cob_u2 = logU2[:, 0] - logU2[:, 1]\n",
    "    log_ratio_p_q_from_cob_u3 = logU3[:, 0] - logU3[:, 1]\n",
    "    \n",
    "    print('iteration: ',i)\n",
    "    print('True KLD: ', true_kl_p_q)\n",
    "    print('CoB from p samples: ', kl_from_cob)\n",
    "    print('CoB from q samples: ', log_ratio_p_q_from_cob_q.mean())\n",
    "    print('CoB from m samples: ', log_ratio_p_q_from_cob_m.mean())\n",
    "    print('CoB (-5, 5): ', log_ratio_p_q_from_cob_u1.mean())\n",
    "    print('CoB (-10, 10): ', log_ratio_p_q_from_cob_u2.mean())\n",
    "    print('CoB (-20, 20): ', log_ratio_p_q_from_cob_u3.mean())\n",
    "\n",
    "#     clear_output(wait=True)\n",
    "    display(fig)\n",
    "    break\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
