{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "# Summary of results\n",
    "\n",
    "Here we report typical statistics for validation/test sets, e.g., ELBO, NLL, BPD, distortion (D) and rate (R)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 37
    },
    "id": "El49NgachrYj",
    "outputId": "867cd0cc-78dd-470c-c8d6-6b55d9daa8dd"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.8.1+cu102'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Dumky9Ut1vHJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributions as td\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple, OrderedDict, defaultdict\n",
    "from tqdm.auto import tqdm\n",
    "from itertools import chain\n",
    "from tabulate import tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "WDWb-ohmPxy_"
   },
   "outputs": [],
   "source": [
    "from components import GenerativeModel, InferenceModel, VAE\n",
    "from data import load_mnist\n",
    "from hparams import load_cfg, make_args\n",
    "from main import make_state, get_batcher, validate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from analysis import probe_prior, compare_marginals, compare_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "rng = np.random.RandomState(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "knn_model = pickle.load(open('knnclassifier.pickle', 'rb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helper code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from analysis import collect_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model and data\n",
    "\n",
    "* Load hyperparameters\n",
    "* Load model state\n",
    "* Load MNIST data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = load_mnist(\n",
    "    batch_size=100, \n",
    "    save_to='../tmp', \n",
    "    height=28, \n",
    "    width=28\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples_test = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mcategorical\u001b[0m/  \u001b[01;34mgaussian\u001b[0m/           \u001b[01;34mmixed-maxent\u001b[0m/\r\n",
      "\u001b[01;34mdirichlet\u001b[0m/    \u001b[01;34mgaussiansp-maxent\u001b[0m/  \u001b[01;34monehotcat\u001b[0m/\r\n"
     ]
    }
   ],
   "source": [
    "ls ../trained_models/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_results = defaultdict(list)\n",
    "test_results = defaultdict(list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs = [('gaussian', d, True) for d in pathlib.Path('../trained_models/gaussian/').iterdir() if d.is_dir()]\n",
    "dirs += [('dirichlet', d, True) for d in pathlib.Path('../trained_models/dirichlet/').iterdir() if d.is_dir()]\n",
    "dirs += [('categorical', d, True) for d in pathlib.Path('../trained_models/categorical').iterdir() if d.is_dir()]\n",
    "dirs += [('onehotcat', d, True) for d in pathlib.Path('../trained_models/onehotcat').iterdir() if d.is_dir()]\n",
    "dirs += [('mixed-maxent', d, True) for d in pathlib.Path('../trained_models/mixed-maxent/').iterdir() if d.is_dir()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('gaussian', PosixPath('../trained_models/gaussian/bumbling-haze-15'), True),\n",
       " ('gaussian', PosixPath('../trained_models/gaussian/crimson-yogurt-13'), True),\n",
       " ('gaussian',\n",
       "  PosixPath('../trained_models/gaussian/glamorous-music-12'),\n",
       "  True),\n",
       " ('gaussian', PosixPath('../trained_models/gaussian/noble-water-11'), True),\n",
       " ('gaussian', PosixPath('../trained_models/gaussian/dainty-wood-14'), True),\n",
       " ('dirichlet',\n",
       "  PosixPath('../trained_models/dirichlet/gallant-flower-19'),\n",
       "  True),\n",
       " ('dirichlet', PosixPath('../trained_models/dirichlet/dark-meadow-17'), True),\n",
       " ('dirichlet', PosixPath('../trained_models/dirichlet/firm-water-16'), True),\n",
       " ('dirichlet', PosixPath('../trained_models/dirichlet/icy-dust-18'), True),\n",
       " ('dirichlet', PosixPath('../trained_models/dirichlet/jumping-bee-20'), True),\n",
       " ('categorical',\n",
       "  PosixPath('../trained_models/categorical/chocolate-voice-30'),\n",
       "  True),\n",
       " ('categorical',\n",
       "  PosixPath('../trained_models/categorical/lively-smoke-26'),\n",
       "  True),\n",
       " ('categorical',\n",
       "  PosixPath('../trained_models/categorical/earthy-frog-29'),\n",
       "  True),\n",
       " ('categorical',\n",
       "  PosixPath('../trained_models/categorical/iconic-valley-28'),\n",
       "  True),\n",
       " ('categorical',\n",
       "  PosixPath('../trained_models/categorical/splendid-glade-27'),\n",
       "  True),\n",
       " ('onehotcat',\n",
       "  PosixPath('../trained_models/onehotcat/gallant-salad-22'),\n",
       "  True),\n",
       " ('onehotcat', PosixPath('../trained_models/onehotcat/peach-valley-25'), True),\n",
       " ('onehotcat',\n",
       "  PosixPath('../trained_models/onehotcat/pleasant-disco-24'),\n",
       "  True),\n",
       " ('onehotcat',\n",
       "  PosixPath('../trained_models/onehotcat/fluent-violet-23'),\n",
       "  True),\n",
       " ('onehotcat',\n",
       "  PosixPath('../trained_models/onehotcat/confused-sponge-21'),\n",
       "  True),\n",
       " ('mixed-maxent',\n",
       "  PosixPath('../trained_models/mixed-maxent/electric-firefly-5'),\n",
       "  True),\n",
       " ('mixed-maxent',\n",
       "  PosixPath('../trained_models/mixed-maxent/robust-wildflower-4'),\n",
       "  True),\n",
       " ('mixed-maxent',\n",
       "  PosixPath('../trained_models/mixed-maxent/leafy-sky-3'),\n",
       "  True),\n",
       " ('mixed-maxent',\n",
       "  PosixPath('../trained_models/mixed-maxent/lucky-grass-1'),\n",
       "  True),\n",
       " ('mixed-maxent',\n",
       "  PosixPath('../trained_models/mixed-maxent/smooth-armadillo-2'),\n",
       "  True)]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "daa78d6478444c7abc6d160eaecaf6a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: gaussian/bumbling-haze-15\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "879ee2e7e73f4dad8f0ede86d9ea077f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d686f91d5d74459699b8d989a591989e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: gaussian/crimson-yogurt-13\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa40a8bc319e4ebf86c1a0f113aed096",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e48abe61068c4c669fb7d3d37a7d7dc5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: gaussian/glamorous-music-12\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "694630e0108f4e8ca3456bc5e923b3f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "536a1ab2b1904eaaa5f8f061b5a275df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: gaussian/noble-water-11\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b94a4c86019b444e8abcda826eb3b1b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "614a8d2dd9f44ce7a5379a4eafbe97a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: gaussian/dainty-wood-14\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57d96033b7014d1ca8d5c1a3006fc9f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c2793c1cea74cc0ac8da1d213fdd5ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: dirichlet/gallant-flower-19\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b0dd7d0efe04eeea918fb26c7ad0259",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea9014d2a3084ccf8990fb4f5fa788e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: dirichlet/dark-meadow-17\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afe8515513ef489194401f93ffb95053",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e001c09e8a714be9ab8246cd077e6b9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: dirichlet/firm-water-16\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "254cc912f7c643aba6aa1073dc9f82d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2184e9e4e3c460a95c6c143c1d92fff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: dirichlet/icy-dust-18\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9bfc5c93ab94f0e9b122af020973bd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "165db8662a734e34b0e263c9622e765e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: dirichlet/jumping-bee-20\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c496e50c972c4104b66f5bf653d4ef75",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "073db4830bc34673ad33c084f62a01bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: categorical/chocolate-voice-30\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6835fcf5bd5c41a99946091d7d501cb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98d6c31c1f524e40b9cd0a431f79cd4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: categorical/lively-smoke-26\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13e0ac2de92340bf956ca39325eeb2bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4eae473925c54692856a6d54b0f4d9b5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: categorical/earthy-frog-29\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f39d15f6db434da58f0adcf8a2d9a39d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a57c5c3a27b4c798372d151c86c1a91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: categorical/iconic-valley-28\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecdc5c126ad146c19978ef638387dca1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "147370e557b24367907a219c05d3692f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: categorical/splendid-glade-27\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "848332d8ee1f4451a110ee3ad21edba0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4850f0c555ad460aac05e3e375498f4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: onehotcat/gallant-salad-22\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6771727371d84d9c8017df3895472b89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99b351ae540e46ca9d317acb3e001065",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: onehotcat/peach-valley-25\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "801b53c856564b6f8d3711bf8c290dd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d467462f2d04269aab94b95ab751e26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: onehotcat/pleasant-disco-24\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24e407db722d4a4fbe493a4101bab9fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ff8dd845a084437bf2c603dd787d904",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: onehotcat/fluent-violet-23\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e84c47feee224ce0a38857c631076522",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbb61830d1464acaa02afd061f5afabe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: onehotcat/confused-sponge-21\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6c21f65e1884a598f7f5ecaa7b3448e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09ee1fbc5dfa4af3b62124380c171783",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: mixed-maxent/electric-firefly-5\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3de2d7f63ca4f848b6efa837ea514a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d79bbfa599094bf8bb09627cbba8cd26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: mixed-maxent/robust-wildflower-4\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d341c9bed5e48cfa8710a8b1f1dc974",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aefdb37fb1814e4c92ba4aeedc7710cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: mixed-maxent/leafy-sky-3\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6110305489064882b1443456014a7154",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c5000c894634d4eab22dae6a45e576e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: mixed-maxent/lucky-grass-1\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4266452e220462fa8b009e6c8790965",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db5aa2059a924141a90b351f0a1a42b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Overriding device to user choice cuda:0\n",
      "Overriding data_dir to user choice ../tmp\n",
      "Experiment: mixed-maxent/smooth-armadillo-2\n",
      "Validating...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a8918248e9845b39fbabcad2095c69f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b745ea7d561743fb8daf441aa60a557b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for cls, directory, redo in tqdm(dirs):\n",
    "    if not redo:\n",
    "        continue\n",
    "    args = make_args(\n",
    "        load_cfg(\n",
    "            f\"{directory}/cfg.json\", \n",
    "            # use this to specify a decide for analysis\n",
    "            device='cuda:0',\n",
    "            # use this to change paths if you need\n",
    "            data_dir='../tmp',\n",
    "            # you don't really need to change the output_dir\n",
    "        )\n",
    "    )\n",
    "    experiment = directory.name\n",
    "    print(f\"Experiment: {cls}/{experiment}\")\n",
    "\n",
    "    state = make_state(\n",
    "        args, \n",
    "        device=args.device, \n",
    "        ckpt_path=f\"{directory}/ckpt.last\"\n",
    "    )\n",
    "        \n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "    torch.manual_seed(0)\n",
    "    rng = np.random.RandomState(0) \n",
    "        \n",
    "    print('Validating...')\n",
    "    val_metrics = validate(\n",
    "        state.vae, get_batcher(valid_loader, args), \n",
    "        num_samples=num_samples_test, \n",
    "        compute_DR=True,\n",
    "        progressbar=True,\n",
    "    )            \n",
    "    \n",
    "    print()\n",
    "                \n",
    "    r = [\n",
    "        val_metrics[0].numpy(),  # NLL\n",
    "        val_metrics[1].numpy(),  # BPD\n",
    "        val_metrics[2]['ELBO'].mean(),  # ELBO\n",
    "        val_metrics[2]['D'].mean(),  # D\n",
    "        val_metrics[2]['R'].mean(),  # R\n",
    "        val_metrics[2].get('R_F', np.zeros(1)).mean(),  # R\n",
    "        val_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R\n",
    "        val_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R\n",
    "        val_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R\n",
    "    ]\n",
    "    valid_results[cls].append(r)\n",
    "    \n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "    torch.manual_seed(0)\n",
    "    rng = np.random.RandomState(0) \n",
    "    \n",
    "\n",
    "    print('Testing...')\n",
    "    test_metrics = validate(\n",
    "        state.vae, get_batcher(test_loader, args), \n",
    "        num_samples=num_samples_test, \n",
    "        compute_DR=True,\n",
    "        progressbar=True,\n",
    "    )            \n",
    "\n",
    "    print()\n",
    "\n",
    "    r = [\n",
    "        test_metrics[0].numpy(),  # NLL\n",
    "        test_metrics[1].numpy(),  # BPD\n",
    "        test_metrics[2]['ELBO'].mean(),  # ELBO\n",
    "        test_metrics[2]['D'].mean(),  # D\n",
    "        test_metrics[2]['R'].mean(),  # R\n",
    "        test_metrics[2].get('R_F', np.zeros(1)).mean(),  # R\n",
    "        test_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R\n",
    "        test_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R\n",
    "        test_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R\n",
    "    ]\n",
    "    test_results[cls].append(r)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabulate import tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "headers = ['NLL', 'BPD', 'ELBO', 'D', 'R', 'R_F', 'R_Y|f', 'R_Y', 'R_Z']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation - Gaussian\n",
      "  NLL    BPD    ELBO      D      R    R_F    R_Y|f    R_Y    R_Z\n",
      "-----  -----  ------  -----  -----  -----  -------  -----  -----\n",
      "91.67  13.22  -96.93  77.03  19.90   0.00     0.00   0.00  19.90\n",
      "91.69  13.23  -96.90  76.90  20.00   0.00     0.00   0.00  20.00\n",
      "91.69  13.23  -97.06  77.16  19.90   0.00     0.00   0.00  19.90\n",
      "91.72  13.23  -96.94  76.98  19.96   0.00     0.00   0.00  19.96\n",
      "91.63  13.22  -97.03  77.17  19.86   0.00     0.00   0.00  19.86\n",
      "Validation - Dirichlet\n",
      "  NLL    BPD    ELBO      D      R    R_F    R_Y|f    R_Y    R_Z\n",
      "-----  -----  ------  -----  -----  -----  -------  -----  -----\n",
      "94.12  13.58  -98.70  78.58  20.12   0.00     0.00   0.00  20.12\n",
      "94.74  13.67  -99.45  79.56  19.89   0.00     0.00   0.00  19.89\n",
      "94.47  13.63  -99.66  79.29  20.38   0.00     0.00   0.00  20.38\n",
      "94.65  13.66  -99.37  79.26  20.11   0.00     0.00   0.00  20.11\n",
      "94.39  13.62  -99.24  79.07  20.17   0.00     0.00   0.00  20.17\n",
      "Validation - Categorical\n",
      "   NLL    BPD     ELBO       D     R    R_F    R_Y|f    R_Y    R_Z\n",
      "------  -----  -------  ------  ----  -----  -------  -----  -----\n",
      "166.01  23.95  -166.05  163.77  2.28   2.28     0.00   0.00   0.00\n",
      "166.11  23.96  -166.15  163.87  2.28   2.28     0.00   0.00   0.00\n",
      "166.20  23.98  -166.25  163.97  2.28   2.28     0.00   0.00   0.00\n",
      "166.23  23.98  -166.27  163.99  2.28   2.28     0.00   0.00   0.00\n",
      "166.16  23.97  -166.21  163.93  2.28   2.28     0.00   0.00   0.00\n",
      "Validation - GS-ST\n",
      "   NLL    BPD     ELBO       D     R    R_F    R_Y|f    R_Y    R_Z\n",
      "------  -----  -------  ------  ----  -----  -------  -----  -----\n",
      "167.91  24.22  -172.88  171.16  1.73   0.00     0.00   0.00   1.73\n",
      "167.76  24.20  -172.96  171.24  1.72   0.00     0.00   0.00   1.72\n",
      "167.50  24.17  -172.80  171.09  1.71   0.00     0.00   0.00   1.71\n",
      "168.24  24.27  -172.84  171.11  1.73   0.00     0.00   0.00   1.73\n",
      "167.73  24.20  -172.27  170.52  1.75   0.00     0.00   0.00   1.75\n",
      "Validation - Mixed Dir\n",
      "   NLL    BPD     ELBO      D      R    R_F    R_Y|f    R_Y    R_Z\n",
      "------  -----  -------  -----  -----  -----  -------  -----  -----\n",
      "107.70  15.54  -110.64  91.88  18.76   9.90     8.86   0.00   0.00\n",
      "106.60  15.38  -109.69  90.02  19.66  10.22     9.44   0.00   0.00\n",
      "106.75  15.40  -109.78  90.22  19.56  10.16     9.41   0.00   0.00\n",
      "106.87  15.42  -109.95  91.04  18.91   9.87     9.04   0.00   0.00\n",
      "107.69  15.54  -110.59  91.69  18.90   9.90     9.00   0.00   0.00\n"
     ]
    }
   ],
   "source": [
    "print(\"Validation - Gaussian\")\n",
    "print(tabulate(valid_results['gaussian'], headers=headers, floatfmt='.2f'))\n",
    "print(\"Validation - Dirichlet\")\n",
    "print(tabulate(valid_results['dirichlet'], headers=headers, floatfmt='.2f'))\n",
    "print(\"Validation - Categorical\")\n",
    "print(tabulate(valid_results['categorical'], headers=headers, floatfmt='.2f'))\n",
    "print(\"Validation - GS-ST\")\n",
    "print(tabulate(valid_results['onehotcat'], headers=headers, floatfmt='.2f'))\n",
    "print(\"Validation - Mixed Dir\")\n",
    "print(tabulate(valid_results['mixed-maxent'], headers=headers, floatfmt='.2f'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print('mixed-dir')\n",
    "# print(tabulate(\n",
    "#     [\n",
    "#         ['mean'] + [x for x in np.mean(valid_results['mixed-maxent'], 0)],\n",
    "#         ['std'] + [x for x in np.std(valid_results['mixed-maxent'], 0)],\n",
    "#         ['min'] + [x for x in np.min(valid_results['mixed-maxent'], 0)],\n",
    "#         ['max'] + [x for x in np.max(valid_results['mixed-maxent'], 0)]\n",
    "#     ], \n",
    "#     headers=headers, floatfmt='.2f'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model        Dataset    Statistic         D      R     NLL\n",
      "-----------  ---------  -----------  ------  -----  ------\n",
      "gaussian     valid      mean          77.05  19.93   91.68\n",
      "dirichlet    valid      mean          79.15  20.13   94.48\n",
      "categorical  valid      mean         163.90   2.28  166.14\n",
      "onehotcat    valid      mean         171.02   1.73  167.83\n",
      "mixed-dir    valid      mean          90.97  19.16  107.12\n",
      "gaussian     valid      stddev         0.10   0.05    0.03\n",
      "dirichlet    valid      stddev         0.33   0.15    0.22\n",
      "categorical  valid      stddev         0.08   0.00    0.08\n",
      "onehotcat    valid      stddev         0.26   0.01    0.25\n",
      "mixed-dir    valid      stddev         0.75   0.38    0.48\n",
      "gaussian     test       mean          76.67  19.94   91.12\n",
      "dirichlet    test       mean          78.62  19.94   93.81\n",
      "categorical  test       mean         164.72   2.28  166.95\n",
      "onehotcat    test       mean         171.76   1.70  168.50\n",
      "mixed-dir    test       mean          90.34  19.39  106.59\n",
      "gaussian     test       stddev         0.07   0.04    0.05\n",
      "dirichlet    test       stddev         0.37   0.19    0.24\n",
      "categorical  test       stddev         0.22   0.00    0.22\n",
      "onehotcat    test       stddev         0.25   0.01    0.11\n",
      "mixed-dir    test       stddev         0.72   0.36    0.47\n"
     ]
    }
   ],
   "source": [
    "idx = np.array([3, 4, 0], dtype=int)\n",
    "\n",
    "print(tabulate(\n",
    "    [\n",
    "        ['gaussian', 'valid', 'mean'] + [x for x in np.array(valid_results['gaussian'])[:,idx].mean(0)],        \n",
    "        ['dirichlet', 'valid', 'mean'] + [x for x in np.array(valid_results['dirichlet'])[:,idx].mean(0)],\n",
    "        ['categorical', 'valid', 'mean'] + [x for x in np.array(valid_results['categorical'])[:,idx].mean(0)],\n",
    "        ['onehotcat', 'valid', 'mean'] + [x for x in np.array(valid_results['onehotcat'])[:,idx].mean(0)],\n",
    "        ['mixed-dir', 'valid', 'mean'] + [x for x in np.array(valid_results['mixed-maxent'])[:,idx].mean(0)],        \n",
    "\n",
    "        ['gaussian', 'valid', 'stddev'] + [x for x in np.array(valid_results['gaussian'])[:,idx].std(0)],        \n",
    "        ['dirichlet', 'valid', 'stddev'] + [x for x in np.array(valid_results['dirichlet'])[:,idx].std(0)],\n",
    "        ['categorical', 'valid', 'stddev'] + [x for x in np.array(valid_results['categorical'])[:,idx].std(0)],\n",
    "        ['onehotcat', 'valid', 'stddev'] + [x for x in np.array(valid_results['onehotcat'])[:,idx].std(0)],\n",
    "        ['mixed-dir', 'valid', 'stddev'] + [x for x in np.array(valid_results['mixed-maxent'])[:,idx].std(0)],        \n",
    "\n",
    "        ['gaussian', 'test', 'mean'] + [x for x in np.array(test_results['gaussian'])[:,idx].mean(0)],\n",
    "        ['dirichlet', 'test', 'mean'] + [x for x in np.array(test_results['dirichlet'])[:,idx].mean(0)],\n",
    "        ['categorical', 'test', 'mean'] + [x for x in np.array(test_results['categorical'])[:,idx].mean(0)],\n",
    "        ['onehotcat', 'test', 'mean'] + [x for x in np.array(test_results['onehotcat'])[:,idx].mean(0)],\n",
    "        ['mixed-dir', 'test', 'mean'] + [x for x in np.array(test_results['mixed-maxent'])[:,idx].mean(0)],\n",
    "        \n",
    "        ['gaussian', 'test', 'stddev'] + [x for x in np.array(test_results['gaussian'])[:,idx].std(0)],\n",
    "        ['dirichlet', 'test', 'stddev'] + [x for x in np.array(test_results['dirichlet'])[:,idx].std(0)],\n",
    "        ['categorical', 'test', 'stddev'] + [x for x in np.array(test_results['categorical'])[:,idx].std(0)],\n",
    "        ['onehotcat', 'test', 'stddev'] + [x for x in np.array(test_results['onehotcat'])[:,idx].std(0)],\n",
    "        ['mixed-dir', 'test', 'stddev'] + [x for x in np.array(test_results['mixed-maxent'])[:,idx].std(0)],\n",
    "    ], \n",
    "    headers=['Model', 'Dataset', 'Statistic', 'D', 'R', 'NLL'], floatfmt='.2f'))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "mixed-rv-MNIST.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "toc-autonumbering": false,
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "026c452c13f74a0aa071ab5869027a7a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2992689582d94ea0b3e2a166461861d3",
      "placeholder": "​",
      "style": "IPY_MODEL_1158c2e933694df38bce84f1c1e4b8f7",
      "value": " 9913344/? [00:15&lt;00:00, 628891.26it/s]"
     }
    },
    "05389b71a6b545cf9b61d9f63f79c894": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "08101fefdba5474193cae33d3418a2e8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "0946de80201a4d1ba749a53cefa440a6": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "09a30ad3b16444309a2f2591382c5735": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0b74813a212940708ccbe6dc9c353813": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "0e0f5d2863544ecea35d0acf4f81f267": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_db95898f256f474d9264009bd6f5939f",
      "max": 9912422,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_18c5a2181a6b4979ba35dd9bc738ddf8",
      "value": 9912422
     }
    },
    "10e2682a1219417ca2371bc15a90622c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Epoch   3: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_cb2c5222e69d44aa8330691d61edfb61",
      "max": 275,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_213ac6a83e7e409cac2ee43c97e7f02d",
      "value": 275
     }
    },
    "1158c2e933694df38bce84f1c1e4b8f7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "18c5a2181a6b4979ba35dd9bc738ddf8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "1e328c49a4294095af40d05efc47b7e9": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Epoch   2: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_29788d5b56994274852a7f007ceff4aa",
      "max": 275,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_8c7267ec1fc54d9a9a363b3b9503d530",
      "value": 275
     }
    },
    "213ac6a83e7e409cac2ee43c97e7f02d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "29788d5b56994274852a7f007ceff4aa": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2992689582d94ea0b3e2a166461861d3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2ee410127a6540ba85b6d733f50a9df8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_687400af482f43239a161d7a6f2d7a97",
       "IPY_MODEL_5c2c665e1dfc4b2785cf4015421e2ddc"
      ],
      "layout": "IPY_MODEL_a0cd4f136d48492ba89ddc256d21ec54"
     }
    },
    "303a5282789449ca8e7ed903d268f9e1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "3101e39990db4580877323e9ee2954ea": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "32dde89004e541a78ee6d35af5d3f14b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_642b9f5fb1624f2c97344e8e7a6c5d20",
       "IPY_MODEL_a9ff1ca678d44b0b96e0aea1541c8359"
      ],
      "layout": "IPY_MODEL_05389b71a6b545cf9b61d9f63f79c894"
     }
    },
    "3440563d925941f3891b4b1f48d45319": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Epoch   1: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_fbb3221bb61348a38e978a3031d0fefe",
      "max": 275,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_a60c0e8b61b74630845a0c22a30b95df",
      "value": 275
     }
    },
    "3e0fdd4d4d874e4ba72106d0ccfa263f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4ed6b82f4b61457e996fc31e14789257": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "519071893e9f4ed9be9ecb55833d7d07": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5b7868306345485fb17658bcfd1c25e4": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5c2c665e1dfc4b2785cf4015421e2ddc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_519071893e9f4ed9be9ecb55833d7d07",
      "placeholder": "​",
      "style": "IPY_MODEL_0b74813a212940708ccbe6dc9c353813",
      "value": " 29696/? [00:02&lt;00:00, 11538.00it/s]"
     }
    },
    "5c954db7e71f4a669989e1ecf0d89988": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "642b9f5fb1624f2c97344e8e7a6c5d20": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f839264c20cd46feb8ce184e87e85416",
      "max": 4542,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_9328008d582947ee9a7dddef3d6571da",
      "value": 4542
     }
    },
    "6582cc71adc04d45b7687c571614931c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_fc73554bd1054f3b95d157ddaa22a8a9",
      "placeholder": "​",
      "style": "IPY_MODEL_3101e39990db4580877323e9ee2954ea",
      "value": " 1649664/? [00:01&lt;00:00, 1065497.77it/s]"
     }
    },
    "687400af482f43239a161d7a6f2d7a97": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_966328cab55848c6a9605f4828637fed",
      "max": 28881,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_ddfea72bf1334cc98c8f46f969ec2d50",
      "value": 28881
     }
    },
    "7f79b9e35755453c8fbd5a6f20d49498": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_a286bdc1c9a6485ba77b8d79726c4c3f",
       "IPY_MODEL_6582cc71adc04d45b7687c571614931c"
      ],
      "layout": "IPY_MODEL_4ed6b82f4b61457e996fc31e14789257"
     }
    },
    "8be95b95873544a79de9124a8ed267d8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "8c7267ec1fc54d9a9a363b3b9503d530": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "9049691356d04d3c9f79245a12e22c2d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_0e0f5d2863544ecea35d0acf4f81f267",
       "IPY_MODEL_026c452c13f74a0aa071ab5869027a7a"
      ],
      "layout": "IPY_MODEL_a52a3ce8dde147f6b45b684cb57c5eb5"
     }
    },
    "9328008d582947ee9a7dddef3d6571da": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "95c747e3a50e48089ba465949d499f61": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9613f82c5bee4d9889d755989607deed": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_10e2682a1219417ca2371bc15a90622c",
       "IPY_MODEL_ae0ee77c9452406ea76c2012113ac46f"
      ],
      "layout": "IPY_MODEL_95c747e3a50e48089ba465949d499f61"
     }
    },
    "966328cab55848c6a9605f4828637fed": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9bd35ce5090d4252aa65b19d5ea37b42": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a0cd4f136d48492ba89ddc256d21ec54": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a286bdc1c9a6485ba77b8d79726c4c3f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3e0fdd4d4d874e4ba72106d0ccfa263f",
      "max": 1648877,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_08101fefdba5474193cae33d3418a2e8",
      "value": 1648877
     }
    },
    "a52a3ce8dde147f6b45b684cb57c5eb5": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a60c0e8b61b74630845a0c22a30b95df": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "a61874200e1c43bc93cf5f38f4d1602a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a9ff1ca678d44b0b96e0aea1541c8359": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_09a30ad3b16444309a2f2591382c5735",
      "placeholder": "​",
      "style": "IPY_MODEL_8be95b95873544a79de9124a8ed267d8",
      "value": " 5120/? [00:11&lt;00:00, 462.47it/s]"
     }
    },
    "ae0ee77c9452406ea76c2012113ac46f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9bd35ce5090d4252aa65b19d5ea37b42",
      "placeholder": "​",
      "style": "IPY_MODEL_303a5282789449ca8e7ed903d268f9e1",
      "value": " 275/275 [00:22&lt;00:00, 12.43it/s, loss=167, LL=-142, KL_F=6.72, KL_Y=7.09, SFE_base_reward_mean=-149, SFE_base_reward_std=42.8, SFE_reward_mean=33.8, SFE_reward_std=42.8]"
     }
    },
    "cb2c5222e69d44aa8330691d61edfb61": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "cc6a4b8ee959491f8e6901d4b96d0435": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f41c9d0c678c4495b043220fae491099",
      "placeholder": "​",
      "style": "IPY_MODEL_d7d526b7a18741a3a888c2421aea347c",
      "value": " 275/275 [00:57&lt;00:00,  4.82it/s, loss=268, LL=-158, KL_F=4.8, KL_Y=5.79, SFE_base_reward_mean=-164, SFE_base_reward_std=45.1, SFE_reward_mean=33.8, SFE_reward_std=45.1]"
     }
    },
    "d7d526b7a18741a3a888c2421aea347c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "db95898f256f474d9264009bd6f5939f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ddfea72bf1334cc98c8f46f969ec2d50": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "e6235f86d6144ccba6f6459586bf8b9b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_3440563d925941f3891b4b1f48d45319",
       "IPY_MODEL_f295266aeaca4088b34f2c9424f2a8a7"
      ],
      "layout": "IPY_MODEL_a61874200e1c43bc93cf5f38f4d1602a"
     }
    },
    "f295266aeaca4088b34f2c9424f2a8a7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_5b7868306345485fb17658bcfd1c25e4",
      "placeholder": "​",
      "style": "IPY_MODEL_5c954db7e71f4a669989e1ecf0d89988",
      "value": " 275/275 [01:25&lt;00:00,  3.22it/s, loss=384, LL=-184, KL_F=2.6, KL_Y=2.67, SFE_base_reward_mean=-187, SFE_base_reward_std=44.8, SFE_reward_mean=32.7, SFE_reward_std=44.8]"
     }
    },
    "f41c9d0c678c4495b043220fae491099": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f839264c20cd46feb8ce184e87e85416": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fa7034d83d3b43c2a181b51a88db3fec": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_1e328c49a4294095af40d05efc47b7e9",
       "IPY_MODEL_cc6a4b8ee959491f8e6901d4b96d0435"
      ],
      "layout": "IPY_MODEL_0946de80201a4d1ba749a53cefa440a6"
     }
    },
    "fbb3221bb61348a38e978a3031d0fefe": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fc73554bd1054f3b95d157ddaa22a8a9": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
