{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_DIMS = 1\n",
    "NUM_SAMPLES = 100000\n",
    "BS = 500\n",
    "NUM_EPOCHS = 800\n",
    "SEED = 10\n",
    "LR = 1e-2\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "# Break by changing num datapoints, scales, means, or to 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "model = RatioCritic1D(dim_input=N_DIMS, dim_output=3, dropout=DROPOUT)\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_1d(mu1=-2., mu2=2., mu3=0, scale_p=0.08, scale_q=0.15, scale_m=1.0)\n",
    "\n",
    "# -5, 5, m_var=3.0\n",
    "# -10, 10, m_var=3.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling p\n",
      "Sampling q\n",
      "Cauchy(loc: 0.0, scale: 1.0)\n",
      "Sampling m\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "Sampling p\n",
      "Sampling q\n",
      "Cauchy(loc: 0.0, scale: 1.0)\n",
      "Sampling m\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n",
      "torch.Size([100000])\n"
     ]
    }
   ],
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES) # Test dataset is only of size batch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABDAAAAEYCAYAAACqUwbqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyQklEQVR4nO3de7gcVZno/+87SSDIHYzcgiYiwoRbhAAiyMSAIYA/kRGFnEGDw5yIBxR0BEE9B3RwBPEgXhjHPIKAIAQRBg6DXCUiqIQEuWMkQJBEkBgQ5A7J+/uja4fOTnf2zt67u2r3/n6ep5+uWlVd6+21u6vWfrtqVWQmkiRJkiRJVfZ3ZQcgSZIkSZLUExMYkiRJkiSp8kxgSJIkSZKkyjOBIUmSJEmSKs8EhiRJkiRJqjwTGJIkSZIkqfJalsCIiHMj4qmIuK+ubKOIuCEiHiqeN2xV/ZI0VEXEgoi4NyLuiog5RVnD/W/UfCci5kfEPRGxc912phXrPxQR08p6P5Ik+9aSBK09A+M8YEq3shOBmzJza+CmYl6SNPDel5njM3NCMd9s/7s/sHXxmA58H2qdYuBkYHdgN+BkO8aSVKrzsG8taYhrWQIjM28Bnu5WfBBwfjF9PvChVtUvSVpBs/3vQcAFWfNbYIOI2AzYD7ghM5/OzGeAG1i54yxJahP71pIEw9tc3yaZ+UQx/SSwSbMVI2I6tV8DWXvttXfZdttt2xCepKFk7ty5f8nMUWXH0QIJXB8RCfwgM2fQfP+7BfB43WsXFmXNylfgvlpSq3Xwvnog2LeWVBnt2F+3O4GxXGZm0blutnwGMANgwoQJOWfOnLbFJmloiIjHyo6hRfbKzEUR8Rbghoj4ff3Cnva/q8N9taRW6+B99YCyby2pbO3YX7f7LiR/Lk5Npnh+qs31S1LHy8xFxfNTwBXUxrBotv9dBGxZ9/LRRVmzcklSddi3ljSktDuBcRXQNZL9NODKNtcvSR0tItaOiHW7poHJwH003/9eBXy8uBvJu4Fni9ORrwMmR8SGxeCdk4sySVJ12LeWNKS07BKSiLgYmAi8OSIWUhvN/jTg0og4EngM+Gir6pekIWoT4IqIgNo+/ieZeW1E3EHj/e81wAHAfOBF4BMAmfl0RPwbcEex3lczs/vgcZKkNrFvLUktTGBk5tQmi/ZpVZ2SNNRl5iPATg3Kl9Bg/5uZCRzdZFvnAucOdIzSUPLaa6+xcOFCXn755bJDqbyRI0cyevRoRowYUXYolWTfWpJKHMRTkiSp0y1cuJB1112XMWPGUJwZpQYykyVLlrBw4ULGjh1bdjiSpIpq9xgYkiRJQ8bLL7/MxhtvbPKiBxHBxhtv7JkqkqRVMoEhSZLUQiYvesd2kiT1xASGJEmSJEmqPBMYkiRJHWrJkiWMHz+e8ePHs+mmm7LFFlssn3/11VdbXv8TTzzB5MmTW16PJGlocBBPSZKkDrXxxhtz1113AXDKKaewzjrr8PnPf3758tdff53hw1vXHbz22mvZb7/9WrZ9SdLQ4hkYkiRJQ8gRRxzBUUcdxe67784JJ5zAKaecwje/+c3ly7fffnsWLFgAwIUXXshuu+3G+PHj+eQnP8nSpUtX2t6YMWM44YQT2GGHHdhtt92YP3/+8mXXXnst+++/P5nJMcccwzbbbMO+++7LAQccwGWXXdby9ypJ6iwmMCRJkirk+uvhi1+sPbfKwoUL+fWvf82ZZ57ZdJ0HH3yQmTNnctttt3HXXXcxbNgwLrrooobrrr/++tx7770cc8wxHHfccQAsXbqUefPmMW7cOK644grmzZvHAw88wAUXXMCvf/3rVrwtSVKH8xISSZKkirj+ejj8cHjlFfjhD+HCC6EVQ0h85CMfYdiwYatc56abbmLu3LnsuuuuALz00ku85S1vabju1KlTlz9/9rOfBeD2229n9913B+CWW25h6tSpDBs2jM0335xJkyYN1FuRJA0hJjAkSZIqYtasWvJi7bXhhRdq861IYKy99trLp4cPH86yZcuWz7/88ssAZCbTpk3j61//eo/bq78Fatf0z3/+c6ZMmTJQIUuS5CUkkiRJVTFxIqy5Zi15seaatflWGzNmDHfeeScAd955J48++igA++yzD5dddhlPPfUUAE8//TSPPfZYw23MnDlz+fMee+wB1M7g2HfffQHYe++9mTlzJkuXLuWJJ57g5ptvbul7kiR1Js/AkCRJqojJk2uXjcyaVUtetOMOpB/+8Ie54IIL2G677dh999155zvfCcC4ceM49dRTmTx5MsuWLWPEiBGcffbZvO1tb1tpG8888ww77rgja665JhdffDGLFy9m5MiRrLvuugAcfPDB/OIXv2DcuHG89a1vXZ7kkCRpdZjAkCRJqpDJk1uTuDjllFMalq+11lpc32TE0EMPPZRDDz20x20ff/zxnH766cvnL7zwQibXvYmI4Hvf+97y+SOOOKJ3QUuSVMcEhiRJkgbU4YcfXnYIkqQOZAJDkiRJfbZgwYLVfs1555034HFIkjqfg3hKkiRJkqTKM4EhSZIkSZIqzwSGJEmSJEmqPBMYkiRJkiSp8kxgSJIkdbAnn3ySww47jK222opddtmFAw44gD/84Q9N11+wYAFrrbUW48ePZ6edduI973kP8+bNW+16TzvtNC666KL+hC5J0gpMYEiSJHWozOTggw9m4sSJPPzww8ydO5evf/3r/PnPf17l67baaivuuusu7r77bqZNm8a///u/r3bd1113HZMnT+5r6JIkrcQEhiRJUoe6+eabGTFiBEcdddTysp122on3vve9ZCbHH38822+/PTvssAMzZ85suI3nnnuODTfccKXyWbNmsffee3PggQeyzTbbcNRRR7Fs2bLlr3n11VcZNWoUjz76KHvssQc77LADX/7yl1lnnXVa82YlSR1veNkBSJIkqc7118OsWTBxIvTzDIb77ruPXXbZpeGyyy+/fPlZFn/5y1/Ydddd2XvvvQF4+OGHGT9+PH/729948cUXuf322xtuY/bs2TzwwAO87W1vY8qUKVx++eUccsgh3Hjjjeyzzz4AHHvssXzqU5/i4x//OGeffXa/3o8kaWjzDAxJkqSquP56OPxwOPvs2vP117esqltvvZWpU6cybNgwNtlkE/7hH/6BO+64A3jjEpKHH36Ys846i+nTpzfcxm677cbb3/52hg0bxtSpU7n11lsBuPbaa9l///0BuO2225g6dSoAH/vYx1r2fiRJnc8EhiRJUlXMmgWvvAJrr117njWrX5vbbrvtmDt3br+28cEPfpBbbrml4bKIaDg/e/Zsdtttt6brSZLUFyYwJEmSqmLiRFhzTXjhhdrzxIn92tykSZN45ZVXmDFjxvKye+65h1/96le8973vZebMmSxdupTFixdzyy23rJB06HLrrbey1VZbNdz+7NmzefTRR1m2bBkzZ85kr7324v7772fbbbdl2LBhAOy5555ccsklAN6VRJLULyYwJEmSqmLyZLjwQjj66NpzP8fAiAiuuOIKbrzxRrbaaiu22247TjrpJDbddFMOPvhgdtxxR3baaScmTZrEN77xDTbddFPgjTEwdtppJ774xS/ywx/+sOH2d911V4455hj+/u//nrFjx3LwwQfz85//nClTpixf59vf/jZnn302O+ywA4sWLerX+5EkDW0O4ilJHSgihgFzgEWZ+YGIGAtcAmwMzAU+lpmvRsSawAXALsAS4NDMXFBs4yTgSGAp8JnMvK7970QagiZP7nfiot7mm2/OpZde2nDZGWecwRlnnLFC2ZgxY3jppZd6te311luPq6++eoWy6667jgsuuGD5/NixY/nNb36zfP6ss87qZeSSJK3IMzAkqTMdCzxYN3868K3MfAfwDLXEBMXzM0X5t4r1iIhxwGHAdsAU4D+KpIgkrdINN9zAZpttVnYYkqQOZAJDkjpMRIwGDgR+WMwHMAm4rFjlfOBDxfRBxTzF8n2K9Q8CLsnMVzLzUWA+sPLF8ZKGrIkTJ6509kVvPP/88y2IRpI0FJjAkKTOcxZwArCsmN8Y+Gtmvl7MLwS2KKa3AB4HKJY/W6y/vLzBa5aLiOkRMSci5ixevHiA34bUGTKz7BAGBdtJktQTExiS1EEi4gPAU5nZv/sm9lJmzsjMCZk5YdSoUe2oUhpURo4cyZIlS/znvAeZyZIlSxg5cmTZoUiSKsxBPCWps+wJfDAiDgBGAusB3wY2iIjhxVkWo4GuWwEsArYEFkbEcGB9aoN5dpV3qX+NpF4aPXo0CxcuxDOUejZy5EhGjx5ddhiSpAozgSFJHSQzTwJOAoiIicDnM/OfIuKnwCHU7kQyDbiyeMlVxfxviuW/yMyMiKuAn0TEmcDmwNbA7Da+FakjjBgxgrFjx5YdhiRJHcEEhiQNDV8ALomIU4HfAecU5ecAP46I+cDT1O48QmbeHxGXAg8ArwNHZ+bS9octSZIk1ZjAkKQOlZmzgFnF9CM0uItIZr4MfKTJ678GfK11EUqSJEm95yCekiRJkiSp8kpJYETEZyPi/oi4LyIujgiHnJYkSZL6wL61pKGi7QmMiNgC+AwwITO3B4ZRXHMtSZIkqffsW0saSsq6hGQ4sFZxy743AX8qKQ5JkiRpsLNvLWlIaHsCIzMXAd8E/gg8ATybmde3Ow5JkiRpsLNvLWkoKeMSkg2Bg4CxwObA2hFxeIP1pkfEnIiYs3jx4naHKUmSJFWefWtJQ0kZl5DsCzyamYsz8zXgcuA93VfKzBmZOSEzJ4waNartQUqSJEmDgH1rSUNGGQmMPwLvjog3RUQA+wAPlhCHJEmSNNjZt5Y0ZJQxBsbtwGXAncC9RQwz2h2HJEmSNNjZt5Y0lAwvo9LMPBk4uYy6JUmSpE5i31rSUFHWbVQlSZIkSZJ6zQSGJEmSJEmqPBMYkiRJkiSp8kxgSJIkSZKkyjOBIUmSJEmSKs8EhiRJkiRJqjwTGJIkSZIkqfJMYEiSJEmSpMozgSFJkiRJkirPBIYkSZIkSao8ExiSJEmSJKnyTGBIkiRJkqTKM4EhSZIkSZIqzwSGJEmSJEmqPBMYkiRJkiSp8kxgSJIkSZKkyjOBIUkdJCJGRsTsiLg7Iu6PiK8U5WMj4vaImB8RMyNijaJ8zWJ+frF8TN22TirK50XEfiW9JUmSJAkwgSFJneYVYFJm7gSMB6ZExLuB04FvZeY7gGeAI4v1jwSeKcq/VaxHRIwDDgO2A6YA/xERw9r5RiRJkqR6JjAkqYNkzfPF7IjikcAk4LKi/HzgQ8X0QcU8xfJ9IiKK8ksy85XMfBSYD+zW+ncgSZIkNWYCQ5I6TEQMi4i7gKeAG4CHgb9m5uvFKguBLYrpLYDHAYrlzwIb15c3eI0kSZLUdiYwJKnDZObSzBwPjKZ21sS2raorIqZHxJyImLN48eJWVSNJkiSZwJCkTpWZfwVuBvYANoiI4cWi0cCiYnoRsCVAsXx9YEl9eYPX1NcxIzMnZOaEUaNGteJtSJIkSYAJDEnqKBExKiI2KKbXAt4PPEgtkXFIsdo04Mpi+qpinmL5LzIzi/LDiruUjAW2Bma35U1IkiRJDQzveRVJ0iCyGXB+cceQvwMuzcyrI+IB4JKIOBX4HXBOsf45wI8jYj7wNLU7j5CZ90fEpcADwOvA0Zm5tM3vRZIkSVrOBIYkdZDMvAd4V4PyR2hwF5HMfBn4SJNtfQ342kDHKEmSJPWFl5BIkiRJkqTKM4EhSZIkSZIqzwSGJEmSJEmqPBMYkiRJkiSp8kxgSJIkSZKkyjOBIUmSJEmSKs8EhiRJkiRJqjwTGJIkSZIkqfJMYEiSJEmSpMozgSFJkiRJkirPBIYkSZIkSao8ExiSJEmSJKnyTGBIkiRJkqTKKyWBEREbRMRlEfH7iHgwIvYoIw5JkiRpsLNvLWmoGF5Svd8Grs3MQyJiDeBNJcUhSZIkDXb2rSUNCW1PYETE+sDewBEAmfkq8Gq745AkSZIGO/vWkoaSHi8hiYitImLNYnpiRHwmIjboR51jgcXAjyLidxHxw4hYu0G90yNiTkTMWbx4cT+qkyRJkjqWfWtJQ0ZvxsD4GbA0It4BzAC2BH7SjzqHAzsD38/MdwEvACd2XykzZ2TmhMycMGrUqH5UJ0mSJHUs+9aShozeJDCWZebrwMHAdzPzeGCzftS5EFiYmbcX85dR2+lK0pAREetHxLe6fg2LiP9bnAYsSdLqsG8tacjoTQLjtYiYCkwDri7KRvS1wsx8Eng8IrYpivYBHujr9iRpkDoXeA74aPF4DvhRqRFJklouIr4REetFxIiIuCkiFkfE4X3dnn1rSUNJbwbx/ARwFPC1zHw0IsYCP+5nvZ8GLipGSX6kqEOShpKtMvPDdfNfiYi7ygpGktQ2kzPzhIg4GFgA/CNwC3BhP7Zp31rSkNBjAiMzHwA+AxARGwLrZubp/ak0M+8CJvRnG5I0yL0UEXtl5q0AEbEn8FLJMUmSWq+r/30g8NPMfDYi+rVB+9aShooeExgRMQv4YLHuXOCpiLgtMz/X4tgkqZN9Cji/GPcigKcpboEnSepoV0fE76klrT8VEaOAl0uOSZIGhd5cQrJ+Zj4XEf8CXJCZJ0fEPa0OTJI6WfFr2U4RsV4x/1y5EUmS2iEzT4yIbwDPZubSiHgBOKjsuCRpMOhNAmN4RGxGbZC5L7U4HknqaBFxeGZeGBGf61YOQGaeWUpgkqS2iIiPANcWyYsvU7tjyKnAk+VGJknV15u7kHwVuA54ODPviIi3Aw+1NixJ6lhrF8/rNnisU1ZQkqS2+d+Z+beI2AvYFzgH+H7JMUnSoNCbQTx/Cvy0bv4R4MPNXyFJaiYzf1BM3piZt9UvKwbylCR1tqXF84HAjMz874g4tcyAJGmw6PEMjIgYHRFXRMRTxeNnETG6HcFJUgf7bi/LVktEbBkRN0fEAxFxf0QcW5RvFBE3RMRDxfOGRXlExHciYn5E3BMRO9dta1qx/kMRMa2/sUmSAFgUET8ADgWuiYg16d1Z0ZI05PVmDIwfAT8BPlLMH16Uvb9VQUlSp4qIPYD3AKO6jYOxHjBsAKp4HfjXzLwzItYF5kbEDdTucHJTZp4WEScCJwJfAPYHti4eu1M7jXn3iNgIOJnabfmy2M5VmfnMAMQoSUPZR4EpwDcz86/FWHPHlxyTJA0Kvcn2jsrMH2Xm68XjPGBUi+OSpE61BrWxLoaz4vgXzwGH9HfjmflEZt5ZTP8NeBDYgtoI9+cXq50PfKiYPojaHaYyM38LbFB0pvcDbsjMp4ukxQ3UOtySpH7IzBeBh4H9IuIY4C2ZeX3JYUnSoNCbMzCWRMThwMXF/FRgSetCkqTOlZm/BH4ZEedl5mOtrCsixgDvAm4HNsnMJ4pFTwKbFNNbAI/XvWxhUdasvHsd04HpAG9961sHMHpJ6kzFpX3/E7i8KLowImZkZr8vI5SkTtebBMY/U7su+1vUTiP+NbVTkSVJffdiRJwBbAeM7CrMzEkDsfGIWAf4GXBcZj7XdZvWoo6MiByIejJzBjADYMKECQOyTUnqcEcCu2fmCwARcTrwGwZgHCRJ6nQ9XkKSmY9l5gczc1RmviUzPwQc2/rQJKmjXQT8HhgLfAVYANwxEBuOiBHUkhcXZWbXL3x/Li4NoXh+qihfBGxZ9/LRRVmzcklS/wRv3ImEYjqarCtJqtPXEY8/OqBRSNLQs3FmngO8lpm/zMx/Bvp99kXUTrU4B3gwM8+sW3QV0HUnkWnAlXXlHy/uRvJu4NniUpPrgMkRsWFxx5LJRZkkqX9+BNweEadExCnAb6nttyVJPejNJSSNmCWWpP55rXh+IiIOBP4EbDQA290T+Bhwb0TcVZR9ETgNuDQijgQe441E9DXAAcB84EXgEwCZ+XRE/BtvnBXy1cx8egDik6QhLTPPjIhZwF5F0SeAP5cXkSQNHk0TGMUt9BouwgSGJPXXqRGxPvCv1K57Xg84rr8bzcxbab6P3qfB+gkc3WRb5wLn9jcmSdKKirtF3dk1HxF/BBwJWZJ6sKozMOZSG7SzUUf41daEI0lDQ2ZeXUw+C7wPICL2LC8iSVKJ/HFQknqhaQIjM8e2MxBJGgoiYhi1yze2AK7NzPsi4gPULvNYi9ptTyVJQ4t3cZKkXujrGBiSpL45h9rdPWYD34mIPwETgBMz87/KDEyS1DoR8V0aJyoC2KC90UjS4GQCQ5LaawKwY2Yui4iRwJPAVpm5pOS4JEmtNaePyyRJBRMYktRer2bmMoDMfDkiHjF5IUmdLzPPLzsGSRrsepXAKK7Z3qR+/cz8Y6uCkqQOtm1E3FNMB7BVMR/UbgqyY3mhSZIkSdXVYwIjIj4NnEzt/tTLiuIE7GRL0ur7+7IDkCRJkgaj3pyBcSywjac4S1L/ZeZjZccgSSpPROyZmbf1VCZJWtnf9WKdx4FnWx2IJEmSNAR8t5dlkqRuenMGxiPArIj4b+CVrsLMPLNlUUmSJEkdJCL2AN4DjIqIz9UtWg8YVk5UkjS49CaB8cfisUbxkCRJkrR61gDWodb/Xreu/DngkFIikqRBpscERmZ+pR2BSNJQEhH3UhsQud6zwBzgVMcdkqTOkpm/BH4ZEed1jYcUEX8HrJOZz5UbnSQNDk0TGBFxVmYeFxH/j5U72WTmB1samSR1tp8DS4GfFPOHAW8CngTOA/6/csKSJLXY1yPiKGrHgDuA9SLi25l5RslxSVLlreoMjB8Xz99sRyCSNMTsm5k7183fGxF3ZubOEXF4aVFJklptXGY+FxH/RC2ZfSIwFzCBIUk9aJrAyMy5xfMv2xeOJA0ZwyJit8ycDRARu/LGIG6vlxeWJKnFRkTECOBDwPcy87WIWOlsZ0nSynocAyMitga+DowDRnaVZ+bbWxiXJHW6fwHOjYh1gKA2iNuREbE2tX2uJKkz/QBYANwN3BIRb6N2DJAk9aA3dyH5EXAy8C3gfcAngL9rZVCS1Oky8w5gh4hYv5h/tm7xpeVEJUlqtcz8DvCduqLHIuJ9ZcUjSYNJbxIRa2XmTUBk5mOZeQpwYGvDkqTOFhHrR8SZwE3ATRHxf7uSGZKkzhURm0TEORHx82J+HDCt5LAkaVDoTQLjleIWTw9FxDERcTC1e1hLkvruXOBvwEeLx3PUzniTJHW284DrgM2L+T8Ax5UVjCQNJr1JYBxL7dZ+nwF2AQ7HLLEk9ddWmXlyZj5SPL4COLaQJHWoiOi6dPvNmXkpsAwgM1+ndktVSVIPVpnAiIhhwKGZ+XxmLszMT2TmhzPzt22KT5I61UsRsVfXTETsCbxUYjySpNaaXTy/EBEbAwkQEe8Gnm36KknSck0H8YyI4Zn5en0HW5I0YI4CLqgb9+IZPLtNkjpZFM+fA64CtoqI24BRwCGlRSVJg8iq7kIyG9gZ+F1EXAX8FHiha2FmXt7i2CSpY2Xm3cBOEbFeMf9cRBwH3FNqYJKkVhkVEZ8rpq8ArqGW1HgF2Bf3/5LUo97cRnUksASYRO1Utyie+5XAKC5PmQMsyswP9GdbkjRYZeZzdbOfA84qKRRJUmsNozYQfnQrf9NAbNy+taShYFUJjLcUWeL7eCNx0SUHoO5jgQeB9QZgW5LUCbp3aiVJneOJzPxqC7dv31pSx1vVIJ5dWeJ1gHXrprsefRYRo4EDgR/2ZzuS1GEGIjksSaqmliWp7VtLGipWdQZGK7PEZwEnUEuMNBQR04HpAG9961tbFIYktVdE/I3GiYoA1mpzOJKk9tmnhds+C/vWkoaAVZ2B0ZIscUR8AHgqM+euar3MnJGZEzJzwqhRo1oRiiS1XWaum5nrNXism5m9GZdolSLi3Ih4KiLuqyvbKCJuiIiHiucNi/KIiO9ExPyIuCcidq57zbRi/YciwrujSFI/ZebTrdiufWtJQ8mqEhityhLvCXwwIhYAlwCTIuLCFtUlSUPNecCUbmUnAjdl5tbATcU8wP7A1sVjOvB9qCU8gJOB3YHdgJO7kh6SpMqxby1pyGiawGhVljgzT8rM0Zk5BjgM+EVmHt6KuiRpqMnMW4Du+++DgPOL6fOBD9WVX5A1vwU2iIjNgP2AGzLz6cx8BriBlZMikqQKsG8taShZ1RkYkqTOsElmPlFMPwlsUkxvATxet97CoqxZ+UoiYnpEzImIOYsXLx7YqCVJkqQ6pSYwMnOW96mWpPbJzGQA73biNdWSVB32rSV1Os/AkKTO9+fi0hCK56eK8kXAlnXrjS7KmpVLkiRJpTGBIUmd7yqg604i04Ar68o/XtyN5N3As8WlJtcBkyNiw2LwzslFmSRJklSaft+yT5JUHRFxMTAReHNELKR2N5HTgEsj4kjgMeCjxerXAAcA84EXgU9AbRDniPg34I5iva+2amBnSZIkqbdMYEhSB8nMqU0WrXRr7GI8jKObbOdc4NwBDE2SJEnqFy8hkSRJkiRJlWcCQ5IkSZIkVZ4JDEmSJEmSVHkmMCRJkiRJUuWZwJAkSZIkSZVnAkOSJEmSJFWeCQxJkiRJklR5JjAkSZIkSVLlmcCQJEmSJEmVZwJDkiRJkiRVngkMSZIkSZJUeSYwJEmSJElS5ZnAkCRJkiRJlWcCQ5IkSZIkVZ4JDEmSJEmSVHkmMCRJkiRJUuWZwJAkSZIkSZVnAkOSJEmSJFWeCQxJkiRJklR5JjAkSZIkSVLlmcCQJEmSJEmVZwJDkiRJkiRVngkMSZIkSZJUeSYwJEmSJElS5ZnAkCRJkiRJlWcCQ5IkSZIkVZ4JDElSUxExJSLmRcT8iDix7HgkSZI0dA0vOwBJUjVFxDDgbOD9wELgjoi4KjMfKDcy9cXSiAH71WIZMCxzgLYmSZLUO56BIUlqZjdgfmY+kpmvApcAB5Uck/pgIJMXUOs8LI0YwC1KkiT1zASGJKmZLYDH6+YXFmXLRcT0iJgTEXMWL17c1uDUe6042NuBkCRJ7Wb/Q5LUZ5k5IzMnZOaEUaNGlR2Omlg2SLYpSZK0KiYwJEnNLAK2rJsfXZRpkBmWOaAJB8fAkCRJZXAQT0lSM3cAW0fEWGqJi8OA/1FuSOqrgUw4DBuwLUmSJPVe28/AiIgtI+LmiHggIu6PiGPbHYMkqWeZ+TpwDHAd8CBwaWbeX25UkqR69q0lDSVlnIHxOvCvmXlnRKwLzI2IG7wtnyRVT2ZeA1xTdhySpKbsW0saMtp+BkZmPpGZdxbTf6P2q94Wq36VJEmSpO7sW0saSkodxDMixgDvAm4vMw5JkiRpsLNvLanTlZbAiIh1gJ8Bx2Xmcw2WT4+IORExZ/Hixe0PUJIkSRok7FtLGgpKSWBExAhqO9iLMvPyRutk5ozMnJCZE0aNGtXeACVJkqRBwr61pKGijLuQBHAO8GBmntnu+iVJkqROYd9a0lBSxhkYewIfAyZFxF3F44AS4pAkSZIGO/vWkoaMtt9GNTNvBaLd9UqSJEmdxr61pKGk1LuQSJIkSZIk9YYJDEmSJEmSVHkmMCRJkiRJUuWZwJAkSZIkSZVnAkOSJEmSJFWeCQxJkiRJklR5JjAkSZIkSVLlmcCQJEmSJEmVZwJDkiRJkiRVngkMSZIkSZJUeSYwJEmSJElS5ZnAkCRJkiRJlWcCQ5IkSZIkVZ4JDEmSJEmSVHkmMCRJkiRJUuWZwJAkSZIkSZU3KBIYyzLLDkGSJEnqCHatJQ1WgyKB8eSzL5cdgiRJktQRnn/ltbJDkKQ+GRQJDEmSJEmSNLSZwJAkSZIkSZVnAkOSOkREfCQi7o+IZRExoduykyJifkTMi4j96sqnFGXzI+LEuvKxEXF7UT4zItZo53uRJEmSujOBIUmd4z7gH4Fb6gsjYhxwGLAdMAX4j4gYFhHDgLOB/YFxwNRiXYDTgW9l5juAZ4Aj2/MWJEmSpMZMYEhSh8jMBzNzXoNFBwGXZOYrmfkoMB/YrXjMz8xHMvNV4BLgoIgIYBJwWfH684EPtfwNSJIkSatgAkOSOt8WwON18wuLsmblGwN/zczXu5WvJCKmR8SciJizePHiAQ9ckiRJ6jK87AAkSb0XETcCmzZY9KXMvLLd8WTmDGAGwIQJE7Ld9UuSJGnoMIEhSYNIZu7bh5ctArasmx9dlNGkfAmwQUQML87CqF9fkiRJKoWXkEhS57sKOCwi1oyIscDWwGzgDmDr4o4ja1Ab6POqzEzgZuCQ4vXTgLaf3SFJkiTVM4EhSR0iIg6OiIXAHsB/R8R1AJl5P3Ap8ABwLXB0Zi4tzq44BrgOeBC4tFgX4AvA5yJiPrUxMc5p77uRJEmSVuQlJJLUITLzCuCKJsu+BnytQfk1wDUNyh+hdpcSSZIkqRI8A0OSJEmSJFWeCQxJkiRJklR5JjAkSZIkSVLlmcCQJEmSJEmVZwJDkiRJkiRVngkMSZIkSZJUeSYwJEmSJElS5ZnAkCRJkiRJlWcCQ5IkSZIkVV4pCYyImBIR8yJifkScWEYMkiRJUiewby1pqGh7AiMihgFnA/sD44CpETGu3XFIkiRJg519a0lDSRlnYOwGzM/MRzLzVeAS4KAS4pAkSZIGO/vWkoaM4SXUuQXweN38QmD37itFxHRgejH7SkTc14bYeuPNwF/KDqJgLCurShxgLM1UKZZtyg6gk8ydO/f5iJhXdhyFKn3OjKUxY2msKrFUJQ5wX92Twdy3rtLnzFgaM5bGjKWxlu+vy0hg9EpmzgBmAETEnMycUHJIgLE0U5VYqhIHGEszVYul7Bg6zLwq/W2NZWXG0pixVDcOcF89UKrYt65KHGAszRhLY8bSWDv212VcQrII2LJufnRRJkmSJGn12LeWNGSUkcC4A9g6IsZGxBrAYcBVJcQhSZIkDXb2rSUNGW2/hCQzX4+IY4DrgGHAuZl5fw8vm9H6yHrNWBqrSixViQOMpRlj6VxVak9jacxYGjOWlVUlDqhWLJUzyPvWVYkDjKUZY2nMWBpreSyRma2uQ5IkSZIkqV/KuIREkiRJkiRptZjAkCRJkiRJlVfpBEZETImIeRExPyJOHMDtbhkRN0fEAxFxf0QcW5RvFBE3RMRDxfOGRXlExHeKOO6JiJ3rtjWtWP+hiJhWV75LRNxbvOY7ERGriGdYRPwuIq4u5sdGxO3Fa2cWAzIREWsW8/OL5WPqtnFSUT4vIvbraxtGxAYRcVlE/D4iHoyIPcpol4j4bPG3uS8iLo6Ike1sl4g4NyKeirp7pLejHRrUcWGDOM4o/j73RMQVEbFBX9/v6rRpozap286/RkRGxJvb0CYbNoslIj5dtM39EfGNdrRL97ZQ879DifGs8PksKYam39s21d+S42kf4mh4/C1TdDsGlxjHSsffEmP5bHQ7Brex7l4ff7X6WrUvaPbdbva3G8h+QpN4KtG3bvS9LrFNVvpet6tdVud7PZDt0KiOJrHYtx5MfevMrOSD2iBEDwNvB9YA7gbGDdC2NwN2LqbXBf4AjAO+AZxYlJ8InF5MHwD8HAjg3cDtRflGwCPF84bF9IbFstnFulG8dv9VxPM54CfA1cX8pcBhxfR/Ap8qpv8X8J/F9GHAzGJ6XNE+awJji3Yb1pc2BM4H/qWYXgPYoN3tAmwBPAqsVdceR7SzXYC9gZ2B++rKWt4ODeq4qEEck4HhxfTpdXGs9vtdnTZt1CbF8i2pDRz2GPDmNrTJ6U3+Pu8DbgTWLObf0o52KXtfWbVHs79DifGs9PksKY6G39s21d2y42kfYml4/C35M7LCMbjEOFY6/pYUR8NjcBvr7/Xx18dqt61962xv37rR97qMNmn2vW5Xu1CdfnWzPqR960HUty59Z7qKL/wewHV18ycBJ7WoriuB9wPzgM2Kss2AecX0D4CpdevPK5ZPBX5QV/6Domwz4Pd15Sus163u0cBNwCTg6uID9hfe+BItb4fig7xHMT28WC+6t03XeqvbhsD61HZu0a28re1CbSf7OLUv4vCiXfZrd7sAY1jxS9zydmhUR/c4usV4MHBRo/fR0/vt42dtpViAy4CdgAW8sZNtaZs0+ftcCuzboI1a3i6t2C8N1kezv0OJ8az0+Sz7Uf+9bVN9bTue9iG2K4H3l1j/CsfgEuNoePwtKZZGx+DJbY6h+/694XHAx2q3q33r7N1xfXX7Dk3iqES/ulhWet+6t9/rgWyHVdSxQizd4rRv/cZrKtm3rvIlJF1ftC4Li7IBVZym8i7gdmCTzHyiWPQksEkPsayqfGEvYz8LOAFYVsxvDPw1M19v8Nrl9RXLny3WX934mhkLLAZ+FLXT7n4YEWvT5nbJzEXAN4E/Ak8U73Mu5bVLl3a0Q7M6mvlnahnVvsTRl8/aCiLiIGBRZt7dbVEZbfJO4L3F6We/jIhd+xhLv9tliGv2d2i7VXw+y1b/vW2HthxPV1e3429ZzmLFY3BZmh1/267RMTgzry8jljqre2xUY/atV35tK/uQlehXF++tin3rKvarwb51vUr2raucwGi5iFgH+BlwXGY+V78sa2mgbHH9HwCeysy5raxnNQyndurQ9zPzXcAL1E4rWq5N7bIhcBC1Hf/mwNrAlFbWubra0Q491RERXwJep3aZSdtFxJuALwL/p1119tAmw6n9svBu4Hjg0q5r/TSwIuLGqF1D2/1xEG3+O/QQS1s/nz3E0rVOqd/bqljV8beNMVTpGNzj8bddGh2DI+LwMmJppB3HX/WdfesVVKJfDdXvW1ehXw3lH6PtW/dOlRMYi6hd/9NldFE2ICJiBLUd7EWZeXlR/OeI2KxYvhnwVA+xrKp8dC9i3xP4YEQsAC6hdqrbt4ENImJ4g9cur69Yvj6wpA/xNbMQWJiZXb+GXUZtx9vudtkXeDQzF2fma8Dl1NqqrHbp0o52aFbHCiLiCOADwD8VO56+xLGE1W/TeltROxDeXXyGRwN3RsSmfYil321C7fN7edbMpvbLy5tLaJeOl5n7Zub2DR5X0vzv0NZYqF0D2uzz2dZYinZp9r1th5YeT1dXk+NvGVY6BkfEhSXF0uz4W4ZGx+D3lBRLl94eB7Rq9q1Xfm0r+5BV6VdDNfvWlelXF8uPwL51d9XsW6/q+pIyH9QyPl0d0K5BQLYboG0HcAFwVrfyM1hxQJNvFNMHsuKgKbOL8o2oXdu2YfF4FNioWNZ90JQDeohpIm8MNPRTVhzk5H8V00ez4iAnlxbT27HiQCqPUBtEZbXbEPgVsE0xfUrRJm1tF2B34H7gTcV65wOfbne7sPJ1YC1vh0Z1NIhjCvAAMKpbvKv9fvvQpivE0q3+BbxxnV5L26TJ3+co4KvF9DupnY4W7WgXHyt8Dhr+HSoQ1/LPZ0n1N/zetqnulh1P+xBLw+Nv2Q/qjsElxrDS8bekOBoeg9scQ/f9e8PjgI/Vblf71tnevnWj73UZbdLse93OdqEi/eomsdi3btwulexbl74z7WHHcwC1UYwfBr40gNvdi9qpMvcAdxWPA6hdb3MT8BC1EVe7/vgBnF3EcS8woW5b/wzMLx6fqCufANxXvOZ79NCBZ8Wd7NuLD9z84o/dNfLryGJ+frH87XWv/1JR1zzqRiBe3TYExgNzirb5r+KL0PZ2Ab4C/L5Y98fFF6Rt7QJcTO0awdeoZR+PbEc7NKjjZw3imE9tB3JX8fjPvr7f1WnTRm3Src0W8MZOtpVtslGTv88awIXFNu4EJrWjXcreT1btsaq/Q8lxLf98llR/0+9tm+pvyfG0D3E0PP5W4PMxkfITGOPpdvwtMZaVjsFtrLvXx18ffWpf+9Zt7Fs3+l6X1SaNvtftaheq069u1oe0bz2I+tZdb0KSJEmSJKmyqjwGhiRJkiRJEmACQ5IkSZIkDQImMCRJkiRJUuWZwJAkSZIkSZVnAkOSJEmSJFWeCQyVKiKeL57HRMT/GOBtf7Hb/K8HcvuSNJRExJci4v6IuCci7oqI3VtY16yImNCq7UtSp7JvrU5nAkNVMQZYrZ1sRAzvYZUVdrKZ+Z7VjEmSBETEHsAHgJ0zc0dgX+DxcqOSJK3CGOxbqwOZwFBVnAa8t/hV77MRMSwizoiIO4pf+z4JEBETI+JXEXEV8EBR9l8RMbf4ZXB6UXYasFaxvYuKsq6MdBTbvi8i7o2IQ+u2PSsiLouI30fERRERJbSFJFXNZsBfMvMVgMz8S2b+KSL+T7Gfvi8iZnTtM4t96bciYk5EPBgRu0bE5RHxUEScWqwzpm5f+2Cx731T94ojYnJE/CYi7oyIn0bEOkX5aRHxQHGM+GYb20KSBgP71upIkZllx6AhLCKez8x1ImIi8PnM/EBRPh14S2aeGhFrArcBHwHeBvw3sH1mPlqsu1FmPh0RawF3AP+QmUu6tt2grg8DRwFTgDcXr9kd2Aa4EtgO+FNR5/GZeWvrW0KSqqtIGtwKvAm4EZiZmb/s2v8W6/wYuDQz/19EzAJuz8wvRMSxwBeAXYCngYeBnYB1gUeBvTLztog4F3ggM79ZvP7zwALgcmD/zHwhIr4ArAmcDfwa2DYzMyI2yMy/tqUxJKnC7Fur03kGhqpqMvDxiLgLuB3YGNi6WDa7awdb+ExE3A38Ftiybr1m9gIuzsylmfln4JfArnXbXpiZy4C7qJ1+J0lDWmY+Ty0BMR1YDMyMiCOA90XE7RFxLzCJWie1y1XF873A/Zn5RHEGxyPU9tUAj2fmbcX0hdT2z/XeDYwDbiuOB9OodbafBV4GzomIfwReHKj3Kkkdyr61OkJP1zlJZQng05l53QqFtWzyC93m9wX2yMwXi1/tRvaj3lfqppfid0SSAMjMpcAsYFaRsPgksCMwITMfj4hTWHH/27U/XcaK+9ZlvLFv7X4aaPf5AG7IzKnd44mI3YB9gEOAY6glUCRJjdm3VkfwDAxVxd+onU7c5TrgUxExAiAi3hkRazd43frAM8UOdltqv9Z1ea3r9d38Cji0uBZwFLA3MHtA3oUkdaCI2CYi6n+BGw/MK6b/UlxickgfNv3WqA0QCrXB5rqfVvxbYM+IeEcRx9rF8WAdYP3MvAb4LLVLUiRJb7BvrY5kBkxVcQ+wtDhd7Tzg29ROMbuzGOxnMfChBq+7FjgqIh6k1pn+bd2yGcA9EXFnZv5TXfkVwB7A3dR+7TshM58sdtKSpJWtA3w3IjYAXgfmU7uc5K/AfcCT1K55Xl3zgKO7xr8Avl+/MDMXF5eqXFxcsw3wZWod8ysjYiS1XxU/14e6JamT2bdWR3IQT0mS1HYRMQa4OjO3LzsWSZI0OHgJiSRJkiRJqjzPwJAkSZIkSZXnGRiSJEmSJKnyTGBIkiRJkqTKM4EhSZIkSZIqzwSGJEmSJEmqPBMYkiRJkiSp8v5/BQdRteCSOxYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1080x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3] = plt.subplots(1, 3,figsize=(15,4))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line, = ax3.plot([0,1],[0,1])\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-6,10])\n",
    "ax2.set_ylim([-1500,5000])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Test Loss\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax3.set_ylim([0,10])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "test_loss_store = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 143/800 [09:16<42:36,  3.89s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-25dcf4c74e33>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     82\u001b[0m                     \u001b[0mclear_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m                     \u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     84\u001b[0m                     \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/display.py\u001b[0m in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, *objs, **kwargs)\u001b[0m\n\u001b[1;32m    311\u001b[0m             \u001b[0mpublish_display_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    312\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m             \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minclude\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    314\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    315\u001b[0m                 \u001b[0;31m# nothing to display (e.g. _ipython_display_ took over)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mformat\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mmd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    179\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m                 \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    181\u001b[0m             \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    182\u001b[0m                 \u001b[0;31m# FIXME: log the exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<decorator-gen-2>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0;34m\"\"\"show traceback on failed format call\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    225\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    226\u001b[0m         \u001b[0;31m# don't warn on NotImplementedErrors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m    339\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    340\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mprinter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    342\u001b[0m             \u001b[0;31m# Finally look for special method names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    343\u001b[0m             \u001b[0mmethod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(fig)\u001b[0m\n\u001b[1;32m    246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    247\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'png'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 248\u001b[0;31m         \u001b[0mpng_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'png'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    249\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;34m'retina'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m'png2x'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mformats\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    250\u001b[0m         \u001b[0mpng_formatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfor_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mFigure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mretina_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, **kwargs)\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    131\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m     \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbytes_io\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    133\u001b[0m     \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbytes_io\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfmt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'svg'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m   2194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2195\u001b[0m                     bbox_inches = self.figure.get_tightbbox(\n\u001b[0;32m-> 2196\u001b[0;31m                         renderer, bbox_extra_artists=bbox_extra_artists)\n\u001b[0m\u001b[1;32m   2197\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0mpad_inches\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2198\u001b[0m                         \u001b[0mpad_inches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrcParams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'savefig.pad_inches'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36mget_tightbbox\u001b[0;34m(self, renderer, bbox_extra_artists)\u001b[0m\n\u001b[1;32m   2504\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2505\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0martists\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2506\u001b[0;31m             \u001b[0mbbox\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tightbbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2507\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mbbox\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbbox\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mbbox\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheight\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2508\u001b[0m                 \u001b[0mbb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbbox\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/axes/_base.py\u001b[0m in \u001b[0;36mget_tightbbox\u001b[0;34m(self, renderer, call_axes_locator, bbox_extra_artists, for_layout_only)\u001b[0m\n\u001b[1;32m   4201\u001b[0m                     \u001b[0;31m# this artist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4202\u001b[0m                     \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4203\u001b[0;31m             \u001b[0mbbox\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tightbbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   4204\u001b[0m             if (bbox is not None\n\u001b[1;32m   4205\u001b[0m                     \u001b[0;32mand\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mbbox\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/artist.py\u001b[0m in \u001b[0;36mget_tightbbox\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m    276\u001b[0m             \u001b[0mThe\u001b[0m \u001b[0menclosing\u001b[0m \u001b[0mbounding\u001b[0m \u001b[0mbox\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32min\u001b[0m \u001b[0mfigure\u001b[0m \u001b[0mpixel\u001b[0m \u001b[0mcoordinates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    277\u001b[0m         \"\"\"\n\u001b[0;32m--> 278\u001b[0;31m         \u001b[0mbbox\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_window_extent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    279\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_clip_on\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    280\u001b[0m             \u001b[0mclip_box\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_clip_box\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/spines.py\u001b[0m in \u001b[0;36mget_window_extent\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m    166\u001b[0m         \u001b[0;31m# make sure the location is updated so that transforms etc are correct:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    167\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_adjust_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m         \u001b[0mbb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_window_extent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    169\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    170\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mbb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/patches.py\u001b[0m in \u001b[0;36mget_window_extent\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m    596\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    597\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_window_extent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_extents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    600\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_convert_xy_units\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/path.py\u001b[0m in \u001b[0;36mget_extents\u001b[0;34m(self, transform, **kwargs)\u001b[0m\n\u001b[1;32m    588\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBbox\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    589\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 590\u001b[0;31m             \u001b[0mself\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    591\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcodes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    592\u001b[0m             \u001b[0mxys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/transforms.py\u001b[0m in \u001b[0;36mtransform_path\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m   1525\u001b[0m         \u001b[0mthat\u001b[0m \u001b[0mbegan\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mline\u001b[0m \u001b[0msegments\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1526\u001b[0m         \"\"\"\n\u001b[0;32m-> 1527\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform_path_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform_path_non_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1529\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mtransform_path_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/sr/lib/python3.7/site-packages/matplotlib/transforms.py\u001b[0m in \u001b[0;36mtransform_path_non_affine\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m   1537\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform_path_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1538\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1539\u001b[0;31m     \u001b[0;32mdef\u001b[0m \u001b[0mtransform_path_non_affine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1540\u001b[0m         \"\"\"\n\u001b[1;32m   1541\u001b[0m         \u001b[0mApply\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mnon\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0maffine\u001b[0m \u001b[0mpart\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mtransform\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPath\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "## CONFIRM q_list_test in validation/visualization in Akash's code\n",
    "\n",
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "# loss_crit = torch.nn.CrossEntropyLoss()\n",
    "loss_crit = torch.nn.functional.cross_entropy\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        model.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "            \n",
    "        logP = model(p_batch)\n",
    "        logQ = model(q_batch)\n",
    "        logM = model(m_batch)\n",
    "        \n",
    "        p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "        q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "        m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "        \n",
    "        loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 50 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "                    log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, m_batch, calc_true_kl=True)\n",
    "                    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch.unsqueeze(1).to(DEVICE)\n",
    "                    \n",
    "                    logP = model(p_batch)\n",
    "                    logQ = model(q_batch)\n",
    "                    logM = model(m_batch)\n",
    "\n",
    "                    log_ratio_p_q_from_cob = logP[:, 0] - logP[:, 1]\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob)\n",
    "                    \n",
    "                    log_ratio_p_q_from_cob = logM[:, 0] - logM[:, 1]\n",
    "\n",
    "                    p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "                    \n",
    "                    test_loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "\n",
    "                    # Visualize\n",
    "                    \n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store) )\n",
    "                    \n",
    "                    scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "                    scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)\n",
    "\n",
    "                    ax2.set_xlim( -25., 25. )\n",
    "                    ax2.set_ylim( -1000, 1000)\n",
    "            \n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    ax3.set_xlim( 0, len(test_loss_store) )\n",
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, ax2 = plt.subplots(1, 1,figsize=(6,4))\n",
    "\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True Log p/q, KL = '+str(np.around(true_kl_p_q.item(),2)),alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB Log p/q, KL = '+str(np.around(kl_from_cob.item(),2)),alpha=0.9,s=10.,c='r')\n",
    "\n",
    "scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)                    \n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-25,25])\n",
    "ax2.set_ylim([-1000,1000])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('../plots/cob_mu2.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sr",
   "language": "python",
   "name": "sr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
