{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import copy\n",
    "import pdb\n",
    "import time\n",
    "import pickle\n",
    "import warnings\n",
    "sys.path.append('../../')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import pandas as pd\n",
    "import gpvae\n",
    "import data.eeg\n",
    "import gpytorch\n",
    "import wbml\n",
    "\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from torch.utils.data import DataLoader\n",
    "from experiments.eeg.train_eeg import train_eeg\n",
    "from wbml import metric\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load EEG data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load from Wessel's brilliant wbml package.\n",
    "_, train, test = data.eeg.load()\n",
    "\n",
    "# Extract data into numpy arrays.\n",
    "x = np.array(train.index)\n",
    "y = np.array(train)\n",
    "y_c = copy.deepcopy(y)\n",
    "\n",
    "# Randomly set some training data to nan.\n",
    "p = 0.05\n",
    "set_to_nan = np.random.choice(a=[False, True], size=y.shape, p=[1-p, p])\n",
    "y_c[set_to_nan] = np.nan\n",
    "\n",
    "y_dim = y.shape[1]\n",
    "decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))\n",
    "\n",
    "# Normalise observations.\n",
    "y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)\n",
    "y = (y - y_mean) / y_std\n",
    "y_c = (y_c - y_mean) / y_std\n",
    "\n",
    "# Convert to torch.tensor.\n",
    "x = torch.tensor(x)\n",
    "y = torch.tensor(y)\n",
    "y_c = torch.tensor(y_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shared hyperparameters.\n",
    "args = {'batch_size': 100,\n",
    "        'latent_dim': 5,\n",
    "        'init_lengthscale': .05,\n",
    "        'init_scale': 1.,\n",
    "        'init_period': .1,\n",
    "        'auxiliary_dim': 1,\n",
    "        'num_inducing': 100,\n",
    "        'epochs': 2000,\n",
    "        'cache_freq': 50,\n",
    "        'lr': 0.001,\n",
    "        'decoder_scale': decoder_scale,\n",
    "        'elbo_samples': 100,\n",
    "        'test_samples': 100\n",
    "       }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = gpvae.utils.dataset_utils.TupleDataset(x, y, contains_nan=True)\n",
    "# loader = DataLoader(dataset, batch_size=args['batch_size'], shuffle=True)\n",
    "\n",
    "dataset_c = ConditionTupleDataset(x, y, y_c, contains_nan=True)\n",
    "loader_c = DataLoader(dataset_c, batch_size=args['batch_size'], shuffle=True)\n",
    "\n",
    "# For models using the VFE approximation.\n",
    "vfe_loader = DataLoader(dataset_c, batch_size=len(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Latent GP kernel.\n",
    "k1 = gpvae.kernels.RBFKernel(lengthscale=args['init_lengthscale'], \n",
    "                             scale=args['init_scale'])\n",
    "k2 = gpvae.kernels.PeriodicKernel(lengthscale=args['init_lengthscale'],\n",
    "                                  period=args['init_period'], \n",
    "                                  scale=args['init_scale']/2)\n",
    "\n",
    "# k1 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "# k2 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel())\n",
    "# k1.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k1._set_outputscale(args['init_scale'] / 2)\n",
    "# k2.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k2.base_kernel._set_period_length(args['init_lengthscale'])\n",
    "# k2._set_outputscale(args['init_scale'] / 2)\n",
    "\n",
    "# kernel = gpvae.kernels.AdditiveKernel(k1, k2)\n",
    "kernel = k1\n",
    "\n",
    "# Likelihood hyperparameters.\n",
    "likelihood_args = {'in_dim': args['latent_dim'],\n",
    "                'out_dim': y_dim,\n",
    "                'hidden_dims': [20],\n",
    "                'sigma': 0.1,\n",
    "                'train_sigma': True\n",
    "               }\n",
    "\n",
    "linear_likelihood_args = {'in_dim': args['latent_dim'], \n",
    "                          'out_dim': y_dim,\n",
    "                          'sigma': 0.1}\n",
    "\n",
    "# Zero imputation inference network hyperparameters.\n",
    "zi_network_args = {'in_dim': y_dim, \n",
    "                   'out_dim': args['latent_dim'],\n",
    "                   'hidden_dims': [20, 20, 20],\n",
    "                   'initial_sigma': .1,\n",
    "                   'initial_mu': 0.,\n",
    "                   'min_sigma': 0.01\n",
    "                  }\n",
    "\n",
    "\n",
    "# Product of Gaussians inference network hyperparameters.\n",
    "pog_network_args = {'in_dim': y_dim, \n",
    "                    'out_dim': args['latent_dim'],\n",
    "                    'hidden_dims': [20, 20],\n",
    "                    'initial_sigma': .1,\n",
    "                    'initial_mu': 0.,\n",
    "                    'min_sigma': 0.01\n",
    "                   }\n",
    "\n",
    "# Semi-amortised DeepSet inference network hyperparameters.\n",
    "sads_network_args = {'in_dim': y_dim, \n",
    "                     'out_dim': args['latent_dim'],\n",
    "                     'middle_dim': 20,\n",
    "                     'hidden_dims': [20],\n",
    "                     'shared_hidden_dims': [20],\n",
    "                     'initial_sigma': .1,\n",
    "                     'initial_mu': 0.,\n",
    "                     'min_sigma': 0.01\n",
    "                    }\n",
    "\n",
    "# PointNet inference network hyperparameters.\n",
    "pn_network_args = {'out_dim': args['latent_dim'],\n",
    "                   'middle_dim': 20,\n",
    "                   'first_hidden_dims': [20],\n",
    "                   'second_hidden_dims': [20],\n",
    "                   'initial_sigma': 1.,\n",
    "                   'initial_mu': 0.,\n",
    "                   'min_sigma': 0.01\n",
    "                 }\n",
    "\n",
    "# Sparse GP hyperparameters.\n",
    "sgp_network_args = {'out_dim': args['latent_dim']}\n",
    "\n",
    "# Sparse GP initial inducing points.\n",
    "z = kmeans2(x, args['num_inducing'], minit='points')[0]\n",
    "z = torch.tensor(z)\n",
    "\n",
    "# Amortised sparse GP hyperparameters.\n",
    "amortised_sgp_network_args = {'in_dim': y_dim, \n",
    "                              'out_dim': args['latent_dim'], \n",
    "                              'hidden_dims': [20, 20, 20],\n",
    "                              'inducing_spacing': 0.025,\n",
    "                              'k': 5,\n",
    "                              'contains_nan': True\n",
    "                             }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct GPVAE models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Product of Gaussians inference network.\n",
    "pog_args = copy.deepcopy(args)\n",
    "pog_args['model'] = 'pog'\n",
    "pog_args['inference_network_args'] = pog_network_args\n",
    "pog_args['likelihood_args'] = likelihood_args\n",
    "pog_network = gpvae.networks.FactorNet(**pog_network_args)\n",
    "pog_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "pog_model = gpvae.models.GPVAE(pog_network, pog_likelihood, args['latent_dim'], kernel)\n",
    "\n",
    "# Semi-amortised DeepSet inference network.\n",
    "# sads_args = copy.deepcopy(args)\n",
    "# sads_args['model'] = 'indexnet_composition'\n",
    "# sads_args['inference_network_args'] = sads_network_args\n",
    "# sads_args['likelihood_args'] = likelihood_args\n",
    "# sads_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# sads_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# sads_model = gpvae.models.GPVAE(sads_network, sads_likelihood, args['latent_dim'], kernel, add_jitter=True)\n",
    "\n",
    "# PointNet inference network.\n",
    "# pn_args = copy.deepcopy(args)\n",
    "# pn_args['model'] = 'pn'\n",
    "# pn_args['inference_network_args'] = pn_network_args\n",
    "# pn_args['likelihood_args'] = likelihood_args\n",
    "# pn_network = gpvae.networks.PointNet(**pn_network_args)\n",
    "# pn_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# pn_model = gpvae.models.GPVAE(pn_network, pn_likelihood, args['latent_dim'], kernel, add_jitter=False)\n",
    "\n",
    "# Zero imputation inference network.\n",
    "# zi_args = copy.deepcopy(args)\n",
    "# zi_args['model'] = 'zi'\n",
    "# zi_args['inference_network_args'] = zi_network_args\n",
    "# zi_args['likelihood_args'] = likelihood_args\n",
    "# zi_network = gpvae.networks.LinearGaussian(**zi_network_args)\n",
    "# zi_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# zi_model = gpvae.models.GPVAE(zi_network, zi_likelihood, args['latent_dim'], kernel, add_jitter=False)\n",
    "\n",
    "# Semi-amortised DeepSet inference network with linear likelihood.\n",
    "# ll_args = copy.deepcopy(args)\n",
    "# ll_args['model'] = 'linear_sads'\n",
    "# ll_args['inference_network_args'] = sads_network_args\n",
    "# ll_args['likelihood_args'] = ll_args\n",
    "# ll_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# ll_likelihood = AffineGaussian(**linear_likelihood_args)\n",
    "# ll_model = gpvae.models.GPVAE(ll_network, ll_likelihood, args['latent_dim'], kernel)\n",
    "\n",
    "# Semi-amortised DeepSet VAE model.\n",
    "# vae_args = copy.deepcopy(args)\n",
    "# vae_args['model'] = 'vae_sads'\n",
    "# vae_args['inference_network_args'] = sads_network_args\n",
    "# vae_args['likelihood_args'] = likelihood_args\n",
    "# vae_network = gpvae.networks.IndexNet(**sads_network_args)\n",
    "# vae_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "# vae_model = gpvae.models.VAE(vae_network, vae_likelihood, args['latent_dim'])\n",
    "\n",
    "# vfe_args = copy.deepcopy(args)\n",
    "# vfe_args['name'] = 'vfe'\n",
    "# vfe_args['encoder_args'] = vfe_encoder_args\n",
    "# vfe_args['decoder_args'] = decoder_args\n",
    "# vfe_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "# vfe_model = gpvae.models.TitsiasSparseGPVAE(vfe_decoder, args['latent_dim'], kernel, z, min_sigma=0.001, initial_sigma=0.01)\n",
    "\n",
    "# r_vfe_args = copy.deepcopy(args)\n",
    "# r_vfe_args['name'] = 'r_vfe'\n",
    "# r_vfe_args['encoder_args'] = r_vfe_encoder_args\n",
    "# r_vfe_args['decoder_args'] = decoder_args\n",
    "# r_vfe_encoder = gpvae.networks.FixedSparseNet(**r_vfe_encoder_args)\n",
    "# r_vfe_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "# r_vfe_model = gpvae.models.SparseGPVAE(r_vfe_encoder, r_vfe_decoder, args['latent_dim'], kernel=kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'll_model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-3eabd6d76aa2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m train_eeg(ll_model, gpvae.estimators.gpvae_estimators.td_estimator, \n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mll_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'll_model' is not defined"
     ]
    }
   ],
   "source": [
    "train_eeg(ll_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, ll_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/2000 [00:00<18:26,  1.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 362.560\n",
      "ELBO: -70219.433\n",
      "SMSE: 1.492\n",
      "SMLL: 18.682\n",
      "MLL: 20.890\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 51/2000 [00:09<19:59,  1.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 13.623\n",
      "ELBO: -2630.214\n",
      "SMSE: 0.575\n",
      "SMLL: 1.050\n",
      "MLL: 3.258\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 102/2000 [00:20<07:51,  4.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 7.323\n",
      "ELBO: -1368.795\n",
      "SMSE: 0.481\n",
      "SMLL: 0.506\n",
      "MLL: 2.715\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 151/2000 [00:28<12:39,  2.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 5.018\n",
      "ELBO: -965.717\n",
      "SMSE: 0.418\n",
      "SMLL: 0.117\n",
      "MLL: 2.325\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 201/2000 [00:39<09:23,  3.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 4.022\n",
      "ELBO: -739.809\n",
      "SMSE: 0.400\n",
      "SMLL: 0.020\n",
      "MLL: 2.228\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█▎        | 252/2000 [00:48<07:25,  3.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 3.426\n",
      "ELBO: -617.934\n",
      "SMSE: 0.395\n",
      "SMLL: -0.006\n",
      "MLL: 2.202\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 301/2000 [00:56<10:24,  2.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 2.784\n",
      "ELBO: -578.114\n",
      "SMSE: 0.393\n",
      "SMLL: -0.065\n",
      "MLL: 2.143\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 352/2000 [01:04<07:01,  3.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 2.892\n",
      "ELBO: -539.710\n",
      "SMSE: 0.373\n",
      "SMLL: -0.150\n",
      "MLL: 2.058\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 402/2000 [01:12<06:19,  4.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 2.469\n",
      "ELBO: -493.872\n",
      "SMSE: 0.370\n",
      "SMLL: -0.178\n",
      "MLL: 2.031\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 452/2000 [01:19<05:55,  4.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: 2.295\n",
      "ELBO: -412.493\n",
      "SMSE: 0.357\n",
      "SMLL: -0.265\n",
      "MLL: 1.943\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 502/2000 [01:31<10:08,  2.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: 1.957\n",
      "ELBO: -363.651\n",
      "SMSE: 0.342\n",
      "SMLL: -0.306\n",
      "MLL: 1.903\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 551/2000 [01:44<13:14,  1.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: 1.917\n",
      "ELBO: -273.263\n",
      "SMSE: 0.314\n",
      "SMLL: -0.408\n",
      "MLL: 1.801\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 602/2000 [01:53<06:44,  3.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: 1.952\n",
      "ELBO: -204.043\n",
      "SMSE: 0.323\n",
      "SMLL: -0.346\n",
      "MLL: 1.862\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 652/2000 [02:02<06:47,  3.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: 1.627\n",
      "ELBO: -164.805\n",
      "SMSE: 0.330\n",
      "SMLL: -0.339\n",
      "MLL: 1.869\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▌      | 702/2000 [02:09<06:31,  3.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: 1.496\n",
      "ELBO: -128.765\n",
      "SMSE: 0.339\n",
      "SMLL: -0.310\n",
      "MLL: 1.898\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 752/2000 [02:18<05:48,  3.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: 1.592\n",
      "ELBO: -77.067\n",
      "SMSE: 0.325\n",
      "SMLL: -0.343\n",
      "MLL: 1.865\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 802/2000 [02:32<14:24,  1.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 1.454\n",
      "ELBO: -12.912\n",
      "SMSE: 0.299\n",
      "SMLL: -0.372\n",
      "MLL: 1.836\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 43%|████▎     | 851/2000 [02:42<08:38,  2.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: 1.037\n",
      "ELBO: 59.810\n",
      "SMSE: 0.295\n",
      "SMLL: -0.381\n",
      "MLL: 1.827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 902/2000 [02:52<07:16,  2.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: 1.263\n",
      "ELBO: 62.675\n",
      "SMSE: 0.298\n",
      "SMLL: -0.345\n",
      "MLL: 1.864\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 951/2000 [03:08<17:26,  1.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: 1.198\n",
      "ELBO: 77.694\n",
      "SMSE: 0.314\n",
      "SMLL: -0.273\n",
      "MLL: 1.935\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1002/2000 [03:26<06:07,  2.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: 1.208\n",
      "ELBO: 114.578\n",
      "SMSE: 0.330\n",
      "SMLL: -0.192\n",
      "MLL: 2.016\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 51%|█████     | 1011/2000 [03:30<03:25,  4.81it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-27-f1e84616096f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m train_eeg(sads_model, gpvae.estimators.gpvae_estimators.analytical_estimator, \n\u001b[1;32m      2\u001b[0m           \u001b[0mloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msads_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m           None, save_model=True)\n\u001b[0m",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/experiments/eeg/train_eeg.py\u001b[0m in \u001b[0;36mtrain_eeg\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model)\u001b[0m\n\u001b[1;32m     66\u001b[0m                 loss = loss_fn(\n\u001b[1;32m     67\u001b[0m                     \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mm_b\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m                     decoder_scale=args['decoder_scale'])\n\u001b[0m\u001b[1;32m     69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/estimators/gpvae_estimators.py\u001b[0m in \u001b[0;36manalytical_estimator\u001b[0;34m(model, x, y, mask, num_samples, decoder_scale, make_lazy, mf, idx)\u001b[0m\n\u001b[1;32m    297\u001b[0m         \u001b[0;31m# Pass amortisation models the observation data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    298\u001b[0m         \u001b[0mqf_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mqf_cov\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpf_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpf_cov\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlf_y_mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlf_y_cov\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 299\u001b[0;31m             \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_latent_dists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    300\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    301\u001b[0m     \u001b[0msum_cov\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpf_cov\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlf_y_cov\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/gpvae/models/gpvae.py\u001b[0m in \u001b[0;36mget_latent_dists\u001b[0;34m(self, x, y, mask, x_test)\u001b[0m\n\u001b[1;32m     72\u001b[0m         \u001b[0mw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlf_y_root_precision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m         \u001b[0mw\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0madd_diagonal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m         \u001b[0mwinv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mx_test\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_eeg(sads_model, gpvae.estimators.gpvae_estimators.analytical_estimator, \n",
    "          loader, sads_args, gpvae.estimators.gpvae_estimators.elbo_estimator, \n",
    "          None, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/2500 [00:01<35:16,  1.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 337.428\n",
      "ELBO: -69464.492\n",
      "IWAE: -69354.735\n",
      "SMSE: 1.448\n",
      "SMLL: 18.412\n",
      "MLL: 20.620\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 53/2500 [00:06<08:41,  4.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 17.541\n",
      "ELBO: -4730.429\n",
      "IWAE: -4620.958\n",
      "SMSE: 1.024\n",
      "SMLL: 2.712\n",
      "MLL: 4.921\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 103/2500 [00:11<08:07,  4.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 5.139\n",
      "ELBO: -1533.934\n",
      "IWAE: -1464.852\n",
      "SMSE: 0.461\n",
      "SMLL: 0.141\n",
      "MLL: 2.349\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 153/2500 [00:17<08:13,  4.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 3.488\n",
      "ELBO: -1279.025\n",
      "IWAE: -1216.118\n",
      "SMSE: 0.410\n",
      "SMLL: -0.083\n",
      "MLL: 2.125\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 203/2500 [00:22<09:01,  4.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 2.628\n",
      "ELBO: -1122.462\n",
      "IWAE: -1070.394\n",
      "SMSE: 0.395\n",
      "SMLL: -0.157\n",
      "MLL: 2.051\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 252/2500 [00:29<09:13,  4.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 1.768\n",
      "ELBO: -980.155\n",
      "IWAE: -937.494\n",
      "SMSE: 0.391\n",
      "SMLL: -0.208\n",
      "MLL: 2.000\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 303/2500 [00:35<08:45,  4.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 0.342\n",
      "ELBO: -807.345\n",
      "IWAE: -769.162\n",
      "SMSE: 0.378\n",
      "SMLL: -0.232\n",
      "MLL: 1.976\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 352/2500 [00:41<10:40,  3.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: -1.220\n",
      "ELBO: -596.848\n",
      "IWAE: -566.790\n",
      "SMSE: 0.437\n",
      "SMLL: -0.059\n",
      "MLL: 2.149\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 402/2500 [00:47<09:10,  3.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: -1.119\n",
      "ELBO: -482.028\n",
      "IWAE: -454.725\n",
      "SMSE: 0.502\n",
      "SMLL: 0.179\n",
      "MLL: 2.387\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 452/2500 [00:53<09:29,  3.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: -2.206\n",
      "ELBO: -394.379\n",
      "IWAE: -353.345\n",
      "SMSE: 0.524\n",
      "SMLL: 0.215\n",
      "MLL: 2.423\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 502/2500 [00:59<08:11,  4.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: -3.139\n",
      "ELBO: -315.142\n",
      "IWAE: -290.350\n",
      "SMSE: 0.473\n",
      "SMLL: 0.002\n",
      "MLL: 2.210\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 553/2500 [01:05<06:48,  4.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: -3.612\n",
      "ELBO: -249.194\n",
      "IWAE: -224.548\n",
      "SMSE: 0.383\n",
      "SMLL: -0.241\n",
      "MLL: 1.967\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 602/2500 [01:10<06:50,  4.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: -3.735\n",
      "ELBO: -188.450\n",
      "IWAE: -162.931\n",
      "SMSE: 0.348\n",
      "SMLL: -0.351\n",
      "MLL: 1.857\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 652/2500 [01:16<07:54,  3.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: -4.884\n",
      "ELBO: -148.352\n",
      "IWAE: -129.020\n",
      "SMSE: 0.314\n",
      "SMLL: -0.457\n",
      "MLL: 1.751\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 702/2500 [01:21<08:06,  3.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: -4.841\n",
      "ELBO: -128.149\n",
      "IWAE: -116.120\n",
      "SMSE: 0.311\n",
      "SMLL: -0.459\n",
      "MLL: 1.749\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 752/2500 [01:27<06:23,  4.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: -5.046\n",
      "ELBO: -89.971\n",
      "IWAE: -68.820\n",
      "SMSE: 0.269\n",
      "SMLL: -0.596\n",
      "MLL: 1.612\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 802/2500 [01:33<07:15,  3.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: -5.663\n",
      "ELBO: -80.039\n",
      "IWAE: -55.798\n",
      "SMSE: 0.249\n",
      "SMLL: -0.608\n",
      "MLL: 1.600\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 852/2500 [01:38<07:10,  3.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: -5.802\n",
      "ELBO: -51.518\n",
      "IWAE: -34.304\n",
      "SMSE: 0.238\n",
      "SMLL: -0.635\n",
      "MLL: 1.573\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 902/2500 [01:43<06:54,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: -6.777\n",
      "ELBO: -34.515\n",
      "IWAE: -10.077\n",
      "SMSE: 0.234\n",
      "SMLL: -0.652\n",
      "MLL: 1.556\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 952/2500 [01:49<06:53,  3.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: -5.130\n",
      "ELBO: -19.290\n",
      "IWAE: -7.475\n",
      "SMSE: 0.228\n",
      "SMLL: -0.667\n",
      "MLL: 1.541\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1002/2500 [01:54<06:47,  3.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: -6.233\n",
      "ELBO: -7.845\n",
      "IWAE: 14.316\n",
      "SMSE: 0.231\n",
      "SMLL: -0.607\n",
      "MLL: 1.601\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1052/2500 [02:00<06:12,  3.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: -5.964\n",
      "ELBO: 2.047\n",
      "IWAE: 34.640\n",
      "SMSE: 0.238\n",
      "SMLL: -0.538\n",
      "MLL: 1.670\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1102/2500 [02:05<06:03,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: -6.347\n",
      "ELBO: 23.299\n",
      "IWAE: 42.390\n",
      "SMSE: 0.231\n",
      "SMLL: -0.529\n",
      "MLL: 1.680\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1152/2500 [02:11<05:43,  3.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: -7.670\n",
      "ELBO: 22.837\n",
      "IWAE: 46.173\n",
      "SMSE: 0.247\n",
      "SMLL: -0.455\n",
      "MLL: 1.754\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1202/2500 [02:16<05:52,  3.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: -6.702\n",
      "ELBO: 47.788\n",
      "IWAE: 71.939\n",
      "SMSE: 0.220\n",
      "SMLL: -0.550\n",
      "MLL: 1.658\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1252/2500 [02:21<05:24,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: -6.896\n",
      "ELBO: 44.135\n",
      "IWAE: 63.007\n",
      "SMSE: 0.234\n",
      "SMLL: -0.484\n",
      "MLL: 1.724\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1302/2500 [02:27<05:15,  3.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: -7.359\n",
      "ELBO: 56.790\n",
      "IWAE: 76.032\n",
      "SMSE: 0.250\n",
      "SMLL: -0.373\n",
      "MLL: 1.835\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1352/2500 [02:32<04:49,  3.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: -8.496\n",
      "ELBO: 72.082\n",
      "IWAE: 90.368\n",
      "SMSE: 0.242\n",
      "SMLL: -0.346\n",
      "MLL: 1.863\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1402/2500 [02:38<04:51,  3.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: -7.602\n",
      "ELBO: 69.306\n",
      "IWAE: 94.032\n",
      "SMSE: 0.260\n",
      "SMLL: -0.209\n",
      "MLL: 1.999\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1452/2500 [02:43<04:24,  3.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: -7.486\n",
      "ELBO: 86.029\n",
      "IWAE: 107.648\n",
      "SMSE: 0.246\n",
      "SMLL: -0.265\n",
      "MLL: 1.943\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1502/2500 [02:48<04:19,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: -7.321\n",
      "ELBO: 73.259\n",
      "IWAE: 92.594\n",
      "SMSE: 0.244\n",
      "SMLL: -0.274\n",
      "MLL: 1.934\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1552/2500 [02:54<04:10,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: -7.070\n",
      "ELBO: 81.938\n",
      "IWAE: 105.394\n",
      "SMSE: 0.241\n",
      "SMLL: -0.287\n",
      "MLL: 1.921\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1602/2500 [02:59<03:53,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: -7.432\n",
      "ELBO: 107.498\n",
      "IWAE: 118.671\n",
      "SMSE: 0.258\n",
      "SMLL: -0.169\n",
      "MLL: 2.039\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1652/2500 [03:04<03:49,  3.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: -7.376\n",
      "ELBO: 106.382\n",
      "IWAE: 119.902\n",
      "SMSE: 0.250\n",
      "SMLL: -0.194\n",
      "MLL: 2.014\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1702/2500 [03:10<03:34,  3.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: -7.751\n",
      "ELBO: 91.042\n",
      "IWAE: 112.494\n",
      "SMSE: 0.257\n",
      "SMLL: -0.141\n",
      "MLL: 2.067\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1752/2500 [03:15<03:13,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: -8.151\n",
      "ELBO: 105.945\n",
      "IWAE: 123.344\n",
      "SMSE: 0.239\n",
      "SMLL: -0.221\n",
      "MLL: 1.987\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1802/2500 [03:21<03:25,  3.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: -7.689\n",
      "ELBO: 116.956\n",
      "IWAE: 131.955\n",
      "SMSE: 0.250\n",
      "SMLL: -0.176\n",
      "MLL: 2.032\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1852/2500 [03:26<02:44,  3.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: -7.720\n",
      "ELBO: 112.825\n",
      "IWAE: 136.990\n",
      "SMSE: 0.248\n",
      "SMLL: -0.200\n",
      "MLL: 2.009\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1902/2500 [03:31<02:41,  3.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: -7.349\n",
      "ELBO: 122.915\n",
      "IWAE: 139.299\n",
      "SMSE: 0.263\n",
      "SMLL: -0.006\n",
      "MLL: 2.202\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1952/2500 [03:37<02:22,  3.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: -7.062\n",
      "ELBO: 118.838\n",
      "IWAE: 136.965\n",
      "SMSE: 0.263\n",
      "SMLL: -0.056\n",
      "MLL: 2.152\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2002/2500 [03:42<02:09,  3.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: -6.727\n",
      "ELBO: 122.701\n",
      "IWAE: 145.327\n",
      "SMSE: 0.274\n",
      "SMLL: 0.033\n",
      "MLL: 2.241\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2052/2500 [03:47<01:58,  3.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: -6.482\n",
      "ELBO: 129.157\n",
      "IWAE: 146.511\n",
      "SMSE: 0.281\n",
      "SMLL: 0.028\n",
      "MLL: 2.236\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2102/2500 [03:53<01:45,  3.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: -7.549\n",
      "ELBO: 127.290\n",
      "IWAE: 144.782\n",
      "SMSE: 0.259\n",
      "SMLL: -0.067\n",
      "MLL: 2.141\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2152/2500 [03:58<01:30,  3.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: -7.860\n",
      "ELBO: 121.791\n",
      "IWAE: 139.080\n",
      "SMSE: 0.267\n",
      "SMLL: -0.062\n",
      "MLL: 2.146\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2202/2500 [04:05<02:20,  2.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: -7.895\n",
      "ELBO: 133.702\n",
      "IWAE: 153.320\n",
      "SMSE: 0.248\n",
      "SMLL: -0.128\n",
      "MLL: 2.081\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2252/2500 [04:12<01:38,  2.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: -8.736\n",
      "ELBO: 131.997\n",
      "IWAE: 145.658\n",
      "SMSE: 0.261\n",
      "SMLL: 0.004\n",
      "MLL: 2.212\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2303/2500 [04:18<00:41,  4.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: -7.599\n",
      "ELBO: 146.246\n",
      "IWAE: 160.063\n",
      "SMSE: 0.248\n",
      "SMLL: -0.103\n",
      "MLL: 2.105\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2353/2500 [04:23<00:30,  4.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: -8.227\n",
      "ELBO: 146.433\n",
      "IWAE: 162.411\n",
      "SMSE: 0.274\n",
      "SMLL: 0.029\n",
      "MLL: 2.238\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2403/2500 [04:29<00:20,  4.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: -8.411\n",
      "ELBO: 146.620\n",
      "IWAE: 163.188\n",
      "SMSE: 0.261\n",
      "SMLL: -0.078\n",
      "MLL: 2.130\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2452/2500 [04:35<00:13,  3.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: -7.647\n",
      "ELBO: 140.561\n",
      "IWAE: 157.964\n",
      "SMSE: 0.272\n",
      "SMLL: 0.122\n",
      "MLL: 2.331\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [04:41<00:00,  8.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: -7.547\n",
      "ELBO: 139.758\n",
      "IWAE: 155.855\n",
      "SMSE: 0.287\n",
      "SMLL: 0.243\n",
      "MLL: 2.451\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [337.42842693190704,\n",
       "  17.54051914338881,\n",
       "  5.139487514563753,\n",
       "  3.48760666438317,\n",
       "  2.627604798801595,\n",
       "  1.7679773570216863,\n",
       "  0.34182859472043564,\n",
       "  -1.2204708796641228,\n",
       "  -1.1193831835562793,\n",
       "  -2.2060986057808165,\n",
       "  -3.1388810376830776,\n",
       "  -3.6118022748989147,\n",
       "  -3.7352794552822246,\n",
       "  -4.883593981805524,\n",
       "  -4.841305119053706,\n",
       "  -5.046327048709766,\n",
       "  -5.663161599292238,\n",
       "  -5.802092396745583,\n",
       "  -6.777025642398179,\n",
       "  -5.12965032790816,\n",
       "  -6.23325383784908,\n",
       "  -5.964109836403338,\n",
       "  -6.34749411589722,\n",
       "  -7.670343528256967,\n",
       "  -6.7021643402232405,\n",
       "  -6.89590651485272,\n",
       "  -7.358595137239461,\n",
       "  -8.496065898797132,\n",
       "  -7.601769183726241,\n",
       "  -7.486110675998899,\n",
       "  -7.3209635042619,\n",
       "  -7.069976783641969,\n",
       "  -7.432236274578108,\n",
       "  -7.376166286765027,\n",
       "  -7.75063509866179,\n",
       "  -8.151215302788486,\n",
       "  -7.689477664878763,\n",
       "  -7.7203136179210246,\n",
       "  -7.349135819478128,\n",
       "  -7.061585124933134,\n",
       "  -6.726537406719289,\n",
       "  -6.482010545987591,\n",
       "  -7.548655087480469,\n",
       "  -7.85962703837292,\n",
       "  -7.895373841326433,\n",
       "  -8.73589248641457,\n",
       "  -7.598734375110127,\n",
       "  -8.227045784116747,\n",
       "  -8.41055852303458,\n",
       "  -7.647475650036739,\n",
       "  -7.546677297243765],\n",
       " 'elbos': [tensor(-69464.4918, grad_fn=<AddBackward0>),\n",
       "  tensor(-4730.4288, grad_fn=<AddBackward0>),\n",
       "  tensor(-1533.9338, grad_fn=<AddBackward0>),\n",
       "  tensor(-1279.0250, grad_fn=<AddBackward0>),\n",
       "  tensor(-1122.4618, grad_fn=<AddBackward0>),\n",
       "  tensor(-980.1549, grad_fn=<AddBackward0>),\n",
       "  tensor(-807.3449, grad_fn=<AddBackward0>),\n",
       "  tensor(-596.8482, grad_fn=<AddBackward0>),\n",
       "  tensor(-482.0279, grad_fn=<AddBackward0>),\n",
       "  tensor(-394.3793, grad_fn=<AddBackward0>),\n",
       "  tensor(-315.1417, grad_fn=<AddBackward0>),\n",
       "  tensor(-249.1938, grad_fn=<AddBackward0>),\n",
       "  tensor(-188.4503, grad_fn=<AddBackward0>),\n",
       "  tensor(-148.3517, grad_fn=<AddBackward0>),\n",
       "  tensor(-128.1486, grad_fn=<AddBackward0>),\n",
       "  tensor(-89.9714, grad_fn=<AddBackward0>),\n",
       "  tensor(-80.0393, grad_fn=<AddBackward0>),\n",
       "  tensor(-51.5181, grad_fn=<AddBackward0>),\n",
       "  tensor(-34.5146, grad_fn=<AddBackward0>),\n",
       "  tensor(-19.2898, grad_fn=<AddBackward0>),\n",
       "  tensor(-7.8454, grad_fn=<AddBackward0>),\n",
       "  tensor(2.0474, grad_fn=<AddBackward0>),\n",
       "  tensor(23.2994, grad_fn=<AddBackward0>),\n",
       "  tensor(22.8369, grad_fn=<AddBackward0>),\n",
       "  tensor(47.7875, grad_fn=<AddBackward0>),\n",
       "  tensor(44.1350, grad_fn=<AddBackward0>),\n",
       "  tensor(56.7904, grad_fn=<AddBackward0>),\n",
       "  tensor(72.0819, grad_fn=<AddBackward0>),\n",
       "  tensor(69.3059, grad_fn=<AddBackward0>),\n",
       "  tensor(86.0288, grad_fn=<AddBackward0>),\n",
       "  tensor(73.2593, grad_fn=<AddBackward0>),\n",
       "  tensor(81.9379, grad_fn=<AddBackward0>),\n",
       "  tensor(107.4979, grad_fn=<AddBackward0>),\n",
       "  tensor(106.3822, grad_fn=<AddBackward0>),\n",
       "  tensor(91.0420, grad_fn=<AddBackward0>),\n",
       "  tensor(105.9448, grad_fn=<AddBackward0>),\n",
       "  tensor(116.9563, grad_fn=<AddBackward0>),\n",
       "  tensor(112.8245, grad_fn=<AddBackward0>),\n",
       "  tensor(122.9153, grad_fn=<AddBackward0>),\n",
       "  tensor(118.8381, grad_fn=<AddBackward0>),\n",
       "  tensor(122.7006, grad_fn=<AddBackward0>),\n",
       "  tensor(129.1570, grad_fn=<AddBackward0>),\n",
       "  tensor(127.2904, grad_fn=<AddBackward0>),\n",
       "  tensor(121.7907, grad_fn=<AddBackward0>),\n",
       "  tensor(133.7023, grad_fn=<AddBackward0>),\n",
       "  tensor(131.9972, grad_fn=<AddBackward0>),\n",
       "  tensor(146.2456, grad_fn=<AddBackward0>),\n",
       "  tensor(146.4334, grad_fn=<AddBackward0>),\n",
       "  tensor(146.6199, grad_fn=<AddBackward0>),\n",
       "  tensor(140.5614, grad_fn=<AddBackward0>),\n",
       "  tensor(139.7580, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-69354.7349, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-4620.9580, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1464.8521, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1216.1182, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1070.3940, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-937.4938, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-769.1620, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-566.7897, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-454.7248, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-353.3449, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-290.3500, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-224.5475, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-162.9308, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-129.0202, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-116.1201, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-68.8197, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-55.7983, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-34.3038, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-10.0770, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-7.4750, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(14.3161, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(34.6402, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(42.3903, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(46.1733, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(71.9386, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(63.0073, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(76.0315, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(90.3678, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(94.0319, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(107.6478, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(92.5944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(105.3944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(118.6714, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(119.9018, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(112.4944, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(123.3443, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(131.9551, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(136.9895, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(139.2991, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(136.9650, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(145.3272, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(146.5111, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(144.7817, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(139.0801, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(153.3198, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(145.6577, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(160.0634, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(162.4114, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(163.1877, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(157.9641, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(155.8552, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [1.4480312972999643,\n",
       "  1.0236485730104397,\n",
       "  0.4605083408173229,\n",
       "  0.4102557325385812,\n",
       "  0.39529960888932,\n",
       "  0.3907157896330992,\n",
       "  0.37821270678001406,\n",
       "  0.43706703503001226,\n",
       "  0.502123496494079,\n",
       "  0.5239586408193785,\n",
       "  0.4727276731917455,\n",
       "  0.3828864248908826,\n",
       "  0.348069165592545,\n",
       "  0.3136002215082348,\n",
       "  0.31104575595888345,\n",
       "  0.2689274962962999,\n",
       "  0.24939282730146725,\n",
       "  0.23828805588097066,\n",
       "  0.23408968497034044,\n",
       "  0.2279335211458645,\n",
       "  0.230705072550308,\n",
       "  0.23773074515857642,\n",
       "  0.23131270408118063,\n",
       "  0.24714230250108934,\n",
       "  0.2202067132787914,\n",
       "  0.2335466082932005,\n",
       "  0.2500759205285308,\n",
       "  0.24214552532329559,\n",
       "  0.25959476782068575,\n",
       "  0.24607388106746372,\n",
       "  0.24397361998121622,\n",
       "  0.24082860845986773,\n",
       "  0.25835181248275374,\n",
       "  0.25046636832562647,\n",
       "  0.256687433093634,\n",
       "  0.23928165766365975,\n",
       "  0.24993393955503995,\n",
       "  0.24771122948746158,\n",
       "  0.2628795058930948,\n",
       "  0.26258549942597614,\n",
       "  0.274323092351246,\n",
       "  0.28093010711906324,\n",
       "  0.2588694415020112,\n",
       "  0.2669875315754374,\n",
       "  0.24772875991416657,\n",
       "  0.26098115258866356,\n",
       "  0.24756150686042266,\n",
       "  0.2741491284678384,\n",
       "  0.2605772288675335,\n",
       "  0.2717721930966173,\n",
       "  0.28721282002102017],\n",
       " 'smlls': [18.412052932757337,\n",
       "  2.7123033949999953,\n",
       "  0.14114333415895555,\n",
       "  -0.08281970468060192,\n",
       "  -0.15711224420512107,\n",
       "  -0.20817078317777127,\n",
       "  -0.23239402685025268,\n",
       "  -0.059121512828469815,\n",
       "  0.17868473138070096,\n",
       "  0.21453428750005785,\n",
       "  0.0019410579870873985,\n",
       "  -0.2413976474818449,\n",
       "  -0.3510245449252454,\n",
       "  -0.45678354564766965,\n",
       "  -0.45878512124780696,\n",
       "  -0.5961826122179741,\n",
       "  -0.6080864673772225,\n",
       "  -0.6351445682129945,\n",
       "  -0.6521452154978127,\n",
       "  -0.6674121890683965,\n",
       "  -0.607149316181139,\n",
       "  -0.5379716227402414,\n",
       "  -0.5285132805635343,\n",
       "  -0.45466203135986705,\n",
       "  -0.5499652425266407,\n",
       "  -0.4837470892224008,\n",
       "  -0.3727646892431366,\n",
       "  -0.34556299621133757,\n",
       "  -0.20919901748012956,\n",
       "  -0.2647908232623711,\n",
       "  -0.2742290606617239,\n",
       "  -0.28696880566413013,\n",
       "  -0.16909336484365003,\n",
       "  -0.19436167059574258,\n",
       "  -0.14082982204641423,\n",
       "  -0.22123223664346647,\n",
       "  -0.17624302300824782,\n",
       "  -0.19961146393728998,\n",
       "  -0.0064570499819711635,\n",
       "  -0.05590683982000777,\n",
       "  0.03250841140788937,\n",
       "  0.02761428996416604,\n",
       "  -0.06695232337097494,\n",
       "  -0.06233390147713062,\n",
       "  -0.1276671912658891,\n",
       "  0.004227246098852093,\n",
       "  -0.10323803261859495,\n",
       "  0.029446198127061063,\n",
       "  -0.07797740634351251,\n",
       "  0.1224238423726316,\n",
       "  0.24284785802283984],\n",
       " 'mlls': [20.620294524393703,\n",
       "  4.920544986636364,\n",
       "  2.3493849257953237,\n",
       "  2.1254218869557664,\n",
       "  2.0511293474312473,\n",
       "  2.000070808458597,\n",
       "  1.9758475647861156,\n",
       "  2.1491200788078983,\n",
       "  2.3869263230170694,\n",
       "  2.422775879136426,\n",
       "  2.2101826496234556,\n",
       "  1.9668439441545236,\n",
       "  1.857217046711123,\n",
       "  1.7514580459886986,\n",
       "  1.7494564703885613,\n",
       "  1.6120589794183944,\n",
       "  1.6001551242591459,\n",
       "  1.5730970234233739,\n",
       "  1.5560963761385558,\n",
       "  1.5408294025679716,\n",
       "  1.6010922754552295,\n",
       "  1.670269968896127,\n",
       "  1.679728311072834,\n",
       "  1.7535795602765012,\n",
       "  1.6582763491097274,\n",
       "  1.7244945024139675,\n",
       "  1.8354769023932318,\n",
       "  1.862678595425031,\n",
       "  1.999042574156239,\n",
       "  1.9434507683739974,\n",
       "  1.9340125309746445,\n",
       "  1.9212727859722383,\n",
       "  2.0391482267927183,\n",
       "  2.0138799210406257,\n",
       "  2.067411769589954,\n",
       "  1.987009354992902,\n",
       "  2.0319985686281203,\n",
       "  2.008630127699078,\n",
       "  2.2017845416543973,\n",
       "  2.1523347518163605,\n",
       "  2.2407500030442575,\n",
       "  2.2358558816005343,\n",
       "  2.141289268265393,\n",
       "  2.145907690159238,\n",
       "  2.080574400370479,\n",
       "  2.2124688377352206,\n",
       "  2.105003559017774,\n",
       "  2.2376877897634295,\n",
       "  2.1302641852928557,\n",
       "  2.3306654340089996,\n",
       "  2.451089449659208]}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(zi_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, zi_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 4/2500 [00:01<28:34,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 346.745\n",
      "ELBO: -70586.053\n",
      "IWAE: -68993.356\n",
      "SMSE: 2.374\n",
      "SMLL: 6.051\n",
      "MLL: 8.259\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 54/2500 [00:03<03:55, 10.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: 19.585\n",
      "ELBO: -4472.171\n",
      "IWAE: -4318.618\n",
      "SMSE: 0.509\n",
      "SMLL: 0.115\n",
      "MLL: 2.323\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 106/2500 [00:06<03:58, 10.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 10.607\n",
      "ELBO: -2539.600\n",
      "IWAE: -2445.978\n",
      "SMSE: 0.501\n",
      "SMLL: 0.020\n",
      "MLL: 2.228\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 156/2500 [00:09<02:55, 13.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: 8.061\n",
      "ELBO: -2001.187\n",
      "IWAE: -1912.114\n",
      "SMSE: 0.437\n",
      "SMLL: -0.174\n",
      "MLL: 2.034\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 204/2500 [00:11<03:22, 11.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: 6.642\n",
      "ELBO: -1744.238\n",
      "IWAE: -1682.610\n",
      "SMSE: 0.439\n",
      "SMLL: -0.233\n",
      "MLL: 1.976\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 255/2500 [00:14<03:42, 10.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250\n",
      "Loss: 6.155\n",
      "ELBO: -1609.327\n",
      "IWAE: -1557.627\n",
      "SMSE: 0.424\n",
      "SMLL: -0.292\n",
      "MLL: 1.916\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█▏        | 306/2500 [00:17<02:44, 13.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 5.687\n",
      "ELBO: -1509.665\n",
      "IWAE: -1443.779\n",
      "SMSE: 0.393\n",
      "SMLL: -0.386\n",
      "MLL: 1.822\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 356/2500 [00:19<02:41, 13.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 350\n",
      "Loss: 5.585\n",
      "ELBO: -1495.686\n",
      "IWAE: -1411.597\n",
      "SMSE: 0.332\n",
      "SMLL: -0.505\n",
      "MLL: 1.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 406/2500 [00:21<02:35, 13.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 5.104\n",
      "ELBO: -1363.465\n",
      "IWAE: -1320.558\n",
      "SMSE: 0.337\n",
      "SMLL: -0.500\n",
      "MLL: 1.708\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 456/2500 [00:24<02:35, 13.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 450\n",
      "Loss: 4.755\n",
      "ELBO: -1283.504\n",
      "IWAE: -1225.269\n",
      "SMSE: 0.344\n",
      "SMLL: -0.505\n",
      "MLL: 1.703\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 505/2500 [00:27<03:42,  8.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: 4.495\n",
      "ELBO: -1218.630\n",
      "IWAE: -1167.473\n",
      "SMSE: 0.338\n",
      "SMLL: -0.499\n",
      "MLL: 1.709\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 555/2500 [00:29<02:29, 13.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 550\n",
      "Loss: 4.071\n",
      "ELBO: -1172.603\n",
      "IWAE: -1117.114\n",
      "SMSE: 0.280\n",
      "SMLL: -0.610\n",
      "MLL: 1.598\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 606/2500 [00:32<02:20, 13.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: 3.929\n",
      "ELBO: -1126.949\n",
      "IWAE: -1081.064\n",
      "SMSE: 0.260\n",
      "SMLL: -0.647\n",
      "MLL: 1.561\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██▌       | 654/2500 [00:34<02:20, 13.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 650\n",
      "Loss: 3.858\n",
      "ELBO: -1089.525\n",
      "IWAE: -1047.405\n",
      "SMSE: 0.237\n",
      "SMLL: -0.711\n",
      "MLL: 1.497\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|██▊       | 704/2500 [00:36<02:14, 13.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: 3.715\n",
      "ELBO: -1096.349\n",
      "IWAE: -1057.518\n",
      "SMSE: 0.302\n",
      "SMLL: -0.543\n",
      "MLL: 1.665\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 754/2500 [00:39<02:12, 13.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 750\n",
      "Loss: 3.630\n",
      "ELBO: -1030.551\n",
      "IWAE: -1001.492\n",
      "SMSE: 0.255\n",
      "SMLL: -0.632\n",
      "MLL: 1.576\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 804/2500 [00:41<02:38, 10.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 3.476\n",
      "ELBO: -993.423\n",
      "IWAE: -951.940\n",
      "SMSE: 0.250\n",
      "SMLL: -0.659\n",
      "MLL: 1.550\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|███▍      | 854/2500 [00:44<02:05, 13.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 850\n",
      "Loss: 3.507\n",
      "ELBO: -968.861\n",
      "IWAE: -939.891\n",
      "SMSE: 0.255\n",
      "SMLL: -0.641\n",
      "MLL: 1.567\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 904/2500 [00:46<02:40,  9.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: 3.252\n",
      "ELBO: -938.379\n",
      "IWAE: -900.532\n",
      "SMSE: 0.233\n",
      "SMLL: -0.687\n",
      "MLL: 1.521\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 954/2500 [00:49<02:28, 10.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 950\n",
      "Loss: 3.109\n",
      "ELBO: -915.538\n",
      "IWAE: -887.881\n",
      "SMSE: 0.267\n",
      "SMLL: -0.576\n",
      "MLL: 1.632\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1004/2500 [00:52<02:21, 10.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: 3.173\n",
      "ELBO: -901.319\n",
      "IWAE: -872.197\n",
      "SMSE: 0.255\n",
      "SMLL: -0.626\n",
      "MLL: 1.582\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 1054/2500 [00:55<02:53,  8.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1050\n",
      "Loss: 3.073\n",
      "ELBO: -889.690\n",
      "IWAE: -857.514\n",
      "SMSE: 0.247\n",
      "SMLL: -0.605\n",
      "MLL: 1.604\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 1105/2500 [00:57<02:18, 10.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: 3.132\n",
      "ELBO: -890.615\n",
      "IWAE: -858.539\n",
      "SMSE: 0.247\n",
      "SMLL: -0.619\n",
      "MLL: 1.589\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 1153/2500 [01:00<02:58,  7.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1150\n",
      "Loss: 2.758\n",
      "ELBO: -871.028\n",
      "IWAE: -852.254\n",
      "SMSE: 0.238\n",
      "SMLL: -0.641\n",
      "MLL: 1.567\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 1204/2500 [01:02<01:44, 12.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: 3.031\n",
      "ELBO: -892.668\n",
      "IWAE: -861.895\n",
      "SMSE: 0.264\n",
      "SMLL: -0.514\n",
      "MLL: 1.694\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1255/2500 [01:05<01:39, 12.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1250\n",
      "Loss: 2.879\n",
      "ELBO: -866.165\n",
      "IWAE: -829.565\n",
      "SMSE: 0.259\n",
      "SMLL: -0.517\n",
      "MLL: 1.691\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|█████▏    | 1305/2500 [01:08<01:58, 10.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: 2.742\n",
      "ELBO: -855.305\n",
      "IWAE: -826.100\n",
      "SMSE: 0.262\n",
      "SMLL: -0.491\n",
      "MLL: 1.717\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 1354/2500 [01:10<01:51, 10.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1350\n",
      "Loss: 2.795\n",
      "ELBO: -870.581\n",
      "IWAE: -828.842\n",
      "SMSE: 0.278\n",
      "SMLL: -0.412\n",
      "MLL: 1.796\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 1406/2500 [01:13<01:46, 10.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: 2.804\n",
      "ELBO: -844.243\n",
      "IWAE: -815.076\n",
      "SMSE: 0.259\n",
      "SMLL: -0.453\n",
      "MLL: 1.755\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 1456/2500 [01:16<01:23, 12.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1450\n",
      "Loss: 2.656\n",
      "ELBO: -837.756\n",
      "IWAE: -810.068\n",
      "SMSE: 0.269\n",
      "SMLL: -0.446\n",
      "MLL: 1.762\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1504/2500 [01:18<01:38, 10.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: 2.614\n",
      "ELBO: -833.742\n",
      "IWAE: -802.311\n",
      "SMSE: 0.281\n",
      "SMLL: -0.308\n",
      "MLL: 1.900\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 1555/2500 [01:21<01:32, 10.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1550\n",
      "Loss: 2.717\n",
      "ELBO: -842.693\n",
      "IWAE: -811.189\n",
      "SMSE: 0.274\n",
      "SMLL: -0.364\n",
      "MLL: 1.844\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 1603/2500 [01:24<01:55,  7.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: 2.436\n",
      "ELBO: -808.594\n",
      "IWAE: -784.545\n",
      "SMSE: 0.249\n",
      "SMLL: -0.476\n",
      "MLL: 1.732\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 1654/2500 [01:27<01:45,  8.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1650\n",
      "Loss: 2.530\n",
      "ELBO: -822.864\n",
      "IWAE: -797.674\n",
      "SMSE: 0.249\n",
      "SMLL: -0.425\n",
      "MLL: 1.783\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|██████▊   | 1704/2500 [01:29<01:12, 10.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: 2.563\n",
      "ELBO: -795.729\n",
      "IWAE: -761.907\n",
      "SMSE: 0.285\n",
      "SMLL: -0.262\n",
      "MLL: 1.946\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 1756/2500 [01:32<01:01, 12.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1750\n",
      "Loss: 2.411\n",
      "ELBO: -806.689\n",
      "IWAE: -775.880\n",
      "SMSE: 0.239\n",
      "SMLL: -0.475\n",
      "MLL: 1.733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|███████▏  | 1806/2500 [01:35<01:07, 10.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: 2.525\n",
      "ELBO: -797.867\n",
      "IWAE: -773.312\n",
      "SMSE: 0.266\n",
      "SMLL: -0.324\n",
      "MLL: 1.884\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 74%|███████▍  | 1854/2500 [01:37<01:02, 10.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1850\n",
      "Loss: 2.543\n",
      "ELBO: -783.469\n",
      "IWAE: -754.802\n",
      "SMSE: 0.263\n",
      "SMLL: -0.357\n",
      "MLL: 1.851\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 1906/2500 [01:40<00:54, 10.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1900\n",
      "Loss: 2.416\n",
      "ELBO: -790.987\n",
      "IWAE: -764.887\n",
      "SMSE: 0.259\n",
      "SMLL: -0.367\n",
      "MLL: 1.842\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 1954/2500 [01:43<00:52, 10.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1950\n",
      "Loss: 2.257\n",
      "ELBO: -775.619\n",
      "IWAE: -756.984\n",
      "SMSE: 0.247\n",
      "SMLL: -0.381\n",
      "MLL: 1.827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 2004/2500 [01:46<00:57,  8.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2000\n",
      "Loss: 2.291\n",
      "ELBO: -768.430\n",
      "IWAE: -741.775\n",
      "SMSE: 0.248\n",
      "SMLL: -0.385\n",
      "MLL: 1.824\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 2054/2500 [01:48<00:35, 12.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2050\n",
      "Loss: 2.241\n",
      "ELBO: -753.864\n",
      "IWAE: -728.408\n",
      "SMSE: 0.257\n",
      "SMLL: -0.361\n",
      "MLL: 1.847\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 2104/2500 [01:51<00:30, 13.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2100\n",
      "Loss: 2.258\n",
      "ELBO: -754.578\n",
      "IWAE: -720.410\n",
      "SMSE: 0.274\n",
      "SMLL: -0.290\n",
      "MLL: 1.918\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|████████▌ | 2154/2500 [01:53<00:32, 10.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2150\n",
      "Loss: 2.254\n",
      "ELBO: -743.524\n",
      "IWAE: -721.176\n",
      "SMSE: 0.252\n",
      "SMLL: -0.382\n",
      "MLL: 1.826\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████▊ | 2204/2500 [01:56<00:29, 10.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2200\n",
      "Loss: 2.278\n",
      "ELBO: -749.416\n",
      "IWAE: -726.735\n",
      "SMSE: 0.256\n",
      "SMLL: -0.350\n",
      "MLL: 1.858\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 2254/2500 [01:58<00:22, 11.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2250\n",
      "Loss: 2.142\n",
      "ELBO: -737.270\n",
      "IWAE: -718.791\n",
      "SMSE: 0.233\n",
      "SMLL: -0.475\n",
      "MLL: 1.733\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 2306/2500 [02:01<00:15, 12.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2300\n",
      "Loss: 2.227\n",
      "ELBO: -754.155\n",
      "IWAE: -733.619\n",
      "SMSE: 0.258\n",
      "SMLL: -0.335\n",
      "MLL: 1.874\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|█████████▍| 2356/2500 [02:03<00:14, 10.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2350\n",
      "Loss: 2.168\n",
      "ELBO: -753.526\n",
      "IWAE: -731.759\n",
      "SMSE: 0.255\n",
      "SMLL: -0.400\n",
      "MLL: 1.808\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 2404/2500 [02:06<00:08, 11.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2400\n",
      "Loss: 2.020\n",
      "ELBO: -715.928\n",
      "IWAE: -698.848\n",
      "SMSE: 0.267\n",
      "SMLL: -0.253\n",
      "MLL: 1.955\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 2454/2500 [02:09<00:03, 11.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2450\n",
      "Loss: 2.182\n",
      "ELBO: -738.064\n",
      "IWAE: -712.501\n",
      "SMSE: 0.292\n",
      "SMLL: -0.215\n",
      "MLL: 1.993\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2500/2500 [02:11<00:00, 18.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2499\n",
      "Loss: 2.026\n",
      "ELBO: -723.456\n",
      "IWAE: -693.300\n",
      "SMSE: 0.234\n",
      "SMLL: -0.507\n",
      "MLL: 1.701\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [346.74476053203944,\n",
       "  19.58532280036387,\n",
       "  10.607473773868483,\n",
       "  8.060885151227142,\n",
       "  6.642055849523875,\n",
       "  6.1545152533523675,\n",
       "  5.687167656677102,\n",
       "  5.584887568009737,\n",
       "  5.104042071477486,\n",
       "  4.754536432468518,\n",
       "  4.494762210605408,\n",
       "  4.071044815700114,\n",
       "  3.928623078063365,\n",
       "  3.8582847893919854,\n",
       "  3.7152013786134446,\n",
       "  3.630275983197366,\n",
       "  3.4757992580103116,\n",
       "  3.5074960473514665,\n",
       "  3.2515753452532734,\n",
       "  3.10933987192036,\n",
       "  3.1725345527620994,\n",
       "  3.0733290152301556,\n",
       "  3.13165356992171,\n",
       "  2.7584252510052454,\n",
       "  3.030950945443345,\n",
       "  2.878830987889991,\n",
       "  2.7420718392510737,\n",
       "  2.7948911879130844,\n",
       "  2.803508609740399,\n",
       "  2.65606557926659,\n",
       "  2.6136287813918657,\n",
       "  2.717137569105759,\n",
       "  2.435533699784336,\n",
       "  2.529956904225853,\n",
       "  2.563174819726486,\n",
       "  2.4111459312461263,\n",
       "  2.52503837226447,\n",
       "  2.543405274266306,\n",
       "  2.4155694975543986,\n",
       "  2.256899092947536,\n",
       "  2.2908360118870226,\n",
       "  2.2412405600406418,\n",
       "  2.2581454152387703,\n",
       "  2.253857771711551,\n",
       "  2.27776182026347,\n",
       "  2.142416755474027,\n",
       "  2.226599491178271,\n",
       "  2.1680413153831286,\n",
       "  2.0202267276767545,\n",
       "  2.1818735960958677,\n",
       "  2.026418143926456],\n",
       " 'elbos': [tensor(-70586.0534, grad_fn=<AddBackward0>),\n",
       "  tensor(-4472.1710, grad_fn=<AddBackward0>),\n",
       "  tensor(-2539.5998, grad_fn=<AddBackward0>),\n",
       "  tensor(-2001.1869, grad_fn=<AddBackward0>),\n",
       "  tensor(-1744.2378, grad_fn=<AddBackward0>),\n",
       "  tensor(-1609.3274, grad_fn=<AddBackward0>),\n",
       "  tensor(-1509.6645, grad_fn=<AddBackward0>),\n",
       "  tensor(-1495.6860, grad_fn=<AddBackward0>),\n",
       "  tensor(-1363.4650, grad_fn=<AddBackward0>),\n",
       "  tensor(-1283.5036, grad_fn=<AddBackward0>),\n",
       "  tensor(-1218.6295, grad_fn=<AddBackward0>),\n",
       "  tensor(-1172.6031, grad_fn=<AddBackward0>),\n",
       "  tensor(-1126.9491, grad_fn=<AddBackward0>),\n",
       "  tensor(-1089.5252, grad_fn=<AddBackward0>),\n",
       "  tensor(-1096.3486, grad_fn=<AddBackward0>),\n",
       "  tensor(-1030.5508, grad_fn=<AddBackward0>),\n",
       "  tensor(-993.4229, grad_fn=<AddBackward0>),\n",
       "  tensor(-968.8607, grad_fn=<AddBackward0>),\n",
       "  tensor(-938.3787, grad_fn=<AddBackward0>),\n",
       "  tensor(-915.5384, grad_fn=<AddBackward0>),\n",
       "  tensor(-901.3187, grad_fn=<AddBackward0>),\n",
       "  tensor(-889.6897, grad_fn=<AddBackward0>),\n",
       "  tensor(-890.6147, grad_fn=<AddBackward0>),\n",
       "  tensor(-871.0277, grad_fn=<AddBackward0>),\n",
       "  tensor(-892.6676, grad_fn=<AddBackward0>),\n",
       "  tensor(-866.1652, grad_fn=<AddBackward0>),\n",
       "  tensor(-855.3055, grad_fn=<AddBackward0>),\n",
       "  tensor(-870.5813, grad_fn=<AddBackward0>),\n",
       "  tensor(-844.2427, grad_fn=<AddBackward0>),\n",
       "  tensor(-837.7562, grad_fn=<AddBackward0>),\n",
       "  tensor(-833.7421, grad_fn=<AddBackward0>),\n",
       "  tensor(-842.6932, grad_fn=<AddBackward0>),\n",
       "  tensor(-808.5944, grad_fn=<AddBackward0>),\n",
       "  tensor(-822.8639, grad_fn=<AddBackward0>),\n",
       "  tensor(-795.7288, grad_fn=<AddBackward0>),\n",
       "  tensor(-806.6894, grad_fn=<AddBackward0>),\n",
       "  tensor(-797.8667, grad_fn=<AddBackward0>),\n",
       "  tensor(-783.4691, grad_fn=<AddBackward0>),\n",
       "  tensor(-790.9865, grad_fn=<AddBackward0>),\n",
       "  tensor(-775.6188, grad_fn=<AddBackward0>),\n",
       "  tensor(-768.4303, grad_fn=<AddBackward0>),\n",
       "  tensor(-753.8644, grad_fn=<AddBackward0>),\n",
       "  tensor(-754.5778, grad_fn=<AddBackward0>),\n",
       "  tensor(-743.5243, grad_fn=<AddBackward0>),\n",
       "  tensor(-749.4164, grad_fn=<AddBackward0>),\n",
       "  tensor(-737.2701, grad_fn=<AddBackward0>),\n",
       "  tensor(-754.1548, grad_fn=<AddBackward0>),\n",
       "  tensor(-753.5263, grad_fn=<AddBackward0>),\n",
       "  tensor(-715.9279, grad_fn=<AddBackward0>),\n",
       "  tensor(-738.0637, grad_fn=<AddBackward0>),\n",
       "  tensor(-723.4561, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [tensor(-68993.3563, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-4318.6183, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-2445.9781, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1912.1143, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1682.6100, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1557.6266, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1443.7794, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1411.5966, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1320.5583, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1225.2692, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1167.4733, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1117.1144, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1081.0642, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1047.4048, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1057.5184, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-1001.4919, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-951.9401, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-939.8915, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-900.5317, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-887.8808, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-872.1966, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-857.5137, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-858.5392, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-852.2535, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-861.8946, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-829.5653, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-826.0997, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-828.8422, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-815.0763, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-810.0682, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-802.3108, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-811.1894, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-784.5446, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-797.6742, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-761.9068, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-775.8796, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-773.3117, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-754.8020, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-764.8866, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-756.9838, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-741.7755, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-728.4079, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-720.4101, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-721.1763, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-726.7352, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-718.7913, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-733.6185, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-731.7592, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-698.8477, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-712.5014, grad_fn=<LogsumexpBackward>),\n",
       "  tensor(-693.3000, grad_fn=<LogsumexpBackward>)],\n",
       " 'smses': [2.374138728271379,\n",
       "  0.5094045847461567,\n",
       "  0.5009408664648348,\n",
       "  0.4366111192323337,\n",
       "  0.4385968482733264,\n",
       "  0.42374225525992304,\n",
       "  0.3930103906101734,\n",
       "  0.3322968885425008,\n",
       "  0.33701980706537116,\n",
       "  0.34396044072815846,\n",
       "  0.3377548032044637,\n",
       "  0.28034902762093744,\n",
       "  0.25998172818838733,\n",
       "  0.23748280667716093,\n",
       "  0.30235253844847576,\n",
       "  0.2553873153934216,\n",
       "  0.2503776182432513,\n",
       "  0.25510645252109915,\n",
       "  0.23328949433955073,\n",
       "  0.267075104918187,\n",
       "  0.25454092104052944,\n",
       "  0.2470776208640013,\n",
       "  0.24713975833229643,\n",
       "  0.23760619999629792,\n",
       "  0.2639263755727199,\n",
       "  0.25869048605396666,\n",
       "  0.2624932837092947,\n",
       "  0.27839398719702585,\n",
       "  0.2589923819907612,\n",
       "  0.26872879681989353,\n",
       "  0.2810517870186033,\n",
       "  0.27364790315838006,\n",
       "  0.24912639877912054,\n",
       "  0.24931449831858418,\n",
       "  0.2852294104432995,\n",
       "  0.23882371189301912,\n",
       "  0.2658228355592492,\n",
       "  0.2631408827524185,\n",
       "  0.25929040292296374,\n",
       "  0.2470131984683852,\n",
       "  0.24813614535570105,\n",
       "  0.25731716381297415,\n",
       "  0.2739021848474916,\n",
       "  0.2517761714791271,\n",
       "  0.2556251074075102,\n",
       "  0.23325315192578877,\n",
       "  0.2578523548140437,\n",
       "  0.255071429195385,\n",
       "  0.2669098551296191,\n",
       "  0.29162032492435835,\n",
       "  0.23443884921702118],\n",
       " 'smlls': [6.050669947723779,\n",
       "  0.11455755352936688,\n",
       "  0.020143034654175922,\n",
       "  -0.1740551129812585,\n",
       "  -0.2326367301026544,\n",
       "  -0.2918682108224697,\n",
       "  -0.3861292381770689,\n",
       "  -0.5054400535933957,\n",
       "  -0.5001438616653279,\n",
       "  -0.5048117156894174,\n",
       "  -0.498747200590987,\n",
       "  -0.6100957585283683,\n",
       "  -0.6471468960463281,\n",
       "  -0.7114256063313852,\n",
       "  -0.5433015108505184,\n",
       "  -0.6321991830993908,\n",
       "  -0.6586592516685527,\n",
       "  -0.6411879406876541,\n",
       "  -0.6869857515616703,\n",
       "  -0.5757710732038064,\n",
       "  -0.6258560835748945,\n",
       "  -0.6047104382154241,\n",
       "  -0.6193265159971573,\n",
       "  -0.6407578080993476,\n",
       "  -0.5138762401876907,\n",
       "  -0.5174860884150249,\n",
       "  -0.49084991746105394,\n",
       "  -0.4124652081507247,\n",
       "  -0.45301346260121494,\n",
       "  -0.44592410050046344,\n",
       "  -0.3082137540311573,\n",
       "  -0.3641256101424124,\n",
       "  -0.4757996198706996,\n",
       "  -0.4248217075428719,\n",
       "  -0.2619271228437774,\n",
       "  -0.4747945926730916,\n",
       "  -0.32420315195296,\n",
       "  -0.35682681757653717,\n",
       "  -0.36669326388456563,\n",
       "  -0.3813946462703026,\n",
       "  -0.3847162043739832,\n",
       "  -0.36110066067969915,\n",
       "  -0.29025856114218435,\n",
       "  -0.3820327356501885,\n",
       "  -0.3500268206692856,\n",
       "  -0.4748167804409676,\n",
       "  -0.33458851422075525,\n",
       "  -0.4004297859042829,\n",
       "  -0.25349794276818854,\n",
       "  -0.21547451494203784,\n",
       "  -0.5069757923704039],\n",
       " 'mlls': [8.258911539360147,\n",
       "  2.3227991451657353,\n",
       "  2.2283846262905445,\n",
       "  2.03418647865511,\n",
       "  1.9756048615337136,\n",
       "  1.9163733808138985,\n",
       "  1.8221123534592996,\n",
       "  1.7028015380429729,\n",
       "  1.7080977299710405,\n",
       "  1.703429875946951,\n",
       "  1.7094943910453813,\n",
       "  1.598145833108,\n",
       "  1.5610946955900402,\n",
       "  1.4968159853049832,\n",
       "  1.6649400807858499,\n",
       "  1.5760424085369777,\n",
       "  1.5495823399678157,\n",
       "  1.5670536509487143,\n",
       "  1.521255840074698,\n",
       "  1.632470518432562,\n",
       "  1.582385508061474,\n",
       "  1.603531153420944,\n",
       "  1.5889150756392112,\n",
       "  1.5674837835370206,\n",
       "  1.6943653514486776,\n",
       "  1.6907555032213437,\n",
       "  1.7173916741753146,\n",
       "  1.7957763834856435,\n",
       "  1.7552281290351532,\n",
       "  1.7623174911359047,\n",
       "  1.9000278376052109,\n",
       "  1.844115981493956,\n",
       "  1.7324419717656687,\n",
       "  1.7834198840934965,\n",
       "  1.9463144687925908,\n",
       "  1.7334469989632766,\n",
       "  1.8840384396834082,\n",
       "  1.8514147740598312,\n",
       "  1.8415483277518028,\n",
       "  1.8268469453660658,\n",
       "  1.8235253872623851,\n",
       "  1.8471409309566693,\n",
       "  1.9179830304941838,\n",
       "  1.8262088559861798,\n",
       "  1.8582147709670827,\n",
       "  1.7334248111954007,\n",
       "  1.8736530774156133,\n",
       "  1.8078118057320853,\n",
       "  1.95474364886818,\n",
       "  1.9927670766943306,\n",
       "  1.7012657992659646]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_eeg(vae_model, gpvae.estimators.vae_estimators.td_estimator, \n",
    "    loader, vae_args, gpvae.estimators.vae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.vae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/2500 [00:01<39:46,  1.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: -14.896\n",
      "ELBO: -959.808\n",
      "IWAE: -930.113\n",
      "SMSE: 1.992\n",
      "SMLL: 1.293\n",
      "MLL: 3.501\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 52/2500 [00:06<10:08,  4.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50\n",
      "Loss: -15.642\n",
      "ELBO: -860.503\n",
      "IWAE: -835.450\n",
      "SMSE: 1.745\n",
      "SMLL: 0.619\n",
      "MLL: 2.827\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 103/2500 [00:12<08:46,  4.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: -16.405\n",
      "ELBO: -826.254\n",
      "IWAE: -796.506\n",
      "SMSE: 1.747\n",
      "SMLL: 0.634\n",
      "MLL: 2.843\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 153/2500 [00:18<08:10,  4.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 150\n",
      "Loss: -16.749\n",
      "ELBO: -805.053\n",
      "IWAE: -780.142\n",
      "SMSE: 1.679\n",
      "SMLL: 0.625\n",
      "MLL: 2.833\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|▋         | 174/2500 [00:20<04:29,  8.64it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-bcff930cf7ab>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m train_eeg(pn_model, gpvae.estimators.gpvae_estimators.td_estimator, \n\u001b[1;32m      2\u001b[0m     \u001b[0mloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)\n\u001b[0m",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/experiments/eeg/train_eeg.py\u001b[0m in \u001b[0;36mtrain_eeg\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model)\u001b[0m\n\u001b[1;32m     68\u001b[0m                     decoder_scale=args['decoder_scale'])\n\u001b[1;32m     69\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     71\u001b[0m             \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    193\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    194\u001b[0m         \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_eeg(pn_model, gpvae.estimators.gpvae_estimators.td_estimator, \n",
    "    loader, pn_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    gpvae.estimators.gpvae_estimators.iwae_estimator, save_model=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3069a2e37c804d339f806d117f8340ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 103696.897\n",
      "ELBO: -69171.075\n",
      "SMSE: 1.644\n",
      "SMLL: 21.328\n",
      "MLL: 23.536\n",
      "\n",
      "Epoch 50\n",
      "Loss: -0.384\n",
      "ELBO: -6518.537\n",
      "SMSE: 0.923\n",
      "SMLL: 3.079\n",
      "MLL: 5.287\n",
      "\n",
      "Epoch 100\n",
      "Loss: -2.640\n",
      "ELBO: -3019.526\n",
      "SMSE: 0.709\n",
      "SMLL: 1.297\n",
      "MLL: 3.505\n",
      "\n",
      "Epoch 150\n",
      "Loss: -4.285\n",
      "ELBO: -2023.412\n",
      "SMSE: 0.642\n",
      "SMLL: 0.726\n",
      "MLL: 2.934\n",
      "\n",
      "Epoch 200\n",
      "Loss: 3.129\n",
      "ELBO: -533.920\n",
      "SMSE: 0.487\n",
      "SMLL: 0.056\n",
      "MLL: 2.264\n",
      "\n",
      "Epoch 250\n",
      "Loss: 1.523\n",
      "ELBO: -460.166\n",
      "SMSE: 0.506\n",
      "SMLL: 0.018\n",
      "MLL: 2.226\n",
      "\n",
      "Epoch 300\n",
      "Loss: 1.535\n",
      "ELBO: -394.677\n",
      "SMSE: 0.573\n",
      "SMLL: 0.121\n",
      "MLL: 2.329\n",
      "\n",
      "Epoch 350\n",
      "Loss: 0.537\n",
      "ELBO: -345.241\n",
      "SMSE: 0.623\n",
      "SMLL: 0.098\n",
      "MLL: 2.306\n",
      "\n",
      "Epoch 400\n",
      "Loss: -0.057\n",
      "ELBO: -306.234\n",
      "SMSE: 0.695\n",
      "SMLL: 0.187\n",
      "MLL: 2.396\n",
      "\n",
      "Epoch 450\n",
      "Loss: 0.733\n",
      "ELBO: -275.027\n",
      "SMSE: 0.815\n",
      "SMLL: 0.348\n",
      "MLL: 2.556\n",
      "\n",
      "Epoch 500\n",
      "Loss: 0.282\n",
      "ELBO: -247.775\n",
      "SMSE: 0.976\n",
      "SMLL: 0.591\n",
      "MLL: 2.799\n",
      "\n",
      "Epoch 550\n",
      "Loss: 0.404\n",
      "ELBO: -225.578\n",
      "SMSE: 1.391\n",
      "SMLL: 1.182\n",
      "MLL: 3.390\n",
      "\n",
      "Epoch 600\n",
      "Loss: 0.369\n",
      "ELBO: -207.476\n",
      "SMSE: 1.426\n",
      "SMLL: 1.083\n",
      "MLL: 3.292\n",
      "\n",
      "Epoch 650\n",
      "Loss: -0.264\n",
      "ELBO: -193.162\n",
      "SMSE: 1.270\n",
      "SMLL: 0.768\n",
      "MLL: 2.977\n",
      "\n",
      "Epoch 700\n",
      "Loss: -0.179\n",
      "ELBO: -182.199\n",
      "SMSE: 1.277\n",
      "SMLL: 0.821\n",
      "MLL: 3.029\n",
      "\n",
      "Epoch 750\n",
      "Loss: -1.072\n",
      "ELBO: -163.693\n",
      "SMSE: 1.210\n",
      "SMLL: 0.452\n",
      "MLL: 2.660\n",
      "\n",
      "Epoch 800\n",
      "Loss: -1.038\n",
      "ELBO: -157.650\n",
      "SMSE: 1.754\n",
      "SMLL: 1.398\n",
      "MLL: 3.606\n",
      "\n",
      "Epoch 850\n",
      "Loss: -0.720\n",
      "ELBO: -139.162\n",
      "SMSE: 1.636\n",
      "SMLL: 1.187\n",
      "MLL: 3.395\n",
      "\n",
      "Epoch 900\n",
      "Loss: -0.687\n",
      "ELBO: -124.214\n",
      "SMSE: 1.858\n",
      "SMLL: 1.481\n",
      "MLL: 3.690\n",
      "\n",
      "Epoch 950\n",
      "Loss: -1.288\n",
      "ELBO: -103.094\n",
      "SMSE: 2.001\n",
      "SMLL: 1.705\n",
      "MLL: 3.913\n",
      "\n",
      "Epoch 1000\n",
      "Loss: -1.204\n",
      "ELBO: -89.300\n",
      "SMSE: 1.655\n",
      "SMLL: 1.463\n",
      "MLL: 3.671\n",
      "\n",
      "Epoch 1050\n",
      "Loss: -1.684\n",
      "ELBO: -81.337\n",
      "SMSE: 1.625\n",
      "SMLL: 1.378\n",
      "MLL: 3.586\n",
      "\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-60-e6be27bdeb0e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m condition_train_eeg(pog_model, td_estimator, \n\u001b[1;32m      2\u001b[0m     \u001b[0mloader_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0melbo_estimator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     None, save_model=False)\n\u001b[0m",
      "\u001b[0;32m<ipython-input-35-1a8cd5e14582>\u001b[0m in \u001b[0;36mcondition_train_eeg\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model)\u001b[0m\n\u001b[1;32m     40\u001b[0m                 decoder_scale=args['decoder_scale'])\n\u001b[1;32m     41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     43\u001b[0m             \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    193\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    194\u001b[0m         \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     97\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     98\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "condition_train_eeg(pog_model, td_estimator, \n",
    "    loader_c, pog_args, gpvae.estimators.gpvae_estimators.elbo_estimator,\n",
    "    None, save_model=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions and class definitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpvae.utils.gaussian_utils import gaussian_diagonal_ll, gaussian_diagonal_kl\n",
    "\n",
    "from gpytorch.distributions.multivariate_normal import MultivariateNormal\n",
    "from gpytorch.lazy import lazify\n",
    "\n",
    "def td_estimator(model, x, y, y_c, mask=None, mask_c=None,\n",
    "                 num_samples=1, decoder_scale=None, make_lazy=True, mf=False, idx=None):\n",
    "    \"\"\"Estimates the gradient of the negative ELBO using the total \n",
    "    derivative of the reparaemterisation trick for the GPVAE model.\n",
    "\n",
    "    :param model: A nn.Module, the model to evaluate on.\n",
    "    :param x: A torch.Tensor, the input data.\n",
    "    :param y: A torch.Tensor, the output data.\n",
    "    :param mask: A torch.Tensor, the mask to apply to the output data.\n",
    "    :param num_samples: An int, the number of samples to estimate the ELBO\n",
    "    gradient with.\n",
    "    :param decoder_scale: None or a float, the amount by which to scale the\n",
    "    decoder term, p(y|f), by. Relevant in the presence of missing values.\n",
    "    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal\n",
    "    class for handling multivariate Gaussians.\n",
    "    :param mf: A bool, whether to model uses mean-field variational\n",
    "    inference or not.\n",
    "    :param idx: A torch.Tensor, the data indeces.\n",
    "    \"\"\"\n",
    "    if mask is not None:\n",
    "        # Scale decoder terms by the reciprocal of the proportion of missing\n",
    "        # observations.\n",
    "        if decoder_scale is None:\n",
    "            num_nan = 1. * torch.sum(abs(1 - mask))\n",
    "            num_observations = y.shape[0] * y.shape[1]\n",
    "            decoder_scale = 1. / (1. - num_nan / num_observations)\n",
    "    else:\n",
    "        decoder_scale = 1.\n",
    "\n",
    "    estimator = 0\n",
    "\n",
    "    # Latent distributions.\n",
    "    if mf:\n",
    "        # Pass mean-field models the data indeces.\n",
    "        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \\\n",
    "            model.get_latent_dists(x, idx=idx)\n",
    "    else:\n",
    "        # Pass amortisation models the observation data.\n",
    "        qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = \\\n",
    "            model.get_latent_dists(x, y_c, mask_c)\n",
    "\n",
    "    # Required distributions.\n",
    "    if make_lazy:\n",
    "        # Use GPyTorch MultivariateNormal class for sampling.\n",
    "        qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "        pf = MultivariateNormal(pf_mu, lazify(pf_cov))\n",
    "    else:\n",
    "        qf = MultivariateNormal(qf_mu, qf_cov)\n",
    "        pf = MultivariateNormal(pf_mu, pf_cov)\n",
    "\n",
    "    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])\n",
    "\n",
    "    # Monte-Carlo estimate of ELBO gradient.\n",
    "    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.\n",
    "    for _ in range(num_samples):\n",
    "        f = qf.rsample()\n",
    "\n",
    "        # log p(y|f) term.\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)\n",
    "        py_f_term = decoder_scale * py_f_term.sum()\n",
    "        estimator += py_f_term\n",
    "\n",
    "        # log l(f|y) term.\n",
    "        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),\n",
    "                                         lf_y_var.detach())\n",
    "        lf_y_term = lf_y_term.sum()\n",
    "        estimator += - lf_y_term\n",
    "\n",
    "        # log p(f) term.\n",
    "        pf_term = pf.log_prob(f.detach()).sum()\n",
    "        estimator += pf_term\n",
    "\n",
    "    # Inner summation over samples from q(f).\n",
    "    estimator /= num_samples\n",
    "\n",
    "    # Outer summation over batch.\n",
    "    estimator /= x.shape[0]\n",
    "\n",
    "    return - estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mean-field models.\n",
    "mf_models = [gpvae.models.TitsiasSparseGPVAE]\n",
    "\n",
    "def condition_train_eeg(model, loss_fn, loader, args, elbo_estimator=None,\n",
    "              iwae_estimator=None, normalised=True, save_model=True, p=0.1):\n",
    "    metrics = {'epochs': [],\n",
    "               'losses': [],\n",
    "               'elbos': [],\n",
    "               'iwaes': [],\n",
    "               'smses': [],\n",
    "               'smlls': [],\n",
    "               'mlls': []\n",
    "               }\n",
    "\n",
    "    model.train(True)\n",
    "    optimiser = optim.Adam(model.parameters(), lr=args['lr'])\n",
    "\n",
    "    # Get dataset.\n",
    "    dataset = loader.dataset.dataset()\n",
    "    if loader.dataset.contains_nan:\n",
    "        x, y, y_c, m, m_c, idx = dataset\n",
    "    else:\n",
    "        x, y, y_c, m_c, idx = dataset\n",
    "        m = None\n",
    "\n",
    "    # Training.\n",
    "    for epoch in tqdm(range(args['epochs'])):\n",
    "        epoch_losses = []\n",
    "        for i, batch in enumerate(loader):\n",
    "            if loader.dataset.contains_nan:\n",
    "                x_b, y_b, y_c_b, m_b, m_c_b, idx_b = batch\n",
    "            else:\n",
    "                x_b, y_b, y_c_b, m_c_b, idx_b = batch\n",
    "                m_b = None\n",
    "                \n",
    "            # Just set random proportion of batch to missing.\n",
    "            y_c_b = copy.deepcopy(y_b)\n",
    "            set_to_nan = np.random.choice(a=[False, True], size=y.shape, p=[1-p, p])\n",
    "            y_c[set_to_nan] = np.nan\n",
    "\n",
    "            optimiser.zero_grad()\n",
    "\n",
    "            loss = loss_fn(\n",
    "                model, x=x_b, y=y_b, y_c=y_c_b, mask=m_b, mask_c=y_c_b, num_samples=1,\n",
    "                decoder_scale=args['decoder_scale'])\n",
    "\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "\n",
    "            epoch_losses.append(loss.item())\n",
    "\n",
    "        # Evaluate model.\n",
    "        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):\n",
    "            model.eval()\n",
    "\n",
    "            report = 'Epoch {}\\n'.format(epoch)\n",
    "\n",
    "            # Average loss over previous epoch.\n",
    "            mean_loss = np.mean(epoch_losses)\n",
    "            metrics['losses'].append(mean_loss)\n",
    "            report += 'Loss: {:.3f}\\n'.format(mean_loss)\n",
    "\n",
    "            if elbo_estimator is not None:\n",
    "                # ELBO estimate.\n",
    "                if type(model) in mf_models:\n",
    "                    elbo = elbo_estimator(\n",
    "                        model, x, y, mask=m, num_samples=args['elbo_samples'],\n",
    "                        mf=True, idx=idx)\n",
    "                else:\n",
    "                    elbo = elbo_estimator(\n",
    "                        model, x, y, mask=m, num_samples=args['elbo_samples'])\n",
    "\n",
    "                metrics['elbos'].append(elbo)\n",
    "                report += 'ELBO: {:.3f}\\n'.format(elbo)\n",
    "\n",
    "            if iwae_estimator is not None:\n",
    "                # IWAE estimate.\n",
    "                if type(model) in mf_models:\n",
    "                    iwae = iwae_estimator(\n",
    "                        model, x, y, mask=m, num_samples=args['elbo_samples'],\n",
    "                        mf=True, idx=idx)\n",
    "                else:\n",
    "                    iwae = iwae_estimator(\n",
    "                        model, x, y, mask=m, num_samples=args['elbo_samples'])\n",
    "\n",
    "                metrics['iwaes'].append(iwae)\n",
    "                report += 'IWAE: {:.3f}\\n'.format(iwae)\n",
    "\n",
    "            if test is not None:\n",
    "                # Test predictions.\n",
    "                if type(model) in mf_models:\n",
    "                    mean, sigma = model.predict_y(\n",
    "                        x=x, idx=idx, num_samples=args['test_samples'])[:2]\n",
    "                else:\n",
    "                    mean, sigma = model.predict_y(\n",
    "                        x=x, y=y, mask=m, num_samples=args['test_samples'])[:2]\n",
    "\n",
    "                mean, sigma = mean.numpy(), sigma.numpy()\n",
    "\n",
    "                if normalised:\n",
    "                    mean = mean * y_std + y_mean\n",
    "                    sigma = sigma * y_std\n",
    "\n",
    "                # Evaluate test predictions.\n",
    "                pred = pd.DataFrame(mean, index=train.index,\n",
    "                                    columns=train.columns)\n",
    "                var = pd.DataFrame(sigma ** 2, index=train.index,\n",
    "                                   columns=train.columns)\n",
    "\n",
    "                smse = metric.smse(pred, test).mean()\n",
    "                smll = metric.smll(pred, var, test).mean()\n",
    "                mll = metric.mll(pred, var, test).mean()\n",
    "\n",
    "                metrics['smses'].append(smse)\n",
    "                metrics['smlls'].append(smll)\n",
    "                metrics['mlls'].append(mll)\n",
    "                report += 'SMSE: {:.3f}\\n'.format(smse)\n",
    "                report += 'SMLL: {:.3f}\\n'.format(smll)\n",
    "                report += 'MLL: {:.3f}\\n'.format(mll)\n",
    "\n",
    "            # Report model performance.\n",
    "            tqdm.write(report)\n",
    "\n",
    "            model.train(True)\n",
    "\n",
    "    if save_model:\n",
    "        # Save model, hyperparameters and metrics.\n",
    "        save(model, args, metrics)\n",
    "\n",
    "    return metrics\n",
    "\n",
    "\n",
    "def save(model, args, metrics):\n",
    "    if 'model' not in args:\n",
    "        print(\"Error: 'model' does not exist in args. Aborting save.\")\n",
    "        return\n",
    "\n",
    "    if 'results_dir' in args.keys():\n",
    "        results_dir = args['results_dir'] + args['model']\n",
    "    else:\n",
    "        results_dir = '_results/' + args['model']\n",
    "    if os.path.isdir(results_dir):\n",
    "        i = 1\n",
    "        while os.path.isdir(results_dir + '_' + str(i)):\n",
    "            i += 1\n",
    "\n",
    "        results_dir = results_dir + '_' + str(i)\n",
    "\n",
    "    os.makedirs(results_dir)\n",
    "    results_path = results_dir + '/report.txt'\n",
    "    model_path = results_dir + '/model_state_dict.pt'\n",
    "\n",
    "    # Pickle args and metrics.\n",
    "    with open(results_dir + '/args.pkl', 'wb') as f:\n",
    "        pickle.dump(args, f)\n",
    "\n",
    "    with open(results_dir + '/metrics.pkl', 'wb') as f:\n",
    "        pickle.dump(metrics, f)\n",
    "\n",
    "    # Save args and results in text format.\n",
    "    with open(results_path, 'w') as f:\n",
    "        f.write('Args: \\n')\n",
    "        if isinstance(args, list):\n",
    "            for d in args:\n",
    "                f.write(str(d) + '\\n')\n",
    "        else:\n",
    "            f.write(str(args) + '\\n')\n",
    "\n",
    "        f.write('\\nPerformance: \\n')\n",
    "        for (key, values) in metrics.items():\n",
    "            try:\n",
    "                f.write('{}: {}\\n'.format(key, values[-1]))\n",
    "            except IndexError:\n",
    "                pass\n",
    "\n",
    "    # Save model.state_dict().\n",
    "    torch.save(model.state_dict(), model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConditionTupleDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, x, y, y_c, contains_nan=False):\n",
    "        super().__init__()\n",
    "\n",
    "        assert len(x) == len(y), 'x and y must be the same length.'\n",
    "\n",
    "        if len(x.shape) == 1:\n",
    "            # Ensure inputs are 2-dimensional.\n",
    "            self.x = x.unsqueeze(1)\n",
    "        else:\n",
    "            self.x = x\n",
    "\n",
    "        if contains_nan:\n",
    "            self.y = copy.deepcopy(y)\n",
    "            self.y_c = copy.deepcopy(y_c)\n",
    "            self.m = torch.ones_like(y).fill_(True)\n",
    "            self.m_c = torch.ones_like(y_c).fill_(True)\n",
    "\n",
    "            # Identify nan values and replace with 0.\n",
    "            m_idx = torch.isnan(y)\n",
    "            m_c_idx = torch.isnan(y_c)\n",
    "            self.m[m_idx] = False\n",
    "            self.y[m_idx] = 0.\n",
    "            self.m_c[m_c_idx] = False\n",
    "            self.y_c[m_c_idx] = 0.\n",
    "        else:\n",
    "            self.y = y\n",
    "            self.y_c = copy.deepcopy(y_c)\n",
    "            self.m_c = torch.ones_like(y_c).fill_(True)\n",
    "            \n",
    "            # Identify nan values and replace with 0.\n",
    "            m_c_idx = torch.isnan(y_c)\n",
    "            self.m_c[m_c_idx] = False\n",
    "            self.y_c[m_c_idx] = 0.\n",
    "            \n",
    "        self.contains_nan = contains_nan\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = self.x[idx]\n",
    "        y = self.y[idx]\n",
    "        y_c = self.y_c[idx]\n",
    "        m_c = self.m_c[idx]\n",
    "\n",
    "        if self.contains_nan:\n",
    "            m = self.m[idx]\n",
    "            return x, y, y_c, m, m_c, idx\n",
    "        else:\n",
    "            return x, y, y_c, m_c, idx\n",
    "\n",
    "    def dataset(self):\n",
    "        idx = list(range(len(self)))\n",
    "\n",
    "        return self.__getitem__(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spatio-temporal-vae-env",
   "language": "python",
   "name": "spatio-temporal-vae-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
