{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import copy\n",
    "import pdb\n",
    "import time\n",
    "import pickle\n",
    "import warnings\n",
    "sys.path.append('../../')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import pandas as pd\n",
    "import gpvae\n",
    "import data.eeg\n",
    "import gpytorch\n",
    "import wbml\n",
    "\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from torch.utils.data import DataLoader\n",
    "from experiments.eeg.train_eeg import train_eeg\n",
    "\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load EEG data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load from Wessel's brilliant wbml package.\n",
    "_, train, test = data.eeg.load()\n",
    "\n",
    "# Extract data into numpy arrays.\n",
    "x = np.array(train.index)\n",
    "y = np.array(train)\n",
    "y_dim = y.shape[1]\n",
    "decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))\n",
    "\n",
    "# Normalise observations.\n",
    "y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)\n",
    "y = (y - y_mean) / y_std\n",
    "\n",
    "# Convert to torch.tensor.\n",
    "x = torch.tensor(x)\n",
    "y = torch.tensor(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shared hyperparameters.\n",
    "args = {'batch_size': 100,\n",
    "        'latent_dim': 5,\n",
    "        'init_lengthscale': .1,\n",
    "        'init_scale': 1.,\n",
    "        'init_period': .1,\n",
    "        'auxiliary_dim': 1,\n",
    "        'num_inducing': 100,\n",
    "        'epochs': 2500,\n",
    "        'cache_freq': 50,\n",
    "        'lr': 0.001,\n",
    "        'decoder_scale': decoder_scale,\n",
    "        'elbo_samples': 100,\n",
    "        'test_samples': 100\n",
    "       }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = gpvae.utils.dataset_utils.TupleDataset(x, y, contains_nan=True)\n",
    "loader = DataLoader(dataset, batch_size=args['batch_size'], shuffle=True)\n",
    "\n",
    "# For models using the VFE approximation.\n",
    "vfe_loader = DataLoader(dataset, batch_size=len(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Latent GP kernel.\n",
    "k1 = gpvae.kernels.RBFKernel(lengthscale=args['init_lengthscale'], \n",
    "                             scale=args['init_scale']/2)\n",
    "k2 = gpvae.kernels.PeriodicKernel(lengthscale=args['init_lengthscale'],\n",
    "                                  period=args['init_period'], \n",
    "                                  scale=args['init_scale']/2)\n",
    "\n",
    "# k1 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "# k2 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel())\n",
    "# k1.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k1._set_outputscale(args['init_scale'] / 2)\n",
    "# k2.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k2.base_kernel._set_period_length(args['init_lengthscale'])\n",
    "# k2._set_outputscale(args['init_scale'] / 2)\n",
    "\n",
    "kernel = gpvae.kernels.AdditiveKernel(k1, k2)\n",
    "\n",
    "# Likelihood hyperparameters.\n",
    "likelihood_args = {'in_dim': args['latent_dim'],\n",
    "                'out_dim': y_dim,\n",
    "                'hidden_dims': [20],\n",
    "                'sigma': 0.1,\n",
    "                'train_sigma': True\n",
    "               }\n",
    "\n",
    "linear_likelihood_args = {'in_dim': args['latent_dim'], \n",
    "                          'out_dim': y_dim,\n",
    "                          'sigma': 0.1}\n",
    "\n",
    "# Zero imputation inference network hyperparameters.\n",
    "zi_network_args = {'in_dim': y_dim, \n",
    "                   'out_dim': args['latent_dim'],\n",
    "                   'hidden_dims': [20, 20, 20],\n",
    "                   'initial_sigma': .1,\n",
    "                   'initial_mu': 0.,\n",
    "                   'min_sigma': 0.01\n",
    "                  }\n",
    "\n",
    "\n",
    "# Product of Gaussians inference network hyperparameters.\n",
    "pog_network_args = {'in_dim': y_dim, \n",
    "                    'out_dim': args['latent_dim'],\n",
    "                    'hidden_dims': [20, 20],\n",
    "                    'initial_sigma': .1,\n",
    "                    'initial_mu': 0.,\n",
    "                    'min_sigma': 0.01\n",
    "                   }\n",
    "\n",
    "# Semi-amortised DeepSet inference network hyperparameters.\n",
    "sads_network_args = {'in_dim': y_dim, \n",
    "                     'out_dim': args['latent_dim'],\n",
    "                     'middle_dim': 20,\n",
    "                     'hidden_dims': [20],\n",
    "                     'shared_hidden_dims': [20],\n",
    "                     'initial_sigma': 1.,\n",
    "                     'initial_mu': 0.,\n",
    "                     'min_sigma': 0.01\n",
    "                    }\n",
    "\n",
    "# PointNet inference network hyperparameters.\n",
    "pn_network_args = {'out_dim': args['latent_dim'],\n",
    "                   'middle_dim': 20,\n",
    "                   'first_hidden_dims': [20],\n",
    "                   'second_hidden_dims': [20],\n",
    "                   'initial_sigma': 1.,\n",
    "                   'initial_mu': 0.,\n",
    "                   'min_sigma': 0.01\n",
    "                 }\n",
    "\n",
    "# Sparse GP hyperparameters.\n",
    "sgp_network_args = {'out_dim': args['latent_dim']}\n",
    "\n",
    "# Sparse GP initial inducing points.\n",
    "z = kmeans2(x, args['num_inducing'], minit='points')[0]\n",
    "z = torch.tensor(z)\n",
    "\n",
    "# Amortised sparse GP hyperparameters.\n",
    "amortised_sgp_network_args = {'in_dim': y_dim, \n",
    "                              'out_dim': args['latent_dim'], \n",
    "                              'hidden_dims': [20, 20, 20],\n",
    "                              'inducing_spacing': 0.025,\n",
    "                              'k': 5,\n",
    "                              'contains_nan': True\n",
    "                             }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct GPVAE models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Product of Gaussians inference network.\n",
    "pog_args = copy.deepcopy(args)\n",
    "pog_args['model'] = 'pog'\n",
    "pog_args['inference_network_args'] = pog_network_args\n",
    "pog_args['likelihood_args'] = likelihood_args\n",
    "pog_network = gpvae.networks.FactorNet(**pog_network_args)\n",
    "pog_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "pog_model = gpvae.models.GPVAE(pog_network, pog_likelihood, args['latent_dim'], kernel)\n",
    "\n",
    "# Semi-amortised DeepSet inference network.\n",
    "# sads_args = copy.deepcopy(args)\n",
    "# sads_args['model'] = 'sads'\n",
    "# sads_args['inference_network_args'] = sads_network_args\n",
    "# sads_args['likelihood_args'] = likelihood_args\n",
    "# sads_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# sads_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# sads_model = gpvae.models.GPVAE(sads_network, sads_likelihood, args['latent_dim'], kernel)\n",
    "\n",
    "# PointNet inference network.\n",
    "# pn_args = copy.deepcopy(args)\n",
    "# pn_args['model'] = 'pn'\n",
    "# pn_args['inference_network_args'] = pn_network_args\n",
    "# pn_args['likelihood_args'] = likelihood_args\n",
    "# pn_network = gpvae.networks.PointNet(**pn_network_args)\n",
    "# pn_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# pn_model = gpvae.models.GPVAE(pn_network, pn_likelihood, args['latent_dim'], kernel, add_jitter=False)\n",
    "\n",
    "# Zero imputation inference network.\n",
    "# zi_args = copy.deepcopy(args)\n",
    "# zi_args['model'] = 'zi'\n",
    "# zi_args['inference_network_args'] = zi_network_args\n",
    "# zi_args['likelihood_args'] = likelihood_args\n",
    "# zi_network = gpvae.networks.LinearGaussian(**zi_network_args)\n",
    "# zi_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# zi_model = gpvae.models.GPVAE(zi_network, zi_likelihood, args['latent_dim'], kernel, add_jitter=False)\n",
    "\n",
    "# Semi-amortised DeepSet inference network with linear likelihood.\n",
    "# ll_args = copy.deepcopy(args)\n",
    "# ll_args['model'] = 'linear_sads'\n",
    "# ll_args['inference_network_args'] = sads_network_args\n",
    "# ll_args['likelihood_args'] = ll_args\n",
    "# ll_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# ll_likelihood = AffineGaussian(**linear_likelihood_args)\n",
    "# ll_model = gpvae.models.GPVAE(ll_network, ll_likelihood, args['latent_dim'], kernel)\n",
    "\n",
    "# Semi-amortised DeepSet VAE model.\n",
    "# vae_args = copy.deepcopy(args)\n",
    "# vae_args['model'] = 'vae_sads'\n",
    "# vae_args['inference_network_args'] = sads_network_args\n",
    "# vae_args['likelihood_args'] = likelihood_args\n",
    "# vae_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# vae_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# vae_model = gpvae.models.VAE(vae_network, vae_likelihood, args['latent_dim'])\n",
    "\n",
    "# vfe_args = copy.deepcopy(args)\n",
    "# vfe_args['name'] = 'vfe'\n",
    "# vfe_args['encoder_args'] = vfe_encoder_args\n",
    "# vfe_args['decoder_args'] = decoder_args\n",
    "# vfe_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "# vfe_model = gpvae.models.TitsiasSparseGPVAE(vfe_decoder, args['latent_dim'], kernel, z, min_sigma=0.001, initial_sigma=0.01)\n",
    "\n",
    "# r_vfe_args = copy.deepcopy(args)\n",
    "# r_vfe_args['name'] = 'r_vfe'\n",
    "# r_vfe_args['encoder_args'] = r_vfe_encoder_args\n",
    "# r_vfe_args['decoder_args'] = decoder_args\n",
    "# r_vfe_encoder = gpvae.networks.FixedSparseNet(**r_vfe_encoder_args)\n",
    "# r_vfe_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "# r_vfe_model = gpvae.models.SparseGPVAE(r_vfe_encoder, r_vfe_decoder, args['latent_dim'], kernel=kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/2500 [00:01<56:12,  1.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 372.979\n",
      "ELBO: -73219.839\n",
      "IWAE: -68433.275\n",
      "SMSE: 1.718\n",
      "SMLL: 6.118\n",
      "MLL: 8.326\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:07<08:31,  4.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 86.210\n",
      "ELBO: -18198.918\n",
      "IWAE: -17963.824\n",
      "SMSE: 1.285\n",
      "SMLL: 7.366\n",
      "MLL: 9.574\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 102/2500 [00:13<10:17,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 16.188\n",
      "ELBO: -3203.745\n",
      "IWAE: -3101.826\n",
      "SMSE: 0.562\n",
      "SMLL: 2.075\n",
      "MLL: 4.283\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 152/2500 [00:19<11:35,  3.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 9.139\n",
      "ELBO: -1892.211\n",
      "IWAE: -1819.591\n",
      "SMSE: 0.501\n",
      "SMLL: 1.505\n",
      "MLL: 3.713\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 202/2500 [00:25<08:10,  4.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 5.521\n",
      "ELBO: -1072.493\n",
      "IWAE: -1015.943\n",
      "SMSE: 0.541\n",
      "SMLL: 1.458\n",
      "MLL: 3.666\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:31<10:38,  3.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 4.070\n",
      "ELBO: -907.346\n",
      "IWAE: -851.122\n",
      "SMSE: 0.505\n",
      "SMLL: 0.998\n",
      "MLL: 3.206\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 302/2500 [00:41<36:25,  1.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 3.094\n",
      "ELBO: -786.933\n",
      "IWAE: -734.878\n",
      "SMSE: 0.471\n",
      "SMLL: 0.791\n",
      "MLL: 3.000\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [00:47<07:21,  4.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 2.678\n",
      "ELBO: -727.565\n",
      "IWAE: -666.153\n",
      "SMSE: 0.413\n",
      "SMLL: 0.452\n",
      "MLL: 2.660\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [00:53<09:03,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 1.979\n",
      "ELBO: -594.780\n",
      "IWAE: -394565159.757\n",
      "SMSE: 0.388\n",
      "SMLL: 0.411\n",
      "MLL: 2.619\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 452/2500 [00:59<06:59,  4.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: 0.532\n",
      "ELBO: -455.118\n",
      "IWAE: -416.546\n",
      "SMSE: 0.339\n",
      "SMLL: 0.299\n",
      "MLL: 2.508\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 502/2500 [01:05<08:34,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: -0.513\n",
      "ELBO: -382.831\n",
      "IWAE: -338.709\n",
      "SMSE: 0.328\n",
      "SMLL: 0.306\n",
      "MLL: 2.514\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 552/2500 [01:12<11:42,  2.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: -1.980\n",
      "ELBO: -286.078\n",
      "IWAE: -234.947\n",
      "SMSE: 0.314\n",
      "SMLL: 0.259\n",
      "MLL: 2.467\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 602/2500 [01:18<10:36,  2.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: -3.242\n",
      "ELBO: -225.863\n",
      "IWAE: -196.258\n",
      "SMSE: 0.285\n",
      "SMLL: 0.065\n",
      "MLL: 2.273\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 652/2500 [01:24<10:54,  2.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: -3.980\n",
      "ELBO: -163.053\n",
      "IWAE: -124.856\n",
      "SMSE: 0.276\n",
      "SMLL: 0.031\n",
      "MLL: 2.240\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 703/2500 [01:31<06:35,  4.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: -3.325\n",
      "ELBO: -107.466\n",
      "IWAE: -71.741\n",
      "SMSE: 0.247\n",
      "SMLL: -0.197\n",
      "MLL: 2.011\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 752/2500 [01:37<12:13,  2.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: -4.491\n",
      "ELBO: -14.692\n",
      "IWAE: 21.815\n",
      "SMSE: 0.222\n",
      "SMLL: -0.327\n",
      "MLL: 1.881\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 802/2500 [01:44<11:28,  2.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 3106784.147\n",
      "ELBO: 59.041\n",
      "IWAE: 89.371\n",
      "SMSE: 0.243\n",
      "SMLL: -0.169\n",
      "MLL: 2.039\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 851/2500 [01:52<16:24,  1.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: -3.958\n",
      "ELBO: 151.206\n",
      "IWAE: 182.453\n",
      "SMSE: 0.234\n",
      "SMLL: -0.252\n",
      "MLL: 1.957\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 902/2500 [02:00<10:28,  2.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: -4.114\n",
      "ELBO: 207.159\n",
      "IWAE: 225.633\n",
      "SMSE: 0.224\n",
      "SMLL: -0.260\n",
      "MLL: 1.948\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 952/2500 [02:08<11:16,  2.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: -4.268\n",
      "ELBO: 252.759\n",
      "IWAE: 273.573\n",
      "SMSE: 0.254\n",
      "SMLL: -0.080\n",
      "MLL: 2.128\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1002/2500 [02:14<05:45,  4.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: -4.435\n",
      "ELBO: 255.737\n",
      "IWAE: 272.816\n",
      "SMSE: 0.219\n",
      "SMLL: -0.259\n",
      "MLL: 1.950\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1051/2500 [02:21<12:02,  2.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: -4.389\n",
      "ELBO: 263.848\n",
      "IWAE: 290.056\n",
      "SMSE: 0.258\n",
      "SMLL: -0.023\n",
      "MLL: 2.186\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1102/2500 [02:27<04:59,  4.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: -4.948\n",
      "ELBO: 278.263\n",
      "IWAE: 304.380\n",
      "SMSE: 0.239\n",
      "SMLL: -0.082\n",
      "MLL: 2.127\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1152/2500 [02:32<05:50,  3.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: -4.411\n",
      "ELBO: 273.572\n",
      "IWAE: 293.286\n",
      "SMSE: 0.215\n",
      "SMLL: -0.312\n",
      "MLL: 1.896\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1202/2500 [02:39<05:45,  3.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: -4.799\n",
      "ELBO: 283.204\n",
      "IWAE: 307.438\n",
      "SMSE: 0.257\n",
      "SMLL: -0.034\n",
      "MLL: 2.175\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1252/2500 [02:45<05:14,  3.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: -4.481\n",
      "ELBO: 288.417\n",
      "IWAE: 303.473\n",
      "SMSE: 0.242\n",
      "SMLL: -0.108\n",
      "MLL: 2.100\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1302/2500 [02:51<06:05,  3.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: -4.757\n",
      "ELBO: 298.919\n",
      "IWAE: 319.909\n",
      "SMSE: 0.273\n",
      "SMLL: 0.090\n",
      "MLL: 2.298\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1353/2500 [02:58<04:05,  4.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: -5.255\n",
      "ELBO: 292.777\n",
      "IWAE: 309.557\n",
      "SMSE: 0.264\n",
      "SMLL: 0.015\n",
      "MLL: 2.223\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1402/2500 [03:04<04:50,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: -4.480\n",
      "ELBO: 305.390\n",
      "IWAE: 321.562\n",
      "SMSE: 0.283\n",
      "SMLL: 0.081\n",
      "MLL: 2.289\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1452/2500 [03:10<05:55,  2.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: -4.506\n",
      "ELBO: 292.958\n",
      "IWAE: 306.593\n",
      "SMSE: 0.262\n",
      "SMLL: -0.037\n",
      "MLL: 2.172\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1502/2500 [03:16<04:17,  3.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: -4.697\n",
      "ELBO: 309.976\n",
      "IWAE: 323.515\n",
      "SMSE: 0.254\n",
      "SMLL: -0.093\n",
      "MLL: 2.115\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1552/2500 [03:22<04:15,  3.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: -4.925\n",
      "ELBO: 315.505\n",
      "IWAE: 332.734\n",
      "SMSE: 0.249\n",
      "SMLL: -0.119\n",
      "MLL: 2.089\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1602/2500 [03:27<03:53,  3.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: -4.684\n",
      "ELBO: 317.764\n",
      "IWAE: 332.821\n",
      "SMSE: 0.251\n",
      "SMLL: -0.082\n",
      "MLL: 2.127\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1652/2500 [03:33<04:51,  2.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: -4.863\n",
      "ELBO: 301.985\n",
      "IWAE: 316.126\n",
      "SMSE: 0.228\n",
      "SMLL: -0.256\n",
      "MLL: 1.952\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1702/2500 [03:40<04:01,  3.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: -4.847\n",
      "ELBO: 306.950\n",
      "IWAE: 323.236\n",
      "SMSE: 0.247\n",
      "SMLL: -0.185\n",
      "MLL: 2.023\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1753/2500 [03:53<04:43,  2.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: -5.226\n",
      "ELBO: 298.863\n",
      "IWAE: 315.296\n",
      "SMSE: 0.270\n",
      "SMLL: -0.055\n",
      "MLL: 2.153\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1803/2500 [03:59<03:57,  2.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: -4.728\n",
      "ELBO: 307.590\n",
      "IWAE: 319.104\n",
      "SMSE: 0.279\n",
      "SMLL: 0.066\n",
      "MLL: 2.274\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1852/2500 [04:05<02:18,  4.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: -4.832\n",
      "ELBO: 290.207\n",
      "IWAE: 305.322\n",
      "SMSE: 0.251\n",
      "SMLL: -0.154\n",
      "MLL: 2.054\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1901/2500 [04:14<08:26,  1.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: -4.567\n",
      "ELBO: 321.090\n",
      "IWAE: 333.548\n",
      "SMSE: 0.232\n",
      "SMLL: -0.255\n",
      "MLL: 1.953\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1953/2500 [04:22<03:03,  2.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: -4.664\n",
      "ELBO: 318.312\n",
      "IWAE: 330.089\n",
      "SMSE: 0.307\n",
      "SMLL: 0.179\n",
      "MLL: 2.387\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2002/2500 [04:29<03:54,  2.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: -4.401\n",
      "ELBO: 322.498\n",
      "IWAE: 332.287\n",
      "SMSE: 0.258\n",
      "SMLL: -0.106\n",
      "MLL: 2.102\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2052/2500 [04:36<03:12,  2.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: -4.731\n",
      "ELBO: 314.841\n",
      "IWAE: 331.929\n",
      "SMSE: 0.275\n",
      "SMLL: -0.019\n",
      "MLL: 2.189\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2102/2500 [04:43<01:29,  4.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: -4.957\n",
      "ELBO: 316.918\n",
      "IWAE: 330.360\n",
      "SMSE: 0.245\n",
      "SMLL: -0.177\n",
      "MLL: 2.032\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2153/2500 [04:49<01:13,  4.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: -5.212\n",
      "ELBO: 313.111\n",
      "IWAE: 323.315\n",
      "SMSE: 0.267\n",
      "SMLL: -0.055\n",
      "MLL: 2.153\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2203/2500 [04:57<02:17,  2.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: -4.393\n",
      "ELBO: 301.406\n",
      "IWAE: 316.558\n",
      "SMSE: 0.241\n",
      "SMLL: -0.212\n",
      "MLL: 1.997\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2252/2500 [05:04<01:24,  2.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: -4.613\n",
      "ELBO: 315.009\n",
      "IWAE: 331.153\n",
      "SMSE: 0.288\n",
      "SMLL: 0.026\n",
      "MLL: 2.234\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2302/2500 [05:10<01:15,  2.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: -4.544\n",
      "ELBO: 319.934\n",
      "IWAE: 334.905\n",
      "SMSE: 0.297\n",
      "SMLL: 0.132\n",
      "MLL: 2.340\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2352/2500 [05:17<00:51,  2.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: -5.042\n",
      "ELBO: 324.806\n",
      "IWAE: 339.998\n",
      "SMSE: 0.299\n",
      "SMLL: 0.063\n",
      "MLL: 2.271\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2402/2500 [05:24<00:44,  2.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: -4.866\n",
      "ELBO: 319.147\n",
      "IWAE: 331.739\n",
      "SMSE: 0.277\n",
      "SMLL: 0.001\n",
      "MLL: 2.209\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2452/2500 [05:30<00:12,  3.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: -4.363\n",
      "ELBO: 316.547\n",
      "IWAE: 329.450\n",
      "SMSE: 0.292\n",
      "SMLL: 0.041\n",
      "MLL: 2.249\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [05:36<00:00,  7.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: -4.310\n",
      "ELBO: 312.765\n",
      "IWAE: 323.560\n",
      "SMSE: 0.288\n",
      "SMLL: 0.042\n",
      "MLL: 2.250\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [372.9792527329484,\n",
       "  86.21033029801174,\n",
       "  16.187968120849078,\n",
       "  9.138615276345584,\n",
       "  5.521169857548749,\n",
       "  4.07018896184815,\n",
       "  3.0942968256834056,\n",
       "  2.6779273181480967,\n",
       "  1.9794063260204282,\n",
       "  0.5318044181493702,\n",
       "  -0.5126243449712383,\n",
       "  -1.9804023151401802,\n",
       "  -3.2415290370439975,\n",
       "  -3.9800232087316547,\n",
       "  -3.3249900257108123,\n",
       "  -4.491147752191437,\n",
       "  3106784.146744447,\n",
       "  -3.9582524338501024,\n",
       "  -4.113657019926725,\n",
       "  -4.267877942246927,\n",
       "  -4.435207767778145,\n",
       "  -4.389148686333461,\n",
       "  -4.9484804828326245,\n",
       "  -4.410675610769143,\n",
       "  -4.799289792273887,\n",
       "  -4.48137120556595,\n",
       "  -4.756539419781194,\n",
       "  -5.25465718114861,\n",
       "  -4.480332883069174,\n",
       "  -4.505688917636081,\n",
       "  -4.697011336391275,\n",
       "  -4.924728545898187,\n",
       "  -4.683900065934285,\n",
       "  -4.86344928624679,\n",
       "  -4.8467986121490725,\n",
       "  -5.225777723973668,\n",
       "  -4.727872379752873,\n",
       "  -4.832188259797167,\n",
       "  -4.566713261554229,\n",
       "  -4.663916380747164,\n",
       "  -4.400626312842184,\n",
       "  -4.731165682744213,\n",
       "  -4.957273226333461,\n",
       "  -5.212383481081656,\n",
       "  -4.392774516055167,\n",
       "  -4.612882418718331,\n",
       "  -4.543517722819441,\n",
       "  -5.041725156679983,\n",
       "  -4.865671820671868,\n",
       "  -4.362892987676632,\n",
       "  -4.310105809287523],\n",
       " 'elbos': [tensor(-73219.8388, grad_fn=<AddBackward0>),\n",
       "  tensor(-18198.9175, grad_fn=<AddBackward0>),\n",
       "  tensor(-3203.7449, grad_fn=<AddBackward0>),\n",
       "  tensor(-1892.2108, grad_fn=<AddBackward0>),\n",
       "  tensor(-1072.4933, grad_fn=<AddBackward0>),\n",
       "  tensor(-907.3455, grad_fn=<AddBackward0>),\n",
       "  tensor(-786.9331, grad_fn=<AddBackward0>),\n",
       "  tensor(-727.5654, grad_fn=<AddBackward0>),\n",
       "  tensor(-594.7797, grad_fn=<AddBackward0>),\n",
       "  tensor(-455.1179, grad_fn=<AddBackward0>),\n",
       "  tensor(-382.8306, grad_fn=<AddBackward0>),\n",
       "  tensor(-286.0781, grad_fn=<AddBackward0>),\n",
       "  tensor(-225.8626, grad_fn=<AddBackward0>),\n",
       "  tensor(-163.0534, grad_fn=<AddBackward0>),\n",
       "  tensor(-107.4659, grad_fn=<AddBackward0>),\n",
       "  tensor(-14.6916, grad_fn=<AddBackward0>),\n",
       "  tensor(59.0414, grad_fn=<AddBackward0>),\n",
       "  tensor(151.2062, grad_fn=<AddBackward0>),\n",
       "  tensor(207.1592, grad_fn=<AddBackward0>),\n",
       "  tensor(252.7592, grad_fn=<AddBackward0>),\n",
       "  tensor(255.7366, grad_fn=<AddBackward0>),\n",
       "  tensor(263.8477, grad_fn=<AddBackward0>),\n",
       "  tensor(278.2626, grad_fn=<AddBackward0>),\n",
       "  tensor(273.5722, grad_fn=<AddBackward0>),\n",
       "  tensor(283.2038, grad_fn=<AddBackward0>),\n",
       "  tensor(288.4175, grad_fn=<AddBackward0>),\n",
       "  tensor(298.9187, grad_fn=<AddBackward0>),\n",
       "  tensor(292.7774, grad_fn=<AddBackward0>),\n",
       "  tensor(305.3899, grad_fn=<AddBackward0>),\n",
       "  tensor(292.9581, grad_fn=<AddBackward0>),\n",
       "  tensor(309.9758, grad_fn=<AddBackward0>),\n",
       "  tensor(315.5046, grad_fn=<AddBackward0>),\n",
       "  tensor(317.7642, grad_fn=<AddBackward0>),\n",
       "  tensor(301.9854, grad_fn=<AddBackward0>),\n",
       "  tensor(306.9501, grad_fn=<AddBackward0>),\n",
       "  tensor(298.8628, grad_fn=<AddBackward0>),\n",
       "  tensor(307.5900, grad_fn=<AddBackward0>),\n",
       "  tensor(290.2073, grad_fn=<AddBackward0>),\n",
       "  tensor(321.0899, grad_fn=<AddBackward0>),\n",
       "  tensor(318.3123, grad_fn=<AddBackward0>),\n",
       "  tensor(322.4982, grad_fn=<AddBackward0>),\n",
       "  tensor(314.8413, grad_fn=<AddBackward0>),\n",
       "  tensor(316.9184, grad_fn=<AddBackward0>),\n",
       "  tensor(313.1110, grad_fn=<AddBackward0>),\n",
       "  tensor(301.4063, grad_fn=<AddBackward0>),\n",
       "  tensor(315.0088, grad_fn=<AddBackward0>),\n",
       "  tensor(319.9336, grad_fn=<AddBackward0>),\n",
       "  tensor(324.8063, grad_fn=<AddBackward0>),\n",
       "  tensor(319.1474, grad_fn=<AddBackward0>),\n",
       "  tensor(316.5470, grad_fn=<AddBackward0>),\n",
       "  tensor(312.7647, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-68433.2753, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-17963.8236, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-3101.8255, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1819.5906, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1015.9432, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-851.1219, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-734.8784, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-666.1534, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-3.9457e+08, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-416.5462, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-338.7093, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-234.9466, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-196.2578, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-124.8559, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-71.7410, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(21.8145, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(89.3711, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(182.4532, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(225.6329, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(273.5727, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(272.8162, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(290.0562, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(304.3795, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(293.2856, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(307.4377, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(303.4732, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(319.9085, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(309.5573, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(321.5621, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(306.5930, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(323.5154, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(332.7340, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(332.8207, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(316.1261, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(323.2360, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(315.2964, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(319.1044, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(305.3221, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(333.5483, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(330.0885, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(332.2874, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(331.9286, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(330.3600, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(323.3153, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(316.5583, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(331.1531, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(334.9049, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(339.9985, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(331.7386, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(329.4499, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(323.5604, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [1.7184970509725297,\n",
       "  1.2852187939756277,\n",
       "  0.5624774301117913,\n",
       "  0.5007280621700388,\n",
       "  0.5414244023250063,\n",
       "  0.505480572670572,\n",
       "  0.47059147815424573,\n",
       "  0.4132495162664396,\n",
       "  0.388225930719623,\n",
       "  0.33892322003187275,\n",
       "  0.3275361503656529,\n",
       "  0.3141761333237218,\n",
       "  0.28512175559906733,\n",
       "  0.27566100394489035,\n",
       "  0.2473959386265497,\n",
       "  0.22210754374530609,\n",
       "  0.24253254874909738,\n",
       "  0.23372179241901855,\n",
       "  0.22387383253844476,\n",
       "  0.2535772438691106,\n",
       "  0.2191951417317314,\n",
       "  0.2582975817537758,\n",
       "  0.23880331741771077,\n",
       "  0.21458855230194362,\n",
       "  0.25671126366001557,\n",
       "  0.24187634338054778,\n",
       "  0.27294949792360845,\n",
       "  0.2642132304731411,\n",
       "  0.2826528089054085,\n",
       "  0.26208770099739565,\n",
       "  0.2544869898917039,\n",
       "  0.2494482006628087,\n",
       "  0.25061147943896467,\n",
       "  0.22837885494827245,\n",
       "  0.24650097827414155,\n",
       "  0.2697018027930497,\n",
       "  0.27876028918638207,\n",
       "  0.2511115473071923,\n",
       "  0.23205571048171122,\n",
       "  0.3074579322460111,\n",
       "  0.2576182810201752,\n",
       "  0.2752100675359602,\n",
       "  0.24518358831612677,\n",
       "  0.26707247494933306,\n",
       "  0.24129243511878862,\n",
       "  0.2881206016200029,\n",
       "  0.2974451122749238,\n",
       "  0.29895900095763167,\n",
       "  0.27682357768569527,\n",
       "  0.29218129667938997,\n",
       "  0.2877422034327377],\n",
       " 'smlls': [6.118189595863162,\n",
       "  7.36553370741314,\n",
       "  2.074791603137551,\n",
       "  1.5047448096133236,\n",
       "  1.4577595019545384,\n",
       "  0.9980492825561554,\n",
       "  0.7913261709172237,\n",
       "  0.4519708611007847,\n",
       "  0.41085009375600184,\n",
       "  0.2993777524782693,\n",
       "  0.30605148289479384,\n",
       "  0.25853229490961344,\n",
       "  0.06511776694222271,\n",
       "  0.031458880246278596,\n",
       "  -0.19678801558791767,\n",
       "  -0.32720899075115356,\n",
       "  -0.16941369977367668,\n",
       "  -0.25159435796433877,\n",
       "  -0.26018712261933546,\n",
       "  -0.07983920839482568,\n",
       "  -0.25864885434574875,\n",
       "  -0.022513225841279533,\n",
       "  -0.08168704214698486,\n",
       "  -0.31183926108379983,\n",
       "  -0.03373613318341081,\n",
       "  -0.10825137072218682,\n",
       "  0.08985660489347214,\n",
       "  0.015124021311168597,\n",
       "  0.08087312285708521,\n",
       "  -0.03656806281298756,\n",
       "  -0.09293918142278312,\n",
       "  -0.11921846908043172,\n",
       "  -0.0816227245038188,\n",
       "  -0.25623789227037547,\n",
       "  -0.1852921731705751,\n",
       "  -0.055293449864226275,\n",
       "  0.06591920887485266,\n",
       "  -0.1539012502619467,\n",
       "  -0.2548473847994556,\n",
       "  0.178608588174634,\n",
       "  -0.1059723308411226,\n",
       "  -0.018788745895008303,\n",
       "  -0.17665808420379303,\n",
       "  -0.055074944049539644,\n",
       "  -0.21152770918629185,\n",
       "  0.025715584731529573,\n",
       "  0.13158742708613222,\n",
       "  0.0628255505519167,\n",
       "  0.0007117634468647438,\n",
       "  0.04091647412539823,\n",
       "  0.04153411509557281],\n",
       " 'mlls': [8.32643118749953,\n",
       "  9.573775299049506,\n",
       "  4.283033194773919,\n",
       "  3.712986401249692,\n",
       "  3.6660010935909066,\n",
       "  3.206290874192524,\n",
       "  2.999567762553592,\n",
       "  2.660212452737153,\n",
       "  2.61909168539237,\n",
       "  2.5076193441146377,\n",
       "  2.514293074531162,\n",
       "  2.466773886545982,\n",
       "  2.2733593585785914,\n",
       "  2.239700471882647,\n",
       "  2.0114535760484507,\n",
       "  1.8810326008852147,\n",
       "  2.0388278918626916,\n",
       "  1.9566472336720295,\n",
       "  1.948054469017033,\n",
       "  2.1284023832415424,\n",
       "  1.9495927372906197,\n",
       "  2.185728365795089,\n",
       "  2.1265545494893834,\n",
       "  1.8964023305525686,\n",
       "  2.1745054584529577,\n",
       "  2.0999902209141816,\n",
       "  2.2980981965298404,\n",
       "  2.223365612947537,\n",
       "  2.2891147144934534,\n",
       "  2.1716735288233804,\n",
       "  2.115302410213585,\n",
       "  2.0890231225559366,\n",
       "  2.12661886713255,\n",
       "  1.9520036993659928,\n",
       "  2.0229494184657932,\n",
       "  2.1529481417721423,\n",
       "  2.2741608005112206,\n",
       "  2.0543403413744215,\n",
       "  1.9533942068369126,\n",
       "  2.3868501798110024,\n",
       "  2.1022692607952456,\n",
       "  2.1894528457413602,\n",
       "  2.0315835074325754,\n",
       "  2.1531666475868287,\n",
       "  1.9967138824500763,\n",
       "  2.233957176367898,\n",
       "  2.339829018722501,\n",
       "  2.271067142188285,\n",
       "  2.208953355083233,\n",
       "  2.2491580657617667,\n",
       "  2.2497757067319415]}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(ll_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, ll_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/2500 [00:01<35:16,  1.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 337.428\n",
      "ELBO: -69464.492\n",
      "IWAE: -69354.735\n",
      "SMSE: 1.448\n",
      "SMLL: 18.412\n",
      "MLL: 20.620\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 53/2500 [00:06<08:41,  4.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 17.541\n",
      "ELBO: -4730.429\n",
      "IWAE: -4620.958\n",
      "SMSE: 1.024\n",
      "SMLL: 2.712\n",
      "MLL: 4.921\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 103/2500 [00:11<08:07,  4.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 5.139\n",
      "ELBO: -1533.934\n",
      "IWAE: -1464.852\n",
      "SMSE: 0.461\n",
      "SMLL: 0.141\n",
      "MLL: 2.349\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 153/2500 [00:17<08:13,  4.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 3.488\n",
      "ELBO: -1279.025\n",
      "IWAE: -1216.118\n",
      "SMSE: 0.410\n",
      "SMLL: -0.083\n",
      "MLL: 2.125\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 203/2500 [00:22<09:01,  4.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 2.628\n",
      "ELBO: -1122.462\n",
      "IWAE: -1070.394\n",
      "SMSE: 0.395\n",
      "SMLL: -0.157\n",
      "MLL: 2.051\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:29<09:13,  4.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 1.768\n",
      "ELBO: -980.155\n",
      "IWAE: -937.494\n",
      "SMSE: 0.391\n",
      "SMLL: -0.208\n",
      "MLL: 2.000\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 303/2500 [00:35<08:45,  4.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 0.342\n",
      "ELBO: -807.345\n",
      "IWAE: -769.162\n",
      "SMSE: 0.378\n",
      "SMLL: -0.232\n",
      "MLL: 1.976\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [00:41<10:40,  3.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: -1.220\n",
      "ELBO: -596.848\n",
      "IWAE: -566.790\n",
      "SMSE: 0.437\n",
      "SMLL: -0.059\n",
      "MLL: 2.149\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [00:47<09:10,  3.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: -1.119\n",
      "ELBO: -482.028\n",
      "IWAE: -454.725\n",
      "SMSE: 0.502\n",
      "SMLL: 0.179\n",
      "MLL: 2.387\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 452/2500 [00:53<09:29,  3.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: -2.206\n",
      "ELBO: -394.379\n",
      "IWAE: -353.345\n",
      "SMSE: 0.524\n",
      "SMLL: 0.215\n",
      "MLL: 2.423\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 502/2500 [00:59<08:11,  4.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: -3.139\n",
      "ELBO: -315.142\n",
      "IWAE: -290.350\n",
      "SMSE: 0.473\n",
      "SMLL: 0.002\n",
      "MLL: 2.210\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 553/2500 [01:05<06:48,  4.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: -3.612\n",
      "ELBO: -249.194\n",
      "IWAE: -224.548\n",
      "SMSE: 0.383\n",
      "SMLL: -0.241\n",
      "MLL: 1.967\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 602/2500 [01:10<06:50,  4.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: -3.735\n",
      "ELBO: -188.450\n",
      "IWAE: -162.931\n",
      "SMSE: 0.348\n",
      "SMLL: -0.351\n",
      "MLL: 1.857\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 652/2500 [01:16<07:54,  3.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: -4.884\n",
      "ELBO: -148.352\n",
      "IWAE: -129.020\n",
      "SMSE: 0.314\n",
      "SMLL: -0.457\n",
      "MLL: 1.751\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 702/2500 [01:21<08:06,  3.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: -4.841\n",
      "ELBO: -128.149\n",
      "IWAE: -116.120\n",
      "SMSE: 0.311\n",
      "SMLL: -0.459\n",
      "MLL: 1.749\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 752/2500 [01:27<06:23,  4.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: -5.046\n",
      "ELBO: -89.971\n",
      "IWAE: -68.820\n",
      "SMSE: 0.269\n",
      "SMLL: -0.596\n",
      "MLL: 1.612\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 802/2500 [01:33<07:15,  3.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: -5.663\n",
      "ELBO: -80.039\n",
      "IWAE: -55.798\n",
      "SMSE: 0.249\n",
      "SMLL: -0.608\n",
      "MLL: 1.600\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 852/2500 [01:38<07:10,  3.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: -5.802\n",
      "ELBO: -51.518\n",
      "IWAE: -34.304\n",
      "SMSE: 0.238\n",
      "SMLL: -0.635\n",
      "MLL: 1.573\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 902/2500 [01:43<06:54,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: -6.777\n",
      "ELBO: -34.515\n",
      "IWAE: -10.077\n",
      "SMSE: 0.234\n",
      "SMLL: -0.652\n",
      "MLL: 1.556\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 952/2500 [01:49<06:53,  3.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: -5.130\n",
      "ELBO: -19.290\n",
      "IWAE: -7.475\n",
      "SMSE: 0.228\n",
      "SMLL: -0.667\n",
      "MLL: 1.541\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1002/2500 [01:54<06:47,  3.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: -6.233\n",
      "ELBO: -7.845\n",
      "IWAE: 14.316\n",
      "SMSE: 0.231\n",
      "SMLL: -0.607\n",
      "MLL: 1.601\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1052/2500 [02:00<06:12,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: -5.964\n",
      "ELBO: 2.047\n",
      "IWAE: 34.640\n",
      "SMSE: 0.238\n",
      "SMLL: -0.538\n",
      "MLL: 1.670\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1102/2500 [02:05<06:03,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: -6.347\n",
      "ELBO: 23.299\n",
      "IWAE: 42.390\n",
      "SMSE: 0.231\n",
      "SMLL: -0.529\n",
      "MLL: 1.680\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1152/2500 [02:11<05:43,  3.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: -7.670\n",
      "ELBO: 22.837\n",
      "IWAE: 46.173\n",
      "SMSE: 0.247\n",
      "SMLL: -0.455\n",
      "MLL: 1.754\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1202/2500 [02:16<05:52,  3.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: -6.702\n",
      "ELBO: 47.788\n",
      "IWAE: 71.939\n",
      "SMSE: 0.220\n",
      "SMLL: -0.550\n",
      "MLL: 1.658\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1252/2500 [02:21<05:24,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: -6.896\n",
      "ELBO: 44.135\n",
      "IWAE: 63.007\n",
      "SMSE: 0.234\n",
      "SMLL: -0.484\n",
      "MLL: 1.724\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1302/2500 [02:27<05:15,  3.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: -7.359\n",
      "ELBO: 56.790\n",
      "IWAE: 76.032\n",
      "SMSE: 0.250\n",
      "SMLL: -0.373\n",
      "MLL: 1.835\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1352/2500 [02:32<04:49,  3.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: -8.496\n",
      "ELBO: 72.082\n",
      "IWAE: 90.368\n",
      "SMSE: 0.242\n",
      "SMLL: -0.346\n",
      "MLL: 1.863\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1402/2500 [02:38<04:51,  3.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: -7.602\n",
      "ELBO: 69.306\n",
      "IWAE: 94.032\n",
      "SMSE: 0.260\n",
      "SMLL: -0.209\n",
      "MLL: 1.999\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1452/2500 [02:43<04:24,  3.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: -7.486\n",
      "ELBO: 86.029\n",
      "IWAE: 107.648\n",
      "SMSE: 0.246\n",
      "SMLL: -0.265\n",
      "MLL: 1.943\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1502/2500 [02:48<04:19,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: -7.321\n",
      "ELBO: 73.259\n",
      "IWAE: 92.594\n",
      "SMSE: 0.244\n",
      "SMLL: -0.274\n",
      "MLL: 1.934\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1552/2500 [02:54<04:10,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: -7.070\n",
      "ELBO: 81.938\n",
      "IWAE: 105.394\n",
      "SMSE: 0.241\n",
      "SMLL: -0.287\n",
      "MLL: 1.921\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1602/2500 [02:59<03:53,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: -7.432\n",
      "ELBO: 107.498\n",
      "IWAE: 118.671\n",
      "SMSE: 0.258\n",
      "SMLL: -0.169\n",
      "MLL: 2.039\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1652/2500 [03:04<03:49,  3.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: -7.376\n",
      "ELBO: 106.382\n",
      "IWAE: 119.902\n",
      "SMSE: 0.250\n",
      "SMLL: -0.194\n",
      "MLL: 2.014\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1702/2500 [03:10<03:34,  3.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: -7.751\n",
      "ELBO: 91.042\n",
      "IWAE: 112.494\n",
      "SMSE: 0.257\n",
      "SMLL: -0.141\n",
      "MLL: 2.067\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1752/2500 [03:15<03:13,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: -8.151\n",
      "ELBO: 105.945\n",
      "IWAE: 123.344\n",
      "SMSE: 0.239\n",
      "SMLL: -0.221\n",
      "MLL: 1.987\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1802/2500 [03:21<03:25,  3.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: -7.689\n",
      "ELBO: 116.956\n",
      "IWAE: 131.955\n",
      "SMSE: 0.250\n",
      "SMLL: -0.176\n",
      "MLL: 2.032\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1852/2500 [03:26<02:44,  3.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: -7.720\n",
      "ELBO: 112.825\n",
      "IWAE: 136.990\n",
      "SMSE: 0.248\n",
      "SMLL: -0.200\n",
      "MLL: 2.009\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1902/2500 [03:31<02:41,  3.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: -7.349\n",
      "ELBO: 122.915\n",
      "IWAE: 139.299\n",
      "SMSE: 0.263\n",
      "SMLL: -0.006\n",
      "MLL: 2.202\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1952/2500 [03:37<02:22,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: -7.062\n",
      "ELBO: 118.838\n",
      "IWAE: 136.965\n",
      "SMSE: 0.263\n",
      "SMLL: -0.056\n",
      "MLL: 2.152\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2002/2500 [03:42<02:09,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: -6.727\n",
      "ELBO: 122.701\n",
      "IWAE: 145.327\n",
      "SMSE: 0.274\n",
      "SMLL: 0.033\n",
      "MLL: 2.241\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2052/2500 [03:47<01:58,  3.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: -6.482\n",
      "ELBO: 129.157\n",
      "IWAE: 146.511\n",
      "SMSE: 0.281\n",
      "SMLL: 0.028\n",
      "MLL: 2.236\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2102/2500 [03:53<01:45,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: -7.549\n",
      "ELBO: 127.290\n",
      "IWAE: 144.782\n",
      "SMSE: 0.259\n",
      "SMLL: -0.067\n",
      "MLL: 2.141\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2152/2500 [03:58<01:30,  3.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: -7.860\n",
      "ELBO: 121.791\n",
      "IWAE: 139.080\n",
      "SMSE: 0.267\n",
      "SMLL: -0.062\n",
      "MLL: 2.146\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2202/2500 [04:05<02:20,  2.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: -7.895\n",
      "ELBO: 133.702\n",
      "IWAE: 153.320\n",
      "SMSE: 0.248\n",
      "SMLL: -0.128\n",
      "MLL: 2.081\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2252/2500 [04:12<01:38,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: -8.736\n",
      "ELBO: 131.997\n",
      "IWAE: 145.658\n",
      "SMSE: 0.261\n",
      "SMLL: 0.004\n",
      "MLL: 2.212\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2303/2500 [04:18<00:41,  4.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: -7.599\n",
      "ELBO: 146.246\n",
      "IWAE: 160.063\n",
      "SMSE: 0.248\n",
      "SMLL: -0.103\n",
      "MLL: 2.105\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2353/2500 [04:23<00:30,  4.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: -8.227\n",
      "ELBO: 146.433\n",
      "IWAE: 162.411\n",
      "SMSE: 0.274\n",
      "SMLL: 0.029\n",
      "MLL: 2.238\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2403/2500 [04:29<00:20,  4.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: -8.411\n",
      "ELBO: 146.620\n",
      "IWAE: 163.188\n",
      "SMSE: 0.261\n",
      "SMLL: -0.078\n",
      "MLL: 2.130\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2452/2500 [04:35<00:13,  3.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: -7.647\n",
      "ELBO: 140.561\n",
      "IWAE: 157.964\n",
      "SMSE: 0.272\n",
      "SMLL: 0.122\n",
      "MLL: 2.331\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [04:41<00:00,  8.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: -7.547\n",
      "ELBO: 139.758\n",
      "IWAE: 155.855\n",
      "SMSE: 0.287\n",
      "SMLL: 0.243\n",
      "MLL: 2.451\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [337.42842693190704,\n",
       "  17.54051914338881,\n",
       "  5.139487514563753,\n",
       "  3.48760666438317,\n",
       "  2.627604798801595,\n",
       "  1.7679773570216863,\n",
       "  0.34182859472043564,\n",
       "  -1.2204708796641228,\n",
       "  -1.1193831835562793,\n",
       "  -2.2060986057808165,\n",
       "  -3.1388810376830776,\n",
       "  -3.6118022748989147,\n",
       "  -3.7352794552822246,\n",
       "  -4.883593981805524,\n",
       "  -4.841305119053706,\n",
       "  -5.046327048709766,\n",
       "  -5.663161599292238,\n",
       "  -5.802092396745583,\n",
       "  -6.777025642398179,\n",
       "  -5.12965032790816,\n",
       "  -6.23325383784908,\n",
       "  -5.964109836403338,\n",
       "  -6.34749411589722,\n",
       "  -7.670343528256967,\n",
       "  -6.7021643402232405,\n",
       "  -6.89590651485272,\n",
       "  -7.358595137239461,\n",
       "  -8.496065898797132,\n",
       "  -7.601769183726241,\n",
       "  -7.486110675998899,\n",
       "  -7.3209635042619,\n",
       "  -7.069976783641969,\n",
       "  -7.432236274578108,\n",
       "  -7.376166286765027,\n",
       "  -7.75063509866179,\n",
       "  -8.151215302788486,\n",
       "  -7.689477664878763,\n",
       "  -7.7203136179210246,\n",
       "  -7.349135819478128,\n",
       "  -7.061585124933134,\n",
       "  -6.726537406719289,\n",
       "  -6.482010545987591,\n",
       "  -7.548655087480469,\n",
       "  -7.85962703837292,\n",
       "  -7.895373841326433,\n",
       "  -8.73589248641457,\n",
       "  -7.598734375110127,\n",
       "  -8.227045784116747,\n",
       "  -8.41055852303458,\n",
       "  -7.647475650036739,\n",
       "  -7.546677297243765],\n",
       " 'elbos': [tensor(-69464.4918, grad_fn=<AddBackward0>),\n",
       "  tensor(-4730.4288, grad_fn=<AddBackward0>),\n",
       "  tensor(-1533.9338, grad_fn=<AddBackward0>),\n",
       "  tensor(-1279.0250, grad_fn=<AddBackward0>),\n",
       "  tensor(-1122.4618, grad_fn=<AddBackward0>),\n",
       "  tensor(-980.1549, grad_fn=<AddBackward0>),\n",
       "  tensor(-807.3449, grad_fn=<AddBackward0>),\n",
       "  tensor(-596.8482, grad_fn=<AddBackward0>),\n",
       "  tensor(-482.0279, grad_fn=<AddBackward0>),\n",
       "  tensor(-394.3793, grad_fn=<AddBackward0>),\n",
       "  tensor(-315.1417, grad_fn=<AddBackward0>),\n",
       "  tensor(-249.1938, grad_fn=<AddBackward0>),\n",
       "  tensor(-188.4503, grad_fn=<AddBackward0>),\n",
       "  tensor(-148.3517, grad_fn=<AddBackward0>),\n",
       "  tensor(-128.1486, grad_fn=<AddBackward0>),\n",
       "  tensor(-89.9714, grad_fn=<AddBackward0>),\n",
       "  tensor(-80.0393, grad_fn=<AddBackward0>),\n",
       "  tensor(-51.5181, grad_fn=<AddBackward0>),\n",
       "  tensor(-34.5146, grad_fn=<AddBackward0>),\n",
       "  tensor(-19.2898, grad_fn=<AddBackward0>),\n",
       "  tensor(-7.8454, grad_fn=<AddBackward0>),\n",
       "  tensor(2.0474, grad_fn=<AddBackward0>),\n",
       "  tensor(23.2994, grad_fn=<AddBackward0>),\n",
       "  tensor(22.8369, grad_fn=<AddBackward0>),\n",
       "  tensor(47.7875, grad_fn=<AddBackward0>),\n",
       "  tensor(44.1350, grad_fn=<AddBackward0>),\n",
       "  tensor(56.7904, grad_fn=<AddBackward0>),\n",
       "  tensor(72.0819, grad_fn=<AddBackward0>),\n",
       "  tensor(69.3059, grad_fn=<AddBackward0>),\n",
       "  tensor(86.0288, grad_fn=<AddBackward0>),\n",
       "  tensor(73.2593, grad_fn=<AddBackward0>),\n",
       "  tensor(81.9379, grad_fn=<AddBackward0>),\n",
       "  tensor(107.4979, grad_fn=<AddBackward0>),\n",
       "  tensor(106.3822, grad_fn=<AddBackward0>),\n",
       "  tensor(91.0420, grad_fn=<AddBackward0>),\n",
       "  tensor(105.9448, grad_fn=<AddBackward0>),\n",
       "  tensor(116.9563, grad_fn=<AddBackward0>),\n",
       "  tensor(112.8245, grad_fn=<AddBackward0>),\n",
       "  tensor(122.9153, grad_fn=<AddBackward0>),\n",
       "  tensor(118.8381, grad_fn=<AddBackward0>),\n",
       "  tensor(122.7006, grad_fn=<AddBackward0>),\n",
       "  tensor(129.1570, grad_fn=<AddBackward0>),\n",
       "  tensor(127.2904, grad_fn=<AddBackward0>),\n",
       "  tensor(121.7907, grad_fn=<AddBackward0>),\n",
       "  tensor(133.7023, grad_fn=<AddBackward0>),\n",
       "  tensor(131.9972, grad_fn=<AddBackward0>),\n",
       "  tensor(146.2456, grad_fn=<AddBackward0>),\n",
       "  tensor(146.4334, grad_fn=<AddBackward0>),\n",
       "  tensor(146.6199, grad_fn=<AddBackward0>),\n",
       "  tensor(140.5614, grad_fn=<AddBackward0>),\n",
       "  tensor(139.7580, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-69354.7349, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-4620.9580, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1464.8521, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1216.1182, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1070.3940, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-937.4938, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-769.1620, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-566.7897, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-454.7248, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-353.3449, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-290.3500, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-224.5475, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-162.9308, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-129.0202, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-116.1201, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-68.8197, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-55.7983, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-34.3038, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-10.0770, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-7.4750, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(14.3161, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(34.6402, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(42.3903, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(46.1733, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(71.9386, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(63.0073, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(76.0315, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(90.3678, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(94.0319, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(107.6478, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(92.5944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(105.3944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(118.6714, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(119.9018, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(112.4944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(123.3443, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(131.9551, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(136.9895, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(139.2991, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(136.9650, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(145.3272, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(146.5111, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(144.7817, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(139.0801, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(153.3198, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(145.6577, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(160.0634, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(162.4114, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(163.1877, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(157.9641, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(155.8552, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [1.4480312972999643,\n",
       "  1.0236485730104397,\n",
       "  0.4605083408173229,\n",
       "  0.4102557325385812,\n",
       "  0.39529960888932,\n",
       "  0.3907157896330992,\n",
       "  0.37821270678001406,\n",
       "  0.43706703503001226,\n",
       "  0.502123496494079,\n",
       "  0.5239586408193785,\n",
       "  0.4727276731917455,\n",
       "  0.3828864248908826,\n",
       "  0.348069165592545,\n",
       "  0.3136002215082348,\n",
       "  0.31104575595888345,\n",
       "  0.2689274962962999,\n",
       "  0.24939282730146725,\n",
       "  0.23828805588097066,\n",
       "  0.23408968497034044,\n",
       "  0.2279335211458645,\n",
       "  0.230705072550308,\n",
       "  0.23773074515857642,\n",
       "  0.23131270408118063,\n",
       "  0.24714230250108934,\n",
       "  0.2202067132787914,\n",
       "  0.2335466082932005,\n",
       "  0.2500759205285308,\n",
       "  0.24214552532329559,\n",
       "  0.25959476782068575,\n",
       "  0.24607388106746372,\n",
       "  0.24397361998121622,\n",
       "  0.24082860845986773,\n",
       "  0.25835181248275374,\n",
       "  0.25046636832562647,\n",
       "  0.256687433093634,\n",
       "  0.23928165766365975,\n",
       "  0.24993393955503995,\n",
       "  0.24771122948746158,\n",
       "  0.2628795058930948,\n",
       "  0.26258549942597614,\n",
       "  0.274323092351246,\n",
       "  0.28093010711906324,\n",
       "  0.2588694415020112,\n",
       "  0.2669875315754374,\n",
       "  0.24772875991416657,\n",
       "  0.26098115258866356,\n",
       "  0.24756150686042266,\n",
       "  0.2741491284678384,\n",
       "  0.2605772288675335,\n",
       "  0.2717721930966173,\n",
       "  0.28721282002102017],\n",
       " 'smlls': [18.412052932757337,\n",
       "  2.7123033949999953,\n",
       "  0.14114333415895555,\n",
       "  -0.08281970468060192,\n",
       "  -0.15711224420512107,\n",
       "  -0.20817078317777127,\n",
       "  -0.23239402685025268,\n",
       "  -0.059121512828469815,\n",
       "  0.17868473138070096,\n",
       "  0.21453428750005785,\n",
       "  0.0019410579870873985,\n",
       "  -0.2413976474818449,\n",
       "  -0.3510245449252454,\n",
       "  -0.45678354564766965,\n",
       "  -0.45878512124780696,\n",
       "  -0.5961826122179741,\n",
       "  -0.6080864673772225,\n",
       "  -0.6351445682129945,\n",
       "  -0.6521452154978127,\n",
       "  -0.6674121890683965,\n",
       "  -0.607149316181139,\n",
       "  -0.5379716227402414,\n",
       "  -0.5285132805635343,\n",
       "  -0.45466203135986705,\n",
       "  -0.5499652425266407,\n",
       "  -0.4837470892224008,\n",
       "  -0.3727646892431366,\n",
       "  -0.34556299621133757,\n",
       "  -0.20919901748012956,\n",
       "  -0.2647908232623711,\n",
       "  -0.2742290606617239,\n",
       "  -0.28696880566413013,\n",
       "  -0.16909336484365003,\n",
       "  -0.19436167059574258,\n",
       "  -0.14082982204641423,\n",
       "  -0.22123223664346647,\n",
       "  -0.17624302300824782,\n",
       "  -0.19961146393728998,\n",
       "  -0.0064570499819711635,\n",
       "  -0.05590683982000777,\n",
       "  0.03250841140788937,\n",
       "  0.02761428996416604,\n",
       "  -0.06695232337097494,\n",
       "  -0.06233390147713062,\n",
       "  -0.1276671912658891,\n",
       "  0.004227246098852093,\n",
       "  -0.10323803261859495,\n",
       "  0.029446198127061063,\n",
       "  -0.07797740634351251,\n",
       "  0.1224238423726316,\n",
       "  0.24284785802283984],\n",
       " 'mlls': [20.620294524393703,\n",
       "  4.920544986636364,\n",
       "  2.3493849257953237,\n",
       "  2.1254218869557664,\n",
       "  2.0511293474312473,\n",
       "  2.000070808458597,\n",
       "  1.9758475647861156,\n",
       "  2.1491200788078983,\n",
       "  2.3869263230170694,\n",
       "  2.422775879136426,\n",
       "  2.2101826496234556,\n",
       "  1.9668439441545236,\n",
       "  1.857217046711123,\n",
       "  1.7514580459886986,\n",
       "  1.7494564703885613,\n",
       "  1.6120589794183944,\n",
       "  1.6001551242591459,\n",
       "  1.5730970234233739,\n",
       "  1.5560963761385558,\n",
       "  1.5408294025679716,\n",
       "  1.6010922754552295,\n",
       "  1.670269968896127,\n",
       "  1.679728311072834,\n",
       "  1.7535795602765012,\n",
       "  1.6582763491097274,\n",
       "  1.7244945024139675,\n",
       "  1.8354769023932318,\n",
       "  1.862678595425031,\n",
       "  1.999042574156239,\n",
       "  1.9434507683739974,\n",
       "  1.9340125309746445,\n",
       "  1.9212727859722383,\n",
       "  2.0391482267927183,\n",
       "  2.0138799210406257,\n",
       "  2.067411769589954,\n",
       "  1.987009354992902,\n",
       "  2.0319985686281203,\n",
       "  2.008630127699078,\n",
       "  2.2017845416543973,\n",
       "  2.1523347518163605,\n",
       "  2.2407500030442575,\n",
       "  2.2358558816005343,\n",
       "  2.141289268265393,\n",
       "  2.145907690159238,\n",
       "  2.080574400370479,\n",
       "  2.2124688377352206,\n",
       "  2.105003559017774,\n",
       "  2.2376877897634295,\n",
       "  2.1302641852928557,\n",
       "  2.3306654340089996,\n",
       "  2.451089449659208]}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(zi_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, zi_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 4/2500 [00:01<28:34,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 346.745\n",
      "ELBO: -70586.053\n",
      "IWAE: -68993.356\n",
      "SMSE: 2.374\n",
      "SMLL: 6.051\n",
      "MLL: 8.259\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 54/2500 [00:03<03:55, 10.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 19.585\n",
      "ELBO: -4472.171\n",
      "IWAE: -4318.618\n",
      "SMSE: 0.509\n",
      "SMLL: 0.115\n",
      "MLL: 2.323\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 106/2500 [00:06<03:58, 10.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 10.607\n",
      "ELBO: -2539.600\n",
      "IWAE: -2445.978\n",
      "SMSE: 0.501\n",
      "SMLL: 0.020\n",
      "MLL: 2.228\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 156/2500 [00:09<02:55, 13.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 8.061\n",
      "ELBO: -2001.187\n",
      "IWAE: -1912.114\n",
      "SMSE: 0.437\n",
      "SMLL: -0.174\n",
      "MLL: 2.034\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 204/2500 [00:11<03:22, 11.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 6.642\n",
      "ELBO: -1744.238\n",
      "IWAE: -1682.610\n",
      "SMSE: 0.439\n",
      "SMLL: -0.233\n",
      "MLL: 1.976\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 255/2500 [00:14<03:42, 10.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 6.155\n",
      "ELBO: -1609.327\n",
      "IWAE: -1557.627\n",
      "SMSE: 0.424\n",
      "SMLL: -0.292\n",
      "MLL: 1.916\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 306/2500 [00:17<02:44, 13.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 5.687\n",
      "ELBO: -1509.665\n",
      "IWAE: -1443.779\n",
      "SMSE: 0.393\n",
      "SMLL: -0.386\n",
      "MLL: 1.822\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 356/2500 [00:19<02:41, 13.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 5.585\n",
      "ELBO: -1495.686\n",
      "IWAE: -1411.597\n",
      "SMSE: 0.332\n",
      "SMLL: -0.505\n",
      "MLL: 1.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 406/2500 [00:21<02:35, 13.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 5.104\n",
      "ELBO: -1363.465\n",
      "IWAE: -1320.558\n",
      "SMSE: 0.337\n",
      "SMLL: -0.500\n",
      "MLL: 1.708\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 456/2500 [00:24<02:35, 13.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: 4.755\n",
      "ELBO: -1283.504\n",
      "IWAE: -1225.269\n",
      "SMSE: 0.344\n",
      "SMLL: -0.505\n",
      "MLL: 1.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 505/2500 [00:27<03:42,  8.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: 4.495\n",
      "ELBO: -1218.630\n",
      "IWAE: -1167.473\n",
      "SMSE: 0.338\n",
      "SMLL: -0.499\n",
      "MLL: 1.709\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 555/2500 [00:29<02:29, 13.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: 4.071\n",
      "ELBO: -1172.603\n",
      "IWAE: -1117.114\n",
      "SMSE: 0.280\n",
      "SMLL: -0.610\n",
      "MLL: 1.598\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 606/2500 [00:32<02:20, 13.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: 3.929\n",
      "ELBO: -1126.949\n",
      "IWAE: -1081.064\n",
      "SMSE: 0.260\n",
      "SMLL: -0.647\n",
      "MLL: 1.561\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 654/2500 [00:34<02:20, 13.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: 3.858\n",
      "ELBO: -1089.525\n",
      "IWAE: -1047.405\n",
      "SMSE: 0.237\n",
      "SMLL: -0.711\n",
      "MLL: 1.497\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 704/2500 [00:36<02:14, 13.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: 3.715\n",
      "ELBO: -1096.349\n",
      "IWAE: -1057.518\n",
      "SMSE: 0.302\n",
      "SMLL: -0.543\n",
      "MLL: 1.665\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 754/2500 [00:39<02:12, 13.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: 3.630\n",
      "ELBO: -1030.551\n",
      "IWAE: -1001.492\n",
      "SMSE: 0.255\n",
      "SMLL: -0.632\n",
      "MLL: 1.576\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 804/2500 [00:41<02:38, 10.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 3.476\n",
      "ELBO: -993.423\n",
      "IWAE: -951.940\n",
      "SMSE: 0.250\n",
      "SMLL: -0.659\n",
      "MLL: 1.550\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 854/2500 [00:44<02:05, 13.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: 3.507\n",
      "ELBO: -968.861\n",
      "IWAE: -939.891\n",
      "SMSE: 0.255\n",
      "SMLL: -0.641\n",
      "MLL: 1.567\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 904/2500 [00:46<02:40,  9.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: 3.252\n",
      "ELBO: -938.379\n",
      "IWAE: -900.532\n",
      "SMSE: 0.233\n",
      "SMLL: -0.687\n",
      "MLL: 1.521\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 954/2500 [00:49<02:28, 10.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: 3.109\n",
      "ELBO: -915.538\n",
      "IWAE: -887.881\n",
      "SMSE: 0.267\n",
      "SMLL: -0.576\n",
      "MLL: 1.632\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1004/2500 [00:52<02:21, 10.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: 3.173\n",
      "ELBO: -901.319\n",
      "IWAE: -872.197\n",
      "SMSE: 0.255\n",
      "SMLL: -0.626\n",
      "MLL: 1.582\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1054/2500 [00:55<02:53,  8.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: 3.073\n",
      "ELBO: -889.690\n",
      "IWAE: -857.514\n",
      "SMSE: 0.247\n",
      "SMLL: -0.605\n",
      "MLL: 1.604\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1105/2500 [00:57<02:18, 10.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: 3.132\n",
      "ELBO: -890.615\n",
      "IWAE: -858.539\n",
      "SMSE: 0.247\n",
      "SMLL: -0.619\n",
      "MLL: 1.589\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1153/2500 [01:00<02:58,  7.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: 2.758\n",
      "ELBO: -871.028\n",
      "IWAE: -852.254\n",
      "SMSE: 0.238\n",
      "SMLL: -0.641\n",
      "MLL: 1.567\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1204/2500 [01:02<01:44, 12.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: 3.031\n",
      "ELBO: -892.668\n",
      "IWAE: -861.895\n",
      "SMSE: 0.264\n",
      "SMLL: -0.514\n",
      "MLL: 1.694\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1255/2500 [01:05<01:39, 12.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: 2.879\n",
      "ELBO: -866.165\n",
      "IWAE: -829.565\n",
      "SMSE: 0.259\n",
      "SMLL: -0.517\n",
      "MLL: 1.691\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1305/2500 [01:08<01:58, 10.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: 2.742\n",
      "ELBO: -855.305\n",
      "IWAE: -826.100\n",
      "SMSE: 0.262\n",
      "SMLL: -0.491\n",
      "MLL: 1.717\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1354/2500 [01:10<01:51, 10.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: 2.795\n",
      "ELBO: -870.581\n",
      "IWAE: -828.842\n",
      "SMSE: 0.278\n",
      "SMLL: -0.412\n",
      "MLL: 1.796\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1406/2500 [01:13<01:46, 10.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: 2.804\n",
      "ELBO: -844.243\n",
      "IWAE: -815.076\n",
      "SMSE: 0.259\n",
      "SMLL: -0.453\n",
      "MLL: 1.755\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1456/2500 [01:16<01:23, 12.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: 2.656\n",
      "ELBO: -837.756\n",
      "IWAE: -810.068\n",
      "SMSE: 0.269\n",
      "SMLL: -0.446\n",
      "MLL: 1.762\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1504/2500 [01:18<01:38, 10.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: 2.614\n",
      "ELBO: -833.742\n",
      "IWAE: -802.311\n",
      "SMSE: 0.281\n",
      "SMLL: -0.308\n",
      "MLL: 1.900\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1555/2500 [01:21<01:32, 10.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: 2.717\n",
      "ELBO: -842.693\n",
      "IWAE: -811.189\n",
      "SMSE: 0.274\n",
      "SMLL: -0.364\n",
      "MLL: 1.844\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1603/2500 [01:24<01:55,  7.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: 2.436\n",
      "ELBO: -808.594\n",
      "IWAE: -784.545\n",
      "SMSE: 0.249\n",
      "SMLL: -0.476\n",
      "MLL: 1.732\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1654/2500 [01:27<01:45,  8.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: 2.530\n",
      "ELBO: -822.864\n",
      "IWAE: -797.674\n",
      "SMSE: 0.249\n",
      "SMLL: -0.425\n",
      "MLL: 1.783\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1704/2500 [01:29<01:12, 10.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: 2.563\n",
      "ELBO: -795.729\n",
      "IWAE: -761.907\n",
      "SMSE: 0.285\n",
      "SMLL: -0.262\n",
      "MLL: 1.946\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1756/2500 [01:32<01:01, 12.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: 2.411\n",
      "ELBO: -806.689\n",
      "IWAE: -775.880\n",
      "SMSE: 0.239\n",
      "SMLL: -0.475\n",
      "MLL: 1.733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1806/2500 [01:35<01:07, 10.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: 2.525\n",
      "ELBO: -797.867\n",
      "IWAE: -773.312\n",
      "SMSE: 0.266\n",
      "SMLL: -0.324\n",
      "MLL: 1.884\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1854/2500 [01:37<01:02, 10.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: 2.543\n",
      "ELBO: -783.469\n",
      "IWAE: -754.802\n",
      "SMSE: 0.263\n",
      "SMLL: -0.357\n",
      "MLL: 1.851\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1906/2500 [01:40<00:54, 10.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: 2.416\n",
      "ELBO: -790.987\n",
      "IWAE: -764.887\n",
      "SMSE: 0.259\n",
      "SMLL: -0.367\n",
      "MLL: 1.842\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1954/2500 [01:43<00:52, 10.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: 2.257\n",
      "ELBO: -775.619\n",
      "IWAE: -756.984\n",
      "SMSE: 0.247\n",
      "SMLL: -0.381\n",
      "MLL: 1.827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2004/2500 [01:46<00:57,  8.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: 2.291\n",
      "ELBO: -768.430\n",
      "IWAE: -741.775\n",
      "SMSE: 0.248\n",
      "SMLL: -0.385\n",
      "MLL: 1.824\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2054/2500 [01:48<00:35, 12.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: 2.241\n",
      "ELBO: -753.864\n",
      "IWAE: -728.408\n",
      "SMSE: 0.257\n",
      "SMLL: -0.361\n",
      "MLL: 1.847\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2104/2500 [01:51<00:30, 13.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: 2.258\n",
      "ELBO: -754.578\n",
      "IWAE: -720.410\n",
      "SMSE: 0.274\n",
      "SMLL: -0.290\n",
      "MLL: 1.918\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2154/2500 [01:53<00:32, 10.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: 2.254\n",
      "ELBO: -743.524\n",
      "IWAE: -721.176\n",
      "SMSE: 0.252\n",
      "SMLL: -0.382\n",
      "MLL: 1.826\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2204/2500 [01:56<00:29, 10.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: 2.278\n",
      "ELBO: -749.416\n",
      "IWAE: -726.735\n",
      "SMSE: 0.256\n",
      "SMLL: -0.350\n",
      "MLL: 1.858\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2254/2500 [01:58<00:22, 11.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: 2.142\n",
      "ELBO: -737.270\n",
      "IWAE: -718.791\n",
      "SMSE: 0.233\n",
      "SMLL: -0.475\n",
      "MLL: 1.733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2306/2500 [02:01<00:15, 12.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: 2.227\n",
      "ELBO: -754.155\n",
      "IWAE: -733.619\n",
      "SMSE: 0.258\n",
      "SMLL: -0.335\n",
      "MLL: 1.874\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2356/2500 [02:03<00:14, 10.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: 2.168\n",
      "ELBO: -753.526\n",
      "IWAE: -731.759\n",
      "SMSE: 0.255\n",
      "SMLL: -0.400\n",
      "MLL: 1.808\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2404/2500 [02:06<00:08, 11.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: 2.020\n",
      "ELBO: -715.928\n",
      "IWAE: -698.848\n",
      "SMSE: 0.267\n",
      "SMLL: -0.253\n",
      "MLL: 1.955\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2454/2500 [02:09<00:03, 11.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: 2.182\n",
      "ELBO: -738.064\n",
      "IWAE: -712.501\n",
      "SMSE: 0.292\n",
      "SMLL: -0.215\n",
      "MLL: 1.993\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [02:11<00:00, 18.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: 2.026\n",
      "ELBO: -723.456\n",
      "IWAE: -693.300\n",
      "SMSE: 0.234\n",
      "SMLL: -0.507\n",
      "MLL: 1.701\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [346.74476053203944,\n",
       "  19.58532280036387,\n",
       "  10.607473773868483,\n",
       "  8.060885151227142,\n",
       "  6.642055849523875,\n",
       "  6.1545152533523675,\n",
       "  5.687167656677102,\n",
       "  5.584887568009737,\n",
       "  5.104042071477486,\n",
       "  4.754536432468518,\n",
       "  4.494762210605408,\n",
       "  4.071044815700114,\n",
       "  3.928623078063365,\n",
       "  3.8582847893919854,\n",
       "  3.7152013786134446,\n",
       "  3.630275983197366,\n",
       "  3.4757992580103116,\n",
       "  3.5074960473514665,\n",
       "  3.2515753452532734,\n",
       "  3.10933987192036,\n",
       "  3.1725345527620994,\n",
       "  3.0733290152301556,\n",
       "  3.13165356992171,\n",
       "  2.7584252510052454,\n",
       "  3.030950945443345,\n",
       "  2.878830987889991,\n",
       "  2.7420718392510737,\n",
       "  2.7948911879130844,\n",
       "  2.803508609740399,\n",
       "  2.65606557926659,\n",
       "  2.6136287813918657,\n",
       "  2.717137569105759,\n",
       "  2.435533699784336,\n",
       "  2.529956904225853,\n",
       "  2.563174819726486,\n",
       "  2.4111459312461263,\n",
       "  2.52503837226447,\n",
       "  2.543405274266306,\n",
       "  2.4155694975543986,\n",
       "  2.256899092947536,\n",
       "  2.2908360118870226,\n",
       "  2.2412405600406418,\n",
       "  2.2581454152387703,\n",
       "  2.253857771711551,\n",
       "  2.27776182026347,\n",
       "  2.142416755474027,\n",
       "  2.226599491178271,\n",
       "  2.1680413153831286,\n",
       "  2.0202267276767545,\n",
       "  2.1818735960958677,\n",
       "  2.026418143926456],\n",
       " 'elbos': [tensor(-70586.0534, grad_fn=<AddBackward0>),\n",
       "  tensor(-4472.1710, grad_fn=<AddBackward0>),\n",
       "  tensor(-2539.5998, grad_fn=<AddBackward0>),\n",
       "  tensor(-2001.1869, grad_fn=<AddBackward0>),\n",
       "  tensor(-1744.2378, grad_fn=<AddBackward0>),\n",
       "  tensor(-1609.3274, grad_fn=<AddBackward0>),\n",
       "  tensor(-1509.6645, grad_fn=<AddBackward0>),\n",
       "  tensor(-1495.6860, grad_fn=<AddBackward0>),\n",
       "  tensor(-1363.4650, grad_fn=<AddBackward0>),\n",
       "  tensor(-1283.5036, grad_fn=<AddBackward0>),\n",
       "  tensor(-1218.6295, grad_fn=<AddBackward0>),\n",
       "  tensor(-1172.6031, grad_fn=<AddBackward0>),\n",
       "  tensor(-1126.9491, grad_fn=<AddBackward0>),\n",
       "  tensor(-1089.5252, grad_fn=<AddBackward0>),\n",
       "  tensor(-1096.3486, grad_fn=<AddBackward0>),\n",
       "  tensor(-1030.5508, grad_fn=<AddBackward0>),\n",
       "  tensor(-993.4229, grad_fn=<AddBackward0>),\n",
       "  tensor(-968.8607, grad_fn=<AddBackward0>),\n",
       "  tensor(-938.3787, grad_fn=<AddBackward0>),\n",
       "  tensor(-915.5384, grad_fn=<AddBackward0>),\n",
       "  tensor(-901.3187, grad_fn=<AddBackward0>),\n",
       "  tensor(-889.6897, grad_fn=<AddBackward0>),\n",
       "  tensor(-890.6147, grad_fn=<AddBackward0>),\n",
       "  tensor(-871.0277, grad_fn=<AddBackward0>),\n",
       "  tensor(-892.6676, grad_fn=<AddBackward0>),\n",
       "  tensor(-866.1652, grad_fn=<AddBackward0>),\n",
       "  tensor(-855.3055, grad_fn=<AddBackward0>),\n",
       "  tensor(-870.5813, grad_fn=<AddBackward0>),\n",
       "  tensor(-844.2427, grad_fn=<AddBackward0>),\n",
       "  tensor(-837.7562, grad_fn=<AddBackward0>),\n",
       "  tensor(-833.7421, grad_fn=<AddBackward0>),\n",
       "  tensor(-842.6932, grad_fn=<AddBackward0>),\n",
       "  tensor(-808.5944, grad_fn=<AddBackward0>),\n",
       "  tensor(-822.8639, grad_fn=<AddBackward0>),\n",
       "  tensor(-795.7288, grad_fn=<AddBackward0>),\n",
       "  tensor(-806.6894, grad_fn=<AddBackward0>),\n",
       "  tensor(-797.8667, grad_fn=<AddBackward0>),\n",
       "  tensor(-783.4691, grad_fn=<AddBackward0>),\n",
       "  tensor(-790.9865, grad_fn=<AddBackward0>),\n",
       "  tensor(-775.6188, grad_fn=<AddBackward0>),\n",
       "  tensor(-768.4303, grad_fn=<AddBackward0>),\n",
       "  tensor(-753.8644, grad_fn=<AddBackward0>),\n",
       "  tensor(-754.5778, grad_fn=<AddBackward0>),\n",
       "  tensor(-743.5243, grad_fn=<AddBackward0>),\n",
       "  tensor(-749.4164, grad_fn=<AddBackward0>),\n",
       "  tensor(-737.2701, grad_fn=<AddBackward0>),\n",
       "  tensor(-754.1548, grad_fn=<AddBackward0>),\n",
       "  tensor(-753.5263, grad_fn=<AddBackward0>),\n",
       "  tensor(-715.9279, grad_fn=<AddBackward0>),\n",
       "  tensor(-738.0637, grad_fn=<AddBackward0>),\n",
       "  tensor(-723.4561, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-68993.3563, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-4318.6183, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-2445.9781, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1912.1143, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1682.6100, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1557.6266, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1443.7794, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1411.5966, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1320.5583, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1225.2692, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1167.4733, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1117.1144, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1081.0642, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1047.4048, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1057.5184, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1001.4919, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-951.9401, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-939.8915, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-900.5317, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-887.8808, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-872.1966, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-857.5137, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-858.5392, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-852.2535, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-861.8946, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-829.5653, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-826.0997, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-828.8422, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-815.0763, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-810.0682, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-802.3108, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-811.1894, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-784.5446, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-797.6742, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-761.9068, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-775.8796, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-773.3117, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-754.8020, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-764.8866, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-756.9838, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-741.7755, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-728.4079, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-720.4101, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-721.1763, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-726.7352, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-718.7913, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-733.6185, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-731.7592, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-698.8477, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-712.5014, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-693.3000, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [2.374138728271379,\n",
       "  0.5094045847461567,\n",
       "  0.5009408664648348,\n",
       "  0.4366111192323337,\n",
       "  0.4385968482733264,\n",
       "  0.42374225525992304,\n",
       "  0.3930103906101734,\n",
       "  0.3322968885425008,\n",
       "  0.33701980706537116,\n",
       "  0.34396044072815846,\n",
       "  0.3377548032044637,\n",
       "  0.28034902762093744,\n",
       "  0.25998172818838733,\n",
       "  0.23748280667716093,\n",
       "  0.30235253844847576,\n",
       "  0.2553873153934216,\n",
       "  0.2503776182432513,\n",
       "  0.25510645252109915,\n",
       "  0.23328949433955073,\n",
       "  0.267075104918187,\n",
       "  0.25454092104052944,\n",
       "  0.2470776208640013,\n",
       "  0.24713975833229643,\n",
       "  0.23760619999629792,\n",
       "  0.2639263755727199,\n",
       "  0.25869048605396666,\n",
       "  0.2624932837092947,\n",
       "  0.27839398719702585,\n",
       "  0.2589923819907612,\n",
       "  0.26872879681989353,\n",
       "  0.2810517870186033,\n",
       "  0.27364790315838006,\n",
       "  0.24912639877912054,\n",
       "  0.24931449831858418,\n",
       "  0.2852294104432995,\n",
       "  0.23882371189301912,\n",
       "  0.2658228355592492,\n",
       "  0.2631408827524185,\n",
       "  0.25929040292296374,\n",
       "  0.2470131984683852,\n",
       "  0.24813614535570105,\n",
       "  0.25731716381297415,\n",
       "  0.2739021848474916,\n",
       "  0.2517761714791271,\n",
       "  0.2556251074075102,\n",
       "  0.23325315192578877,\n",
       "  0.2578523548140437,\n",
       "  0.255071429195385,\n",
       "  0.2669098551296191,\n",
       "  0.29162032492435835,\n",
       "  0.23443884921702118],\n",
       " 'smlls': [6.050669947723779,\n",
       "  0.11455755352936688,\n",
       "  0.020143034654175922,\n",
       "  -0.1740551129812585,\n",
       "  -0.2326367301026544,\n",
       "  -0.2918682108224697,\n",
       "  -0.3861292381770689,\n",
       "  -0.5054400535933957,\n",
       "  -0.5001438616653279,\n",
       "  -0.5048117156894174,\n",
       "  -0.498747200590987,\n",
       "  -0.6100957585283683,\n",
       "  -0.6471468960463281,\n",
       "  -0.7114256063313852,\n",
       "  -0.5433015108505184,\n",
       "  -0.6321991830993908,\n",
       "  -0.6586592516685527,\n",
       "  -0.6411879406876541,\n",
       "  -0.6869857515616703,\n",
       "  -0.5757710732038064,\n",
       "  -0.6258560835748945,\n",
       "  -0.6047104382154241,\n",
       "  -0.6193265159971573,\n",
       "  -0.6407578080993476,\n",
       "  -0.5138762401876907,\n",
       "  -0.5174860884150249,\n",
       "  -0.49084991746105394,\n",
       "  -0.4124652081507247,\n",
       "  -0.45301346260121494,\n",
       "  -0.44592410050046344,\n",
       "  -0.3082137540311573,\n",
       "  -0.3641256101424124,\n",
       "  -0.4757996198706996,\n",
       "  -0.4248217075428719,\n",
       "  -0.2619271228437774,\n",
       "  -0.4747945926730916,\n",
       "  -0.32420315195296,\n",
       "  -0.35682681757653717,\n",
       "  -0.36669326388456563,\n",
       "  -0.3813946462703026,\n",
       "  -0.3847162043739832,\n",
       "  -0.36110066067969915,\n",
       "  -0.29025856114218435,\n",
       "  -0.3820327356501885,\n",
       "  -0.3500268206692856,\n",
       "  -0.4748167804409676,\n",
       "  -0.33458851422075525,\n",
       "  -0.4004297859042829,\n",
       "  -0.25349794276818854,\n",
       "  -0.21547451494203784,\n",
       "  -0.5069757923704039],\n",
       " 'mlls': [8.258911539360147,\n",
       "  2.3227991451657353,\n",
       "  2.2283846262905445,\n",
       "  2.03418647865511,\n",
       "  1.9756048615337136,\n",
       "  1.9163733808138985,\n",
       "  1.8221123534592996,\n",
       "  1.7028015380429729,\n",
       "  1.7080977299710405,\n",
       "  1.703429875946951,\n",
       "  1.7094943910453813,\n",
       "  1.598145833108,\n",
       "  1.5610946955900402,\n",
       "  1.4968159853049832,\n",
       "  1.6649400807858499,\n",
       "  1.5760424085369777,\n",
       "  1.5495823399678157,\n",
       "  1.5670536509487143,\n",
       "  1.521255840074698,\n",
       "  1.632470518432562,\n",
       "  1.582385508061474,\n",
       "  1.603531153420944,\n",
       "  1.5889150756392112,\n",
       "  1.5674837835370206,\n",
       "  1.6943653514486776,\n",
       "  1.6907555032213437,\n",
       "  1.7173916741753146,\n",
       "  1.7957763834856435,\n",
       "  1.7552281290351532,\n",
       "  1.7623174911359047,\n",
       "  1.9000278376052109,\n",
       "  1.844115981493956,\n",
       "  1.7324419717656687,\n",
       "  1.7834198840934965,\n",
       "  1.9463144687925908,\n",
       "  1.7334469989632766,\n",
       "  1.8840384396834082,\n",
       "  1.8514147740598312,\n",
       "  1.8415483277518028,\n",
       "  1.8268469453660658,\n",
       "  1.8235253872623851,\n",
       "  1.8471409309566693,\n",
       "  1.9179830304941838,\n",
       "  1.8262088559861798,\n",
       "  1.8582147709670827,\n",
       "  1.7334248111954007,\n",
       "  1.8736530774156133,\n",
       "  1.8078118057320853,\n",
       "  1.95474364886818,\n",
       "  1.9927670766943306,\n",
       "  1.7012657992659646]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(vae_model, gpvae.estimators.vae_estimators.td_estimator, \n",
    "    loader, vae_args, gpvae.estimators.vae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.vae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/2500 [00:01<39:46,  1.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: -14.896\n",
      "ELBO: -959.808\n",
      "IWAE: -930.113\n",
      "SMSE: 1.992\n",
      "SMLL: 1.293\n",
      "MLL: 3.501\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:06<10:08,  4.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: -15.642\n",
      "ELBO: -860.503\n",
      "IWAE: -835.450\n",
      "SMSE: 1.745\n",
      "SMLL: 0.619\n",
      "MLL: 2.827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 103/2500 [00:12<08:46,  4.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: -16.405\n",
      "ELBO: -826.254\n",
      "IWAE: -796.506\n",
      "SMSE: 1.747\n",
      "SMLL: 0.634\n",
      "MLL: 2.843\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 153/2500 [00:18<08:10,  4.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: -16.749\n",
      "ELBO: -805.053\n",
      "IWAE: -780.142\n",
      "SMSE: 1.679\n",
      "SMLL: 0.625\n",
      "MLL: 2.833\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|▋         | 174/2500 [00:20<04:29,  8.64it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-bcff930cf7ab>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m train_eeg(pn_model, gpvae.estimators.gpvae_estimators.td_estimator, \n\u001b[1;32m      2\u001b[0m     \u001b[0mloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)\n\u001b[0m",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/experiments/eeg/train_eeg.py\u001b[0m in \u001b[0;36mtrain_eeg\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model)\u001b[0m\n\u001b[1;32m     68\u001b[0m                     decoder_scale=args['decoder_scale'])\n\u001b[1;32m     69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     71\u001b[0m             \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    193\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    194\u001b[0m         \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_eeg(pn_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, pn_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/2500 [00:01<41:07,  1.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 345.282\n",
      "ELBO: -70517.501\n",
      "IWAE: -70363.158\n",
      "SMSE: 1.922\n",
      "SMLL: 25.424\n",
      "MLL: 27.632\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:12<58:29,  1.43s/it]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 15.981\n",
      "ELBO: -3916.869\n",
      "IWAE: -3852.021\n",
      "SMSE: 0.737\n",
      "SMLL: 1.431\n",
      "MLL: 3.640\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 101/2500 [00:24<39:02,  1.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 4.707\n",
      "ELBO: -1461.431\n",
      "IWAE: -1397.548\n",
      "SMSE: 0.387\n",
      "SMLL: -0.179\n",
      "MLL: 2.029\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 152/2500 [00:37<26:51,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 2.017\n",
      "ELBO: -969.603\n",
      "IWAE: -908.553\n",
      "SMSE: 0.330\n",
      "SMLL: -0.392\n",
      "MLL: 1.817\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 202/2500 [00:45<15:56,  2.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 0.826\n",
      "ELBO: -709.524\n",
      "IWAE: -665.426\n",
      "SMSE: 0.385\n",
      "SMLL: -0.184\n",
      "MLL: 2.024\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:54<17:57,  2.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: -1.076\n",
      "ELBO: -523.583\n",
      "IWAE: -479.378\n",
      "SMSE: 0.442\n",
      "SMLL: 0.027\n",
      "MLL: 2.235\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 302/2500 [01:02<14:56,  2.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: -0.855\n",
      "ELBO: -412.024\n",
      "IWAE: -369.960\n",
      "SMSE: 0.405\n",
      "SMLL: -0.137\n",
      "MLL: 2.071\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [01:09<13:04,  2.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: -1.679\n",
      "ELBO: -358.985\n",
      "IWAE: -321.096\n",
      "SMSE: 0.425\n",
      "SMLL: -0.053\n",
      "MLL: 2.156\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [01:16<12:53,  2.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: -2.312\n",
      "ELBO: -297.433\n",
      "IWAE: -251.860\n",
      "SMSE: 0.399\n",
      "SMLL: -0.149\n",
      "MLL: 2.060\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 452/2500 [01:23<12:31,  2.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: -2.829\n",
      "ELBO: -231.237\n",
      "IWAE: -198.323\n",
      "SMSE: 0.410\n",
      "SMLL: -0.118\n",
      "MLL: 2.090\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 502/2500 [01:30<14:49,  2.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: -2.592\n",
      "ELBO: -171.310\n",
      "IWAE: -141.050\n",
      "SMSE: 0.400\n",
      "SMLL: -0.112\n",
      "MLL: 2.096\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 552/2500 [01:39<14:10,  2.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: -2.875\n",
      "ELBO: -96.931\n",
      "IWAE: -64.159\n",
      "SMSE: 0.398\n",
      "SMLL: -0.051\n",
      "MLL: 2.157\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 601/2500 [01:49<23:35,  1.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: -3.079\n",
      "ELBO: -79.751\n",
      "IWAE: -48.163\n",
      "SMSE: 0.423\n",
      "SMLL: 0.085\n",
      "MLL: 2.293\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 651/2500 [02:04<43:24,  1.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: -3.234\n",
      "ELBO: -59.077\n",
      "IWAE: -33.028\n",
      "SMSE: 0.373\n",
      "SMLL: -0.123\n",
      "MLL: 2.086\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 702/2500 [02:17<18:06,  1.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: -3.213\n",
      "ELBO: -42.395\n",
      "IWAE: -9.313\n",
      "SMSE: 0.401\n",
      "SMLL: 0.047\n",
      "MLL: 2.255\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 752/2500 [02:27<12:43,  2.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: -3.711\n",
      "ELBO: -46.066\n",
      "IWAE: -10.303\n",
      "SMSE: 0.397\n",
      "SMLL: 0.075\n",
      "MLL: 2.283\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 802/2500 [02:35<11:00,  2.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: -3.482\n",
      "ELBO: -15.435\n",
      "IWAE: 16.183\n",
      "SMSE: 0.368\n",
      "SMLL: -0.043\n",
      "MLL: 2.166\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 852/2500 [02:43<13:40,  2.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: -3.667\n",
      "ELBO: -10.619\n",
      "IWAE: 9.309\n",
      "SMSE: 0.327\n",
      "SMLL: -0.174\n",
      "MLL: 2.034\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 902/2500 [02:54<16:21,  1.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: -4.007\n",
      "ELBO: 5.760\n",
      "IWAE: 27.757\n",
      "SMSE: 0.337\n",
      "SMLL: -0.116\n",
      "MLL: 2.092\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 952/2500 [03:01<09:21,  2.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: -4.123\n",
      "ELBO: 19.911\n",
      "IWAE: 47.875\n",
      "SMSE: 0.330\n",
      "SMLL: -0.048\n",
      "MLL: 2.160\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1002/2500 [03:09<10:34,  2.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: -4.220\n",
      "ELBO: 39.954\n",
      "IWAE: 63.165\n",
      "SMSE: 0.321\n",
      "SMLL: -0.075\n",
      "MLL: 2.133\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1052/2500 [03:16<10:14,  2.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: -4.267\n",
      "ELBO: 43.547\n",
      "IWAE: 66.393\n",
      "SMSE: 0.329\n",
      "SMLL: 0.030\n",
      "MLL: 2.238\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1102/2500 [03:23<09:51,  2.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: -4.427\n",
      "ELBO: 52.830\n",
      "IWAE: 77.410\n",
      "SMSE: 0.323\n",
      "SMLL: 0.064\n",
      "MLL: 2.273\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1152/2500 [03:31<08:33,  2.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: -3.584\n",
      "ELBO: 69.546\n",
      "IWAE: 93.900\n",
      "SMSE: 0.357\n",
      "SMLL: 0.306\n",
      "MLL: 2.514\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1202/2500 [03:38<08:01,  2.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: -4.231\n",
      "ELBO: 88.014\n",
      "IWAE: 112.284\n",
      "SMSE: 0.339\n",
      "SMLL: 0.243\n",
      "MLL: 2.451\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1252/2500 [03:45<07:46,  2.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: -3.876\n",
      "ELBO: 88.482\n",
      "IWAE: 126.514\n",
      "SMSE: 0.306\n",
      "SMLL: 0.088\n",
      "MLL: 2.296\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1302/2500 [03:52<07:30,  2.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: -3.706\n",
      "ELBO: 100.648\n",
      "IWAE: 130.929\n",
      "SMSE: 0.344\n",
      "SMLL: 0.333\n",
      "MLL: 2.542\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1352/2500 [03:59<07:49,  2.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: -3.846\n",
      "ELBO: 84.451\n",
      "IWAE: 113.325\n",
      "SMSE: 0.357\n",
      "SMLL: 0.522\n",
      "MLL: 2.730\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1402/2500 [04:05<06:34,  2.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: -4.141\n",
      "ELBO: 131.537\n",
      "IWAE: 153.442\n",
      "SMSE: 0.322\n",
      "SMLL: 0.274\n",
      "MLL: 2.482\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1452/2500 [04:12<06:51,  2.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: -3.929\n",
      "ELBO: 111.157\n",
      "IWAE: 144.368\n",
      "SMSE: 0.348\n",
      "SMLL: 0.561\n",
      "MLL: 2.769\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1500/2500 [04:18<01:50,  9.04it/s]"
     ]
    }
   ],
   "source": [
    "train_eeg(pog_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, pog_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/2500 [00:01<48:19,  1.16s/it]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 7.952\n",
      "ELBO: -1573.169\n",
      "IWAE: -1510.731\n",
      "SMSE: 0.563\n",
      "SMLL: 0.494\n",
      "MLL: 2.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:09<08:42,  4.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 4.512\n",
      "ELBO: -940.892\n",
      "IWAE: -890.770\n",
      "SMSE: 0.455\n",
      "SMLL: -0.331\n",
      "MLL: 1.877\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 102/2500 [00:15<08:32,  4.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 4.016\n",
      "ELBO: -791.629\n",
      "IWAE: -758.042\n",
      "SMSE: 0.424\n",
      "SMLL: -0.406\n",
      "MLL: 1.802\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 152/2500 [00:21<08:54,  4.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 3.643\n",
      "ELBO: -600.141\n",
      "IWAE: -574.826\n",
      "SMSE: 0.374\n",
      "SMLL: -0.462\n",
      "MLL: 1.746\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 202/2500 [00:28<09:24,  4.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 3.008\n",
      "ELBO: -423.643\n",
      "IWAE: -386.221\n",
      "SMSE: 0.341\n",
      "SMLL: -0.480\n",
      "MLL: 1.728\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:37<17:58,  2.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 2.968\n",
      "ELBO: -246.225\n",
      "IWAE: -212.639\n",
      "SMSE: 0.314\n",
      "SMLL: -0.505\n",
      "MLL: 1.704\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 302/2500 [00:44<15:50,  2.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 2.220\n",
      "ELBO: -112.152\n",
      "IWAE: -89.425\n",
      "SMSE: 0.291\n",
      "SMLL: -0.505\n",
      "MLL: 1.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [00:50<09:25,  3.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 1.646\n",
      "ELBO: -54.342\n",
      "IWAE: -30.157\n",
      "SMSE: 0.259\n",
      "SMLL: -0.548\n",
      "MLL: 1.660\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [00:57<14:41,  2.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 1.588\n",
      "ELBO: 77.331\n",
      "IWAE: 107.041\n",
      "SMSE: 0.261\n",
      "SMLL: -0.496\n",
      "MLL: 1.712\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 452/2500 [01:05<12:05,  2.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: 1.049\n",
      "ELBO: 169.964\n",
      "IWAE: 206.604\n",
      "SMSE: 0.233\n",
      "SMLL: -0.469\n",
      "MLL: 1.739\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 502/2500 [01:14<24:43,  1.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: 1.059\n",
      "ELBO: 196.803\n",
      "IWAE: 230.372\n",
      "SMSE: 0.212\n",
      "SMLL: -0.509\n",
      "MLL: 1.699\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 552/2500 [01:22<15:08,  2.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: 1.151\n",
      "ELBO: 238.699\n",
      "IWAE: 269.858\n",
      "SMSE: 0.217\n",
      "SMLL: -0.495\n",
      "MLL: 1.714\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 602/2500 [01:30<12:28,  2.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: 0.922\n",
      "ELBO: 243.672\n",
      "IWAE: 265.299\n",
      "SMSE: 0.218\n",
      "SMLL: -0.470\n",
      "MLL: 1.738\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 652/2500 [01:37<12:11,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: 0.863\n",
      "ELBO: 250.417\n",
      "IWAE: 275.619\n",
      "SMSE: 0.213\n",
      "SMLL: -0.457\n",
      "MLL: 1.752\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 702/2500 [01:46<17:52,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: 0.758\n",
      "ELBO: 266.397\n",
      "IWAE: 307.847\n",
      "SMSE: 0.193\n",
      "SMLL: -0.564\n",
      "MLL: 1.644\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 752/2500 [01:55<17:37,  1.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: 0.689\n",
      "ELBO: 271.719\n",
      "IWAE: 300.171\n",
      "SMSE: 0.206\n",
      "SMLL: -0.493\n",
      "MLL: 1.715\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 802/2500 [02:05<13:07,  2.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 0.874\n",
      "ELBO: 318.715\n",
      "IWAE: 343.491\n",
      "SMSE: 0.200\n",
      "SMLL: -0.406\n",
      "MLL: 1.803\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 852/2500 [02:12<12:37,  2.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: 0.712\n",
      "ELBO: 322.329\n",
      "IWAE: 341.746\n",
      "SMSE: 0.192\n",
      "SMLL: -0.536\n",
      "MLL: 1.672\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 902/2500 [02:19<10:16,  2.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: 0.496\n",
      "ELBO: 342.628\n",
      "IWAE: 365.631\n",
      "SMSE: 0.194\n",
      "SMLL: -0.435\n",
      "MLL: 1.773\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 952/2500 [02:26<10:14,  2.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: 0.705\n",
      "ELBO: 337.308\n",
      "IWAE: 366.309\n",
      "SMSE: 0.194\n",
      "SMLL: -0.497\n",
      "MLL: 1.711\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1002/2500 [02:34<10:07,  2.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: 0.281\n",
      "ELBO: 353.208\n",
      "IWAE: 374.543\n",
      "SMSE: 0.185\n",
      "SMLL: -0.501\n",
      "MLL: 1.708\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1051/2500 [02:41<12:36,  1.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: 0.681\n",
      "ELBO: 309.549\n",
      "IWAE: 338.181\n",
      "SMSE: 0.180\n",
      "SMLL: -0.603\n",
      "MLL: 1.606\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1102/2500 [02:48<09:44,  2.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: 0.324\n",
      "ELBO: 380.302\n",
      "IWAE: 397.274\n",
      "SMSE: 0.187\n",
      "SMLL: -0.449\n",
      "MLL: 1.759\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1152/2500 [02:56<08:53,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: 0.331\n",
      "ELBO: 361.346\n",
      "IWAE: 378.527\n",
      "SMSE: 0.187\n",
      "SMLL: -0.471\n",
      "MLL: 1.738\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1202/2500 [03:03<08:56,  2.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: 0.385\n",
      "ELBO: 344.661\n",
      "IWAE: 374.898\n",
      "SMSE: 0.194\n",
      "SMLL: -0.406\n",
      "MLL: 1.802\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1252/2500 [03:10<08:20,  2.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: 0.659\n",
      "ELBO: 379.522\n",
      "IWAE: 402.361\n",
      "SMSE: 0.180\n",
      "SMLL: -0.503\n",
      "MLL: 1.705\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1302/2500 [03:18<07:56,  2.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: 0.394\n",
      "ELBO: 381.561\n",
      "IWAE: 408.797\n",
      "SMSE: 0.182\n",
      "SMLL: -0.507\n",
      "MLL: 1.702\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1352/2500 [03:25<07:39,  2.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: 0.255\n",
      "ELBO: 382.235\n",
      "IWAE: 410.519\n",
      "SMSE: 0.183\n",
      "SMLL: -0.525\n",
      "MLL: 1.683\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1402/2500 [03:32<07:16,  2.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: 0.404\n",
      "ELBO: 386.345\n",
      "IWAE: 404.439\n",
      "SMSE: 0.195\n",
      "SMLL: -0.386\n",
      "MLL: 1.822\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1452/2500 [03:39<07:00,  2.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: 0.361\n",
      "ELBO: 419.184\n",
      "IWAE: 439.627\n",
      "SMSE: 0.192\n",
      "SMLL: -0.364\n",
      "MLL: 1.844\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1502/2500 [03:47<06:34,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: 0.402\n",
      "ELBO: 405.480\n",
      "IWAE: 422.918\n",
      "SMSE: 0.187\n",
      "SMLL: -0.460\n",
      "MLL: 1.749\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1552/2500 [03:54<06:17,  2.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: 0.332\n",
      "ELBO: 420.447\n",
      "IWAE: 434.328\n",
      "SMSE: 0.195\n",
      "SMLL: -0.331\n",
      "MLL: 1.877\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1602/2500 [04:02<06:47,  2.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: 0.530\n",
      "ELBO: 415.726\n",
      "IWAE: 435.972\n",
      "SMSE: 0.177\n",
      "SMLL: -0.513\n",
      "MLL: 1.695\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1652/2500 [04:10<06:42,  2.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: 0.335\n",
      "ELBO: 411.841\n",
      "IWAE: 432.917\n",
      "SMSE: 0.184\n",
      "SMLL: -0.436\n",
      "MLL: 1.772\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1702/2500 [04:19<07:27,  1.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: 0.350\n",
      "ELBO: 423.815\n",
      "IWAE: 437.522\n",
      "SMSE: 0.191\n",
      "SMLL: -0.396\n",
      "MLL: 1.813\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1752/2500 [04:27<06:18,  1.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: 0.081\n",
      "ELBO: 422.153\n",
      "IWAE: 445.788\n",
      "SMSE: 0.186\n",
      "SMLL: -0.401\n",
      "MLL: 1.807\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1802/2500 [04:36<05:37,  2.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: 0.501\n",
      "ELBO: 405.843\n",
      "IWAE: 430.180\n",
      "SMSE: 0.178\n",
      "SMLL: -0.477\n",
      "MLL: 1.731\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1852/2500 [04:45<05:22,  2.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: 0.430\n",
      "ELBO: 417.580\n",
      "IWAE: 436.066\n",
      "SMSE: 0.177\n",
      "SMLL: -0.458\n",
      "MLL: 1.750\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1902/2500 [04:56<06:23,  1.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: 0.367\n",
      "ELBO: 412.877\n",
      "IWAE: 430.347\n",
      "SMSE: 0.182\n",
      "SMLL: -0.417\n",
      "MLL: 1.791\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1952/2500 [05:05<03:36,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: 0.317\n",
      "ELBO: 393.038\n",
      "IWAE: 423.620\n",
      "SMSE: 0.190\n",
      "SMLL: -0.367\n",
      "MLL: 1.841\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2002/2500 [05:12<02:53,  2.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: 0.129\n",
      "ELBO: 430.204\n",
      "IWAE: 443.514\n",
      "SMSE: 0.176\n",
      "SMLL: -0.449\n",
      "MLL: 1.759\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2052/2500 [05:19<02:42,  2.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: 0.219\n",
      "ELBO: 454.512\n",
      "IWAE: 472.490\n",
      "SMSE: 0.185\n",
      "SMLL: -0.409\n",
      "MLL: 1.799\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2102/2500 [05:25<02:23,  2.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: 0.237\n",
      "ELBO: 426.132\n",
      "IWAE: 441.975\n",
      "SMSE: 0.180\n",
      "SMLL: -0.426\n",
      "MLL: 1.782\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2152/2500 [05:32<01:32,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: 0.196\n",
      "ELBO: 432.890\n",
      "IWAE: 449.639\n",
      "SMSE: 0.196\n",
      "SMLL: -0.303\n",
      "MLL: 1.905\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2202/2500 [05:38<01:19,  3.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: 0.375\n",
      "ELBO: 430.795\n",
      "IWAE: 459.549\n",
      "SMSE: 0.183\n",
      "SMLL: -0.404\n",
      "MLL: 1.804\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2252/2500 [05:44<00:58,  4.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: 0.740\n",
      "ELBO: 422.508\n",
      "IWAE: 448.522\n",
      "SMSE: 0.177\n",
      "SMLL: -0.476\n",
      "MLL: 1.733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2302/2500 [05:50<00:54,  3.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: 0.347\n",
      "ELBO: 436.080\n",
      "IWAE: 454.677\n",
      "SMSE: 0.182\n",
      "SMLL: -0.443\n",
      "MLL: 1.765\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2352/2500 [05:57<00:56,  2.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: 0.082\n",
      "ELBO: 440.753\n",
      "IWAE: 461.892\n",
      "SMSE: 0.181\n",
      "SMLL: -0.425\n",
      "MLL: 1.783\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2402/2500 [06:04<00:41,  2.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: 0.097\n",
      "ELBO: 438.916\n",
      "IWAE: 462.991\n",
      "SMSE: 0.186\n",
      "SMLL: -0.379\n",
      "MLL: 1.829\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2452/2500 [06:12<00:20,  2.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: -0.005\n",
      "ELBO: 452.876\n",
      "IWAE: 476.165\n",
      "SMSE: 0.190\n",
      "SMLL: -0.427\n",
      "MLL: 1.781\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [06:19<00:00,  6.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: 0.329\n",
      "ELBO: 426.070\n",
      "IWAE: 445.443\n",
      "SMSE: 0.189\n",
      "SMLL: -0.401\n",
      "MLL: 1.807\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [7.951560393256571,\n",
       "  4.511696204099647,\n",
       "  4.015930386465222,\n",
       "  3.642990691490683,\n",
       "  3.007542849338172,\n",
       "  2.967567885417619,\n",
       "  2.219837609775176,\n",
       "  1.6462661158564267,\n",
       "  1.5876946874384918,\n",
       "  1.0494613722700066,\n",
       "  1.0588066798135645,\n",
       "  1.1510157501188438,\n",
       "  0.9215578292157801,\n",
       "  0.8627156647161579,\n",
       "  0.7580663233104882,\n",
       "  0.6886249965382362,\n",
       "  0.8736288347744964,\n",
       "  0.7120909111788597,\n",
       "  0.49551156132176405,\n",
       "  0.704684398272225,\n",
       "  0.2810995577847818,\n",
       "  0.6807405304523941,\n",
       "  0.32395354682982797,\n",
       "  0.33059555459966566,\n",
       "  0.38479541113835447,\n",
       "  0.6594397690732213,\n",
       "  0.3936209404260708,\n",
       "  0.25536059832465324,\n",
       "  0.4042998690254466,\n",
       "  0.36133988897106767,\n",
       "  0.4016255442615942,\n",
       "  0.3323178161730789,\n",
       "  0.5300496047651532,\n",
       "  0.33520278150343596,\n",
       "  0.3498853140063949,\n",
       "  0.08105885878847312,\n",
       "  0.5013176853898754,\n",
       "  0.43041746579325396,\n",
       "  0.3668054100609342,\n",
       "  0.31722976125813623,\n",
       "  0.12899725536186732,\n",
       "  0.2188287666098705,\n",
       "  0.23700657072965184,\n",
       "  0.1957835626750499,\n",
       "  0.3751066624096402,\n",
       "  0.7404635411966742,\n",
       "  0.34721755873015187,\n",
       "  0.08205018945615213,\n",
       "  0.09686844999061568,\n",
       "  -0.004606935861035333,\n",
       "  0.3287836863564007],\n",
       " 'elbos': [tensor(-1573.1690, grad_fn=<AddBackward0>),\n",
       "  tensor(-940.8916, grad_fn=<AddBackward0>),\n",
       "  tensor(-791.6286, grad_fn=<AddBackward0>),\n",
       "  tensor(-600.1406, grad_fn=<AddBackward0>),\n",
       "  tensor(-423.6430, grad_fn=<AddBackward0>),\n",
       "  tensor(-246.2251, grad_fn=<AddBackward0>),\n",
       "  tensor(-112.1515, grad_fn=<AddBackward0>),\n",
       "  tensor(-54.3421, grad_fn=<AddBackward0>),\n",
       "  tensor(77.3314, grad_fn=<AddBackward0>),\n",
       "  tensor(169.9635, grad_fn=<AddBackward0>),\n",
       "  tensor(196.8030, grad_fn=<AddBackward0>),\n",
       "  tensor(238.6987, grad_fn=<AddBackward0>),\n",
       "  tensor(243.6720, grad_fn=<AddBackward0>),\n",
       "  tensor(250.4172, grad_fn=<AddBackward0>),\n",
       "  tensor(266.3971, grad_fn=<AddBackward0>),\n",
       "  tensor(271.7195, grad_fn=<AddBackward0>),\n",
       "  tensor(318.7150, grad_fn=<AddBackward0>),\n",
       "  tensor(322.3289, grad_fn=<AddBackward0>),\n",
       "  tensor(342.6280, grad_fn=<AddBackward0>),\n",
       "  tensor(337.3078, grad_fn=<AddBackward0>),\n",
       "  tensor(353.2079, grad_fn=<AddBackward0>),\n",
       "  tensor(309.5491, grad_fn=<AddBackward0>),\n",
       "  tensor(380.3024, grad_fn=<AddBackward0>),\n",
       "  tensor(361.3459, grad_fn=<AddBackward0>),\n",
       "  tensor(344.6608, grad_fn=<AddBackward0>),\n",
       "  tensor(379.5217, grad_fn=<AddBackward0>),\n",
       "  tensor(381.5613, grad_fn=<AddBackward0>),\n",
       "  tensor(382.2349, grad_fn=<AddBackward0>),\n",
       "  tensor(386.3449, grad_fn=<AddBackward0>),\n",
       "  tensor(419.1839, grad_fn=<AddBackward0>),\n",
       "  tensor(405.4799, grad_fn=<AddBackward0>),\n",
       "  tensor(420.4467, grad_fn=<AddBackward0>),\n",
       "  tensor(415.7256, grad_fn=<AddBackward0>),\n",
       "  tensor(411.8407, grad_fn=<AddBackward0>),\n",
       "  tensor(423.8146, grad_fn=<AddBackward0>),\n",
       "  tensor(422.1535, grad_fn=<AddBackward0>),\n",
       "  tensor(405.8434, grad_fn=<AddBackward0>),\n",
       "  tensor(417.5801, grad_fn=<AddBackward0>),\n",
       "  tensor(412.8769, grad_fn=<AddBackward0>),\n",
       "  tensor(393.0378, grad_fn=<AddBackward0>),\n",
       "  tensor(430.2035, grad_fn=<AddBackward0>),\n",
       "  tensor(454.5123, grad_fn=<AddBackward0>),\n",
       "  tensor(426.1322, grad_fn=<AddBackward0>),\n",
       "  tensor(432.8898, grad_fn=<AddBackward0>),\n",
       "  tensor(430.7952, grad_fn=<AddBackward0>),\n",
       "  tensor(422.5080, grad_fn=<AddBackward0>),\n",
       "  tensor(436.0802, grad_fn=<AddBackward0>),\n",
       "  tensor(440.7532, grad_fn=<AddBackward0>),\n",
       "  tensor(438.9163, grad_fn=<AddBackward0>),\n",
       "  tensor(452.8756, grad_fn=<AddBackward0>),\n",
       "  tensor(426.0699, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-1510.7306, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-890.7695, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-758.0424, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-574.8263, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-386.2209, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-212.6387, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-89.4254, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-30.1570, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(107.0409, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(206.6041, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(230.3717, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(269.8578, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(265.2993, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(275.6189, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(307.8474, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(300.1713, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(343.4906, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(341.7460, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(365.6314, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(366.3089, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(374.5432, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(338.1807, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(397.2741, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(378.5274, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(374.8980, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(402.3609, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(408.7971, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(410.5191, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(404.4393, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(439.6269, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(422.9175, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(434.3280, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(435.9716, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(432.9167, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(437.5225, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(445.7880, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(430.1805, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(436.0658, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(430.3474, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(423.6198, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(443.5141, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(472.4903, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(441.9750, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(449.6385, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(459.5485, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(448.5221, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(454.6766, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(461.8919, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(462.9911, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(476.1652, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(445.4428, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [0.5631966386870645,\n",
       "  0.45515475156173063,\n",
       "  0.4240959807875427,\n",
       "  0.37399409502810715,\n",
       "  0.34104566198319164,\n",
       "  0.3139883069249995,\n",
       "  0.2910054446832407,\n",
       "  0.25885238247107817,\n",
       "  0.2610670525554321,\n",
       "  0.23338345394237195,\n",
       "  0.21184920526442705,\n",
       "  0.2166031640447752,\n",
       "  0.21804651067810496,\n",
       "  0.21301638305667914,\n",
       "  0.19298445356818852,\n",
       "  0.20582175512552245,\n",
       "  0.19987633277531422,\n",
       "  0.1921124004099762,\n",
       "  0.19365692232833667,\n",
       "  0.19405785493326297,\n",
       "  0.18495876107799783,\n",
       "  0.1796015114185734,\n",
       "  0.187383469737165,\n",
       "  0.18660471270517262,\n",
       "  0.19378439786491283,\n",
       "  0.1795289837837208,\n",
       "  0.18235041757687567,\n",
       "  0.18278157550316784,\n",
       "  0.19469577913029787,\n",
       "  0.1917799969391748,\n",
       "  0.18722573321923916,\n",
       "  0.19494328625449453,\n",
       "  0.17715724847635184,\n",
       "  0.1836723890303956,\n",
       "  0.1906001520115744,\n",
       "  0.18590271699249702,\n",
       "  0.1784784543091784,\n",
       "  0.17676919322072937,\n",
       "  0.18245870590442012,\n",
       "  0.19012872611766785,\n",
       "  0.17637740311313277,\n",
       "  0.1850872543044936,\n",
       "  0.179733372664619,\n",
       "  0.19552838999563163,\n",
       "  0.18258081236930168,\n",
       "  0.17651320032412698,\n",
       "  0.18193244145097762,\n",
       "  0.1814617316518106,\n",
       "  0.18622735763799123,\n",
       "  0.18965720010526835,\n",
       "  0.18935132696591025],\n",
       " 'smlls': [0.49443201231402156,\n",
       "  -0.3311868385990602,\n",
       "  -0.40597328527851273,\n",
       "  -0.4623074536230613,\n",
       "  -0.4802171357636799,\n",
       "  -0.5046843119205077,\n",
       "  -0.5051104471732296,\n",
       "  -0.5483752562664077,\n",
       "  -0.49624883998773944,\n",
       "  -0.4693914793217379,\n",
       "  -0.5089204779119452,\n",
       "  -0.49471782793056945,\n",
       "  -0.47035577420265523,\n",
       "  -0.45657689543229935,\n",
       "  -0.5644519989154609,\n",
       "  -0.4931537400092892,\n",
       "  -0.4055714236916132,\n",
       "  -0.5358468241915612,\n",
       "  -0.4347910205663799,\n",
       "  -0.49714153463634186,\n",
       "  -0.5005508698393478,\n",
       "  -0.6027315074599823,\n",
       "  -0.4491243698394893,\n",
       "  -0.4706594945103119,\n",
       "  -0.40645147828591566,\n",
       "  -0.5032314667304568,\n",
       "  -0.506615556490099,\n",
       "  -0.5248162905351977,\n",
       "  -0.38648018724841354,\n",
       "  -0.36422540757758587,\n",
       "  -0.45973382244730954,\n",
       "  -0.331330607123252,\n",
       "  -0.5131412098044494,\n",
       "  -0.4361608804231625,\n",
       "  -0.39565158965099356,\n",
       "  -0.4011476891044253,\n",
       "  -0.4768068672786055,\n",
       "  -0.45843245947947425,\n",
       "  -0.41728495917838493,\n",
       "  -0.3669009894276958,\n",
       "  -0.4489211089763887,\n",
       "  -0.4093345654650831,\n",
       "  -0.42641674942135305,\n",
       "  -0.3032738396629822,\n",
       "  -0.40431611824243113,\n",
       "  -0.4757356224597096,\n",
       "  -0.4431383865784251,\n",
       "  -0.4249082607364418,\n",
       "  -0.3787530669066352,\n",
       "  -0.42678131087950105,\n",
       "  -0.40091475005930705],\n",
       " 'mlls': [2.7026736039503896,\n",
       "  1.8770547530373083,\n",
       "  1.8022683063578555,\n",
       "  1.7459341380133069,\n",
       "  1.7280244558726885,\n",
       "  1.7035572797158605,\n",
       "  1.7031311444631385,\n",
       "  1.6598663353699605,\n",
       "  1.7119927516486289,\n",
       "  1.7388501123146305,\n",
       "  1.699321113724423,\n",
       "  1.7135237637057987,\n",
       "  1.7378858174337133,\n",
       "  1.7516646962040692,\n",
       "  1.6437895927209076,\n",
       "  1.7150878516270793,\n",
       "  1.8026701679447552,\n",
       "  1.6723947674448072,\n",
       "  1.7734505710699882,\n",
       "  1.7111000570000268,\n",
       "  1.7076907217970205,\n",
       "  1.6055100841763863,\n",
       "  1.7591172217968791,\n",
       "  1.7375820971260563,\n",
       "  1.8017901133504524,\n",
       "  1.7050101249059113,\n",
       "  1.7016260351462693,\n",
       "  1.6834253011011704,\n",
       "  1.8217614043879546,\n",
       "  1.8440161840587825,\n",
       "  1.748507769189059,\n",
       "  1.8769109845131162,\n",
       "  1.6951003818319188,\n",
       "  1.7720807112132058,\n",
       "  1.8125900019853747,\n",
       "  1.8070939025319432,\n",
       "  1.7314347243577626,\n",
       "  1.749809132156894,\n",
       "  1.7909566324579833,\n",
       "  1.8413406022086727,\n",
       "  1.7593204826599795,\n",
       "  1.7989070261712852,\n",
       "  1.7818248422150151,\n",
       "  1.9049677519733859,\n",
       "  1.8039254733939372,\n",
       "  1.7325059691766589,\n",
       "  1.7651032050579432,\n",
       "  1.7833333308999266,\n",
       "  1.829488524729733,\n",
       "  1.7814602807568674,\n",
       "  1.8073268415770611]}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(j_model, gpvae.estimators.gpvae_estimators.analytical_estimator, \n",
    "    loader, j_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/2500 [00:00<15:59,  2.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 359.346\n",
      "ELBO: -75055.496\n",
      "SMSE: 1.605\n",
      "SMLL: 20.794\n",
      "MLL: 23.002\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:06<07:54,  5.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 161.082\n",
      "ELBO: -34117.173\n",
      "SMSE: 1.632\n",
      "SMLL: 9.781\n",
      "MLL: 11.990\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 102/2500 [00:12<07:33,  5.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 117.793\n",
      "ELBO: -24917.558\n",
      "SMSE: 1.702\n",
      "SMLL: 7.192\n",
      "MLL: 9.400\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 152/2500 [00:19<10:55,  3.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 82.865\n",
      "ELBO: -17626.448\n",
      "SMSE: 1.741\n",
      "SMLL: 5.484\n",
      "MLL: 7.692\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 202/2500 [00:29<10:29,  3.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 63.124\n",
      "ELBO: -13525.054\n",
      "SMSE: 1.627\n",
      "SMLL: 4.193\n",
      "MLL: 6.401\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:34<06:11,  6.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 52.389\n",
      "ELBO: -11135.161\n",
      "SMSE: 1.705\n",
      "SMLL: 3.908\n",
      "MLL: 6.116\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 302/2500 [00:40<06:36,  5.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 45.206\n",
      "ELBO: -9715.282\n",
      "SMSE: 1.794\n",
      "SMLL: 3.671\n",
      "MLL: 5.880\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [00:47<07:14,  4.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 41.127\n",
      "ELBO: -8745.516\n",
      "SMSE: 1.884\n",
      "SMLL: 3.603\n",
      "MLL: 5.811\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [00:53<06:32,  5.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 37.215\n",
      "ELBO: -7993.247\n",
      "SMSE: 1.930\n",
      "SMLL: 3.242\n",
      "MLL: 5.450\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 446/2500 [00:58<04:29,  7.62it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-63-76f1b2fe3f00>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvfe_elbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     save_model=True)\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;31m# train_model(vfe_model, vfe_analytical_estimator, loader, vfe_args,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;31m#             gpvae.estimators.gpvae_estimators.vfe_elbo_estimator, None,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/experiments/eeg/train_eeg.py\u001b[0m in \u001b[0;36mtrain_eeg\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model)\u001b[0m\n\u001b[1;32m     62\u001b[0m                 loss = loss_fn(\n\u001b[1;32m     63\u001b[0m                     \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mm_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m                     decoder_scale=args['decoder_scale'], mf=True, idx=idx_b)\n\u001b[0m\u001b[1;32m     65\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m                 loss = loss_fn(\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/estimators/gpvae_estimators.py\u001b[0m in \u001b[0;36mvfe_analytical_estimator\u001b[0;34m(model, x, y, mask, num_samples, decoder_scale, make_lazy, mf, idx)\u001b[0m\n\u001b[1;32m    379\u001b[0m         \u001b[0;31m# Pass mean-field models the data indeces.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    380\u001b[0m         \u001b[0mqf_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mqf_cov\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mqu_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mqu_cov\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpu_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpu_cov\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlu_y_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlu_y_cov\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m             \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_latent_dists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    382\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m         \u001b[0;31m# Pass amortisation models the observation data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/models/gpvae.py\u001b[0m in \u001b[0;36mget_latent_dists\u001b[0;34m(self, x, x_test, full_cov, **kwargs)\u001b[0m\n\u001b[1;32m    301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    302\u001b[0m             \u001b[0;31m# GP conditional prior.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 303\u001b[0;31m             \u001b[0mkfu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    304\u001b[0m             \u001b[0mkuf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkfu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    305\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/composition_kernels.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x1, x2, diag, embed)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mcovs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkernel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdiag\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/composition_kernels.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mcovs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkernel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdiag\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/composition_kernels.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x1, x2, diag)\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m         \u001b[0mcov\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     38\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mcov\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/composition_kernels.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x1, x2, diag, embed)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mcovs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkernel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdiag\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/composition_kernels.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mcovs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkernel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdiag\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0membed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/kernels/kernels.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x1, x2, diag)\u001b[0m\n\u001b[1;32m     53\u001b[0m         \u001b[0;31m# [M1, M2, D] or [M, D] if diag.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m         \u001b[0msd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mx2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m         \u001b[0msd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclamp_min_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m         \u001b[0;31m# Apply lengthscale and sum over dimensions.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_eeg(\n",
    "    vfe_model, gpvae.estimators.gpvae_estimators.vfe_analytical_estimator, vfe_loader, vfe_args,\n",
    "    gpvae.estimators.gpvae_estimators.vfe_elbo_estimator,\n",
    "    None,\n",
    "    save_model=True)\n",
    "# train_model(vfe_model, vfe_analytical_estimator, loader, vfe_args,\n",
    "#             gpvae.estimators.gpvae_estimators.vfe_elbo_estimator, None,\n",
    "#             train, test, y_std, y_mean)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train GPVAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:24: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f53240c91acd40259f91043637faf867",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/gpytorch/utils/cholesky.py:43: RuntimeWarning: A not p.d., added jitter of 1e-08 to the diagonal\n",
      "  warnings.warn(f\"A not p.d., added jitter of {jitter_new} to the diagonal\", RuntimeWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 347.364\n",
      "ELBO: -75585.215\n",
      "IWAE: -75321.544\n",
      "SMSE: 1.911\n",
      "\n",
      "Epoch 50\n",
      "Loss: 79.130\n",
      "ELBO: -17991.391\n",
      "IWAE: -17985.733\n",
      "SMSE: 2.164\n",
      "\n",
      "Epoch 100\n",
      "Loss: 19.265\n",
      "ELBO: -3731.229\n",
      "IWAE: -3606.099\n",
      "SMSE: 1.205\n",
      "\n",
      "Epoch 150\n",
      "Loss: 11.863\n",
      "ELBO: -1978.322\n",
      "IWAE: -1840.350\n",
      "SMSE: 0.603\n",
      "\n",
      "Epoch 200\n",
      "Loss: 7.954\n",
      "ELBO: -1234.452\n",
      "IWAE: -1115.534\n",
      "SMSE: 0.517\n",
      "\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m--------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-16-10daeaf83f6a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m train_model(r_model, gpvae.estimators.gpvae_estimators.td_estimator, loader, r_args,\n\u001b[1;32m      3\u001b[0m             \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miwae_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m             train, test, y_std, y_mean)\n\u001b[0m",
      "\u001b[0;32m<ipython-input-14-f5e5397c2d05>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, train, test, y_std, y_mean)\u001b[0m\n\u001b[1;32m     36\u001b[0m             loss = loss_fn(model, x=x_b, y=y_b, mask=m_b, \n\u001b[1;32m     37\u001b[0m                            num_samples=1, decoder_scale=args['decoder_scale'])\n\u001b[0;32m---> 38\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     39\u001b[0m             \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    193\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    194\u001b[0m         \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Rich's method.\n",
    "train_model(r_model, gpvae.estimators.gpvae_estimators.td_estimator, loader, r_args,\n",
    "            gpvae.estimators.gpvae_estimators.elbo_estimator, gpvae.estimators.gpvae_estimators.iwae_estimator,\n",
    "            train, test, y_std, y_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:26: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2428008d106d46f491db3b9ea6e369a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 389519.897\n",
      "ELBO: -69914.133\n",
      "IWAE: -69697.424\n",
      "SMSE: 1.381\n",
      "\n",
      "Epoch 50\n",
      "Loss: 42.709\n",
      "ELBO: -13221.200\n",
      "IWAE: -13035.966\n",
      "SMSE: 2.075\n",
      "\n",
      "Epoch 100\n",
      "Loss: 8.835\n",
      "ELBO: -7316.364\n",
      "IWAE: -7286.288\n",
      "SMSE: 1.461\n",
      "\n",
      "Epoch 150\n",
      "Loss: 6.564\n",
      "ELBO: -5751.106\n",
      "IWAE: -5593.308\n",
      "SMSE: 1.361\n",
      "\n",
      "Epoch 200\n",
      "Loss: 3.647\n",
      "ELBO: -5102.733\n",
      "IWAE: -4958.983\n",
      "SMSE: 1.247\n",
      "\n",
      "Epoch 250\n",
      "Loss: 2.302\n",
      "ELBO: -4381.873\n",
      "IWAE: -4310.915\n",
      "SMSE: 1.202\n",
      "\n",
      "Epoch 300\n",
      "Loss: 2.247\n",
      "ELBO: -3914.110\n",
      "IWAE: -3790.483\n",
      "SMSE: 1.178\n",
      "\n",
      "Epoch 350\n",
      "Loss: 1.189\n",
      "ELBO: -3565.395\n",
      "IWAE: -3489.890\n",
      "SMSE: 1.140\n",
      "\n",
      "Epoch 400\n",
      "Loss: 0.497\n",
      "ELBO: -3220.932\n",
      "IWAE: -3151.706\n",
      "SMSE: 1.409\n",
      "\n",
      "Epoch 450\n",
      "Loss: -0.054\n",
      "ELBO: -3101.303\n",
      "IWAE: -2989.301\n",
      "SMSE: 1.330\n",
      "\n",
      "Epoch 500\n",
      "Loss: 0.599\n",
      "ELBO: -2915.828\n",
      "IWAE: -2821.845\n",
      "SMSE: 1.429\n",
      "\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m-----------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                     Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-e00980e671db>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m train_model(j_model, gpvae.estimators.gpvae_estimators.td_estimator, loader, j_args,\n\u001b[1;32m      3\u001b[0m             \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miwae_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m             train, test, y_std, y_mean)\n\u001b[0m",
      "\u001b[0;32m<ipython-input-15-ea4fac0960fe>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, train, test, y_std, y_mean)\u001b[0m\n\u001b[1;32m     42\u001b[0m                 loss = loss_fn(\n\u001b[1;32m     43\u001b[0m                     \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mm_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m                     decoder_scale=args['decoder_scale'])\n\u001b[0m\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/estimators/gpvae_estimators.py\u001b[0m in \u001b[0;36mtd_estimator\u001b[0;34m(model, x, y, mask, num_samples, decoder_scale, make_lazy, mf, idx)\u001b[0m\n\u001b[1;32m    144\u001b[0m     \u001b[0;31m# Monte-Carlo estimate of ELBO gradient.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m     \u001b[0;31m# See Spatio-Temporal VAEs: ELBO Gradient Estimators.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    147\u001b[0m         \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mqf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrsample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Jonny's method.\n",
    "train_model(j_model, gpvae.estimators.gpvae_estimators.td_estimator, loader, j_args,\n",
    "            gpvae.estimators.gpvae_estimators.elbo_estimator, gpvae.estimators.gpvae_estimators.iwae_estimator,\n",
    "            train, test, y_std, y_mean)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions and class definitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "JITTER = 1e-5\n",
    "\n",
    "class AffineGaussian(nn.Module):\n",
    "    \"\"\"The mean of the output Gaussian is an affine transformation of the\n",
    "    input.\n",
    "\n",
    "    :param in_dim: An int, the dimension of the input variable.\n",
    "    :param out_dim: An int, the dimension of the output variable.\n",
    "    :param sigma: A float, sets the initial output sigma (shared across all\n",
    "    output dimensions).\n",
    "    :param initial_weight: A float, sets the initial weight.\n",
    "    :param initial_bias: A float, sets the initial bias.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_dim, out_dim, sigma=1., initial_weight=None,\n",
    "                 initial_bias=None):\n",
    "        super().__init__()\n",
    "\n",
    "        self.in_dim = in_dim\n",
    "        self.out_dim = out_dim\n",
    "\n",
    "        if initial_weight is None:\n",
    "            initial_weight = torch.ones(out_dim, in_dim) / in_dim\n",
    "        else:\n",
    "            initial_weight = torch.tensor(initial_weight)\n",
    "\n",
    "        if initial_bias is None:\n",
    "            initial_bias = torch.zeros(out_dim)\n",
    "        else:\n",
    "            initial_bias = torch.tensor(initial_bias)\n",
    "\n",
    "        # Initial weight and bias of the affine transformation.\n",
    "        self.weight = nn.Parameter(initial_weight + JITTER * torch.randn(\n",
    "            out_dim, in_dim), requires_grad=True)\n",
    "        self.bias = nn.Parameter(initial_bias + JITTER * torch.randn(out_dim),\n",
    "                                 requires_grad=True)\n",
    "\n",
    "        self.raw_sigma = nn.Parameter(torch.tensor(sigma).log(),\n",
    "                                      requires_grad=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"Returns parameters of a Gaussian distribution.\n",
    "\n",
    "        :param x: A Tensor, input of shape [M, 1].\n",
    "        \"\"\"\n",
    "        mu = self.weight.matmul(x.unsqueeze(2)).squeeze(2) + self.bias\n",
    "        sigma = torch.ones_like(mu) * self.raw_sigma.exp()\n",
    "        \n",
    "        return mu, sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spatio-temporal-vae-env",
   "language": "python",
   "name": "spatio-temporal-vae-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
