{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import copy\n",
    "import pdb\n",
    "import time\n",
    "import pickle\n",
    "import warnings\n",
    "sys.path.append('../../')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "import wbml\n",
    "import wbml.out\n",
    "import wbml.experiment\n",
    "import wbml.metric\n",
    "import data.eeg\n",
    "import gpytorch\n",
    "import gpvae\n",
    "# import data\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "from torch.distributions.kl import kl_divergence\n",
    "\n",
    "# from old_gpvae.gpvae import *\n",
    "from old_gpvae.vae import *\n",
    "from old_gpvae.hvi.hvi_gpvae import *\n",
    "from old_gpvae.networks.networks import *\n",
    "from old_gpvae.mf.mf_networks import *\n",
    "from old_gpvae.kernels.kernels import *\n",
    "\n",
    "from old_gpvae.gpvae_estimators import *\n",
    "from old_gpvae.gpvae_elbo import *\n",
    "from old_gpvae.vae_estimators import *\n",
    "from old_gpvae.vae_elbo import *\n",
    "from old_gpvae.hvi.hvi_gpvae_estimators import *\n",
    "from old_gpvae.hvi.hvi_gpvae_elbo import *\n",
    "\n",
    "from old_gpvae.utils.metrics import *\n",
    "from old_gpvae.utils.datasets import *\n",
    "from old_gpvae.utils.matrix_utils import *\n",
    "from old_gpvae.utils.gaussian import *\n",
    "\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "args = {'loss_fn': gpvae_analytical_estimator,\n",
    "        'batch_size': 100,\n",
    "        'latent_dim': 5,\n",
    "        'init_lengthscale': .1,\n",
    "        'init_scale': 1.,\n",
    "        'init_period': .1,\n",
    "        'auxiliary_dim': 1,\n",
    "        'num_inducing': 100\n",
    "       }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load EEG data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, train, test = data.eeg.load()\n",
    "\n",
    "x = np.array(train.index)\n",
    "y = np.array(train)\n",
    "y_dim = y.shape[1]\n",
    "decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))\n",
    "\n",
    "# normalise data\n",
    "y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)\n",
    "y = (y - y_mean) / y_std\n",
    "\n",
    "# set up dataset and data loaders\n",
    "data = (torch.tensor(x), torch.tensor(y))\n",
    "\n",
    "new_dataset = NewTupleDataset(torch.tensor(x), torch.tensor(y), contains_nan=True)\n",
    "new_loader = DataLoader(new_dataset, batch_size=args['batch_size'], shuffle=True)\n",
    "\n",
    "amortised_dataset = TupleDataset(data, contains_nan=True)\n",
    "amortised_loader = DataLoader(amortised_dataset, batch_size=args['batch_size'],\n",
    "                              shuffle=True)\n",
    "\n",
    "mf_dataset = TupleDataset(data, contains_nan=True, return_index=True)\n",
    "mf_loader = DataLoader(mf_dataset, batch_size=args['batch_size'], shuffle=True)\n",
    "\n",
    "vfe_loader = DataLoader(amortised_dataset, batch_size=x.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "# kernel\n",
    "k1 = gpvae.kernels.RBFKernel(lengthscale=args['init_lengthscale'], \n",
    "               scale=args['init_scale']/2)\n",
    "k2 = gpvae.kernels.PeriodicKernel(lengthscale=args['init_lengthscale'],\n",
    "                    period=args['init_period'], \n",
    "                    scale=args['init_scale']/2)\n",
    "\n",
    "# k1 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "# k2 = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel())\n",
    "# k1.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k1._set_outputscale(args['init_scale'] / 2)\n",
    "# k2.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# k2.base_kernel._set_period_length(args['init_lengthscale'])\n",
    "# k2._set_outputscale(args['init_scale'] / 2)\n",
    "\n",
    "kernel = gpvae.kernels.AdditiveKernel(k1, k2)\n",
    "\n",
    "\n",
    "# rich encoder hyperparameters\n",
    "r_encoder_args = {'in_dim': y_dim, \n",
    "                  'out_dim': args['latent_dim'],\n",
    "                  'hidden_dims': [20, 20, 20, 20],\n",
    "                  'initial_sigma': 1.,\n",
    "                  'initial_mu': None\n",
    "                 }\n",
    "\n",
    "# jonny encoder hyperparameters\n",
    "j_encoder_args = {'in_dim': y_dim, \n",
    "                  'out_dim': args['latent_dim'],\n",
    "                  'middle_dim': 20,\n",
    "                  'hidden_dims': [20, 20],\n",
    "                  'shared_hidden_dims': [20, 20],\n",
    "                  'initial_sigma': 1.,\n",
    "                  'initial_mu': None\n",
    "                 }\n",
    "\n",
    "# deepset encoder hyperparameters\n",
    "ds_encoder_args = {'in_dim': y_dim, \n",
    "                   'out_dim': args['latent_dim'],\n",
    "                   'middle_dim': 20,\n",
    "                   'first_hidden_dims': [20, 20],\n",
    "                   'second_hidden_dims': [20, 20],\n",
    "                   'initial_sigma': 1.,\n",
    "                   'initial_mu': None\n",
    "                 }\n",
    "\n",
    "# decoder hyperparameters\n",
    "decoder_args = {'in_dim': args['latent_dim'],\n",
    "                'out_dim': y_dim,\n",
    "                'hidden_dims': [20, 20],\n",
    "                'sigma': 0.1,\n",
    "                'train_sigma': False\n",
    "               }\n",
    "\n",
    "# hvi encoder/decoder hyperparameters\n",
    "latent_encoder_args = {'in_dims': [args['auxiliary_dim'], y_dim],\n",
    "                       'out_dim': args['latent_dim'],\n",
    "                       'hidden_dims': [20, 20, 20, 20],\n",
    "                       'initial_sigma': 1.,\n",
    "                       'initial_mu': None,\n",
    "                       'contains_nans': [False, True]\n",
    "                      }\n",
    "auxiliary_encoder_args = {'in_dim': y_dim, \n",
    "                          'out_dim': args['auxiliary_dim'],\n",
    "                          'middle_dim': 20,\n",
    "                          'hidden_dims': [20, 20],\n",
    "                          'shared_hidden_dims': [20, 20],\n",
    "                          'initial_sigma': .1,\n",
    "                          'initial_mu': None\n",
    "                         }\n",
    "\n",
    "latent_decoder_args = {'in_dim': args['latent_dim'], \n",
    "                       'out_dim': y_dim,\n",
    "                       'hidden_dims': [20, 20],\n",
    "                       'sigma': 0.1, \n",
    "                       'train_sigma': False\n",
    "                      }\n",
    "auxiliary_decoder_args = {'in_dims': [y_dim, args['latent_dim']],\n",
    "                          'out_dim': args['auxiliary_dim'],\n",
    "                          'hidden_dims': [20, 20],\n",
    "                          'initial_sigma': 1.,\n",
    "                          'initial_mu': None,\n",
    "                          'contains_nans': [True, False]\n",
    "                         }\n",
    "\n",
    "# vfe encoder/decoder hyperparameters\n",
    "vfe_encoder_args = {'out_dim': args['latent_dim']\n",
    "                   }\n",
    "\n",
    "# vfe inducing points\n",
    "z = kmeans2(x, args['num_inducing'], minit='points')[0]\n",
    "z = torch.tensor(z)\n",
    "\n",
    "r_vfe_encoder_args = {'in_dim': y_dim, \n",
    "                      'out_dim': args['latent_dim'], \n",
    "                      'hidden_dims': [20, 20, 20, 20],\n",
    "                      'inducing_spacing': 0.025,\n",
    "                      'k': 5,\n",
    "                      'contains_nan': True\n",
    "                     }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct GPVAE models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "j_args = [args, j_encoder_args, decoder_args]\n",
    "j_encoder = gpvae.networks.IndexNet(**j_encoder_args)\n",
    "j_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "j_model = NewGPVAE(j_encoder, j_decoder, args['latent_dim'], kernel=kernel, add_jitter=False)\n",
    "\n",
    "j_model_copy = copy.deepcopy(j_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Things to check:\n",
    "* Kernels - could be on to a winner - no\n",
    "* Encoder - no\n",
    "* Decoder - no\n",
    "* Model - yes - fixed. Results are still differnt though.\n",
    "* TD estimator - no\n",
    "* ELBO estimator - no\n",
    "* Data loader - ?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train GPVAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:25: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3225dae54724cdca73d23f7fc09c1a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 42.030\n",
      "ELBO: -74133.924\n",
      "IWAE: 0.000\n",
      "SMSE: 2.011\n",
      "\n",
      "Epoch 50\n",
      "Loss: 5.270\n",
      "ELBO: -15614.654\n",
      "IWAE: 0.000\n",
      "SMSE: 1.293\n",
      "\n",
      "Epoch 100\n",
      "Loss: -1.836\n",
      "ELBO: -7004.293\n",
      "IWAE: 0.000\n",
      "SMSE: 0.474\n",
      "\n",
      "Epoch 150\n",
      "Loss: -2.053\n",
      "ELBO: -4875.128\n",
      "IWAE: 0.000\n",
      "SMSE: 0.406\n",
      "\n",
      "Epoch 200\n",
      "Loss: -2.087\n",
      "ELBO: -3099.121\n",
      "IWAE: 0.000\n",
      "SMSE: 0.354\n",
      "\n",
      "Epoch 250\n",
      "Loss: -2.837\n",
      "ELBO: -2639.443\n",
      "IWAE: 0.000\n",
      "SMSE: 0.365\n",
      "\n",
      "Epoch 300\n",
      "Loss: -3.328\n",
      "ELBO: -2363.898\n",
      "IWAE: 0.000\n",
      "SMSE: 0.326\n",
      "\n",
      "Epoch 350\n",
      "Loss: -3.747\n",
      "ELBO: -2177.769\n",
      "IWAE: 0.000\n",
      "SMSE: 0.310\n",
      "\n",
      "Epoch 400\n",
      "Loss: -4.477\n",
      "ELBO: -2165.067\n",
      "IWAE: 0.000\n",
      "SMSE: 0.325\n",
      "\n",
      "Epoch 450\n",
      "Loss: -3.268\n",
      "ELBO: -2107.385\n",
      "IWAE: 0.000\n",
      "SMSE: 0.310\n",
      "\n",
      "Epoch 500\n",
      "Loss: -3.671\n",
      "ELBO: -2020.958\n",
      "IWAE: 0.000\n",
      "SMSE: 0.332\n",
      "\n",
      "Epoch 550\n",
      "Loss: -3.712\n",
      "ELBO: -1973.444\n",
      "IWAE: 0.000\n",
      "SMSE: 0.304\n",
      "\n",
      "Epoch 600\n",
      "Loss: -3.713\n",
      "ELBO: -1937.683\n",
      "IWAE: 0.000\n",
      "SMSE: 0.327\n",
      "\n",
      "Epoch 650\n",
      "Loss: -4.128\n",
      "ELBO: -1768.519\n",
      "IWAE: 0.000\n",
      "SMSE: 0.305\n",
      "\n",
      "Epoch 700\n",
      "Loss: -4.124\n",
      "ELBO: -1742.151\n",
      "IWAE: 0.000\n",
      "SMSE: 0.332\n",
      "\n",
      "Epoch 750\n",
      "Loss: -3.616\n",
      "ELBO: -1779.831\n",
      "IWAE: 0.000\n",
      "SMSE: 0.330\n",
      "\n",
      "Epoch 800\n",
      "Loss: -4.035\n",
      "ELBO: -1754.750\n",
      "IWAE: 0.000\n",
      "SMSE: 0.313\n",
      "\n",
      "Epoch 850\n",
      "Loss: -4.052\n",
      "ELBO: -1753.223\n",
      "IWAE: 0.000\n",
      "SMSE: 0.307\n",
      "\n",
      "Epoch 900\n",
      "Loss: -4.413\n",
      "ELBO: -1679.441\n",
      "IWAE: 0.000\n",
      "SMSE: 0.275\n",
      "\n",
      "Epoch 950\n",
      "Loss: -4.429\n",
      "ELBO: -1677.702\n",
      "IWAE: 0.000\n",
      "SMSE: 0.274\n",
      "\n",
      "Epoch 1000\n",
      "Loss: -4.060\n",
      "ELBO: -1555.281\n",
      "IWAE: 0.000\n",
      "SMSE: 0.272\n",
      "\n",
      "Epoch 1050\n",
      "Loss: -4.403\n",
      "ELBO: -1686.031\n",
      "IWAE: 0.000\n",
      "SMSE: 0.255\n",
      "\n",
      "Epoch 1100\n",
      "Loss: -4.176\n",
      "ELBO: -1649.286\n",
      "IWAE: 0.000\n",
      "SMSE: 0.254\n",
      "\n",
      "Epoch 1150\n",
      "Loss: -4.114\n",
      "ELBO: -1545.806\n",
      "IWAE: 0.000\n",
      "SMSE: 0.259\n",
      "\n",
      "Epoch 1200\n",
      "Loss: -4.010\n",
      "ELBO: -1626.819\n",
      "IWAE: 0.000\n",
      "SMSE: 0.251\n",
      "\n",
      "Epoch 1250\n",
      "Loss: -4.068\n",
      "ELBO: -1515.356\n",
      "IWAE: 0.000\n",
      "SMSE: 0.239\n",
      "\n",
      "Epoch 1300\n",
      "Loss: -4.133\n",
      "ELBO: -1547.169\n",
      "IWAE: 0.000\n",
      "SMSE: 0.247\n",
      "\n",
      "Epoch 1350\n",
      "Loss: -4.251\n",
      "ELBO: -1381.913\n",
      "IWAE: 0.000\n",
      "SMSE: 0.241\n",
      "\n",
      "Epoch 1400\n",
      "Loss: -4.246\n",
      "ELBO: -1524.689\n",
      "IWAE: 0.000\n",
      "SMSE: 0.216\n",
      "\n",
      "Epoch 1450\n",
      "Loss: -4.222\n",
      "ELBO: -1421.409\n",
      "IWAE: 0.000\n",
      "SMSE: 0.226\n",
      "\n",
      "Epoch 1500\n",
      "Loss: -4.256\n",
      "ELBO: -1470.014\n",
      "IWAE: 0.000\n",
      "SMSE: 0.238\n",
      "\n",
      "Epoch 1550\n",
      "Loss: -4.420\n",
      "ELBO: -1298.303\n",
      "IWAE: 0.000\n",
      "SMSE: 0.239\n",
      "\n",
      "Epoch 1600\n",
      "Loss: -4.324\n",
      "ELBO: -1334.881\n",
      "IWAE: 0.000\n",
      "SMSE: 0.218\n",
      "\n",
      "Epoch 1650\n",
      "Loss: -4.413\n",
      "ELBO: -1386.169\n",
      "IWAE: 0.000\n",
      "SMSE: 0.204\n",
      "\n",
      "Epoch 1700\n",
      "Loss: -4.453\n",
      "ELBO: -1394.527\n",
      "IWAE: 0.000\n",
      "SMSE: 0.229\n",
      "\n",
      "Epoch 1750\n",
      "Loss: -4.605\n",
      "ELBO: -1315.827\n",
      "IWAE: 0.000\n",
      "SMSE: 0.213\n",
      "\n",
      "Epoch 1800\n",
      "Loss: -4.344\n",
      "ELBO: -1302.255\n",
      "IWAE: 0.000\n",
      "SMSE: 0.216\n",
      "\n",
      "Epoch 1850\n",
      "Loss: -4.458\n",
      "ELBO: -1302.053\n",
      "IWAE: 0.000\n",
      "SMSE: 0.228\n",
      "\n",
      "Epoch 1900\n",
      "Loss: -4.604\n",
      "ELBO: -1355.257\n",
      "IWAE: 0.000\n",
      "SMSE: 0.228\n",
      "\n",
      "Epoch 1950\n",
      "Loss: -4.690\n",
      "ELBO: -1267.042\n",
      "IWAE: 0.000\n",
      "SMSE: 0.238\n",
      "\n",
      "Epoch 1999\n",
      "Loss: -4.320\n",
      "ELBO: -1216.028\n",
      "IWAE: 0.000\n",
      "SMSE: 0.230\n",
      "\n",
      "\n",
      "Save model? (yes/no)no\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epochs': [],\n",
       " 'losses': [42.02961286129406,\n",
       "  5.270494185674203,\n",
       "  -1.8361056474746735,\n",
       "  -2.0526902174201296,\n",
       "  -2.087219792940939,\n",
       "  -2.8367282718843865,\n",
       "  -3.3277210020607613,\n",
       "  -3.747445313193794,\n",
       "  -4.476658578300886,\n",
       "  -3.2682572691169685,\n",
       "  -3.6706917892987985,\n",
       "  -3.7118883912360783,\n",
       "  -3.7129110017688753,\n",
       "  -4.128155957046746,\n",
       "  -4.124059769586735,\n",
       "  -3.615692942258663,\n",
       "  -4.035464783925122,\n",
       "  -4.0521647756884684,\n",
       "  -4.413389574626398,\n",
       "  -4.429082957236709,\n",
       "  -4.059675095788019,\n",
       "  -4.403372943857877,\n",
       "  -4.176251229294267,\n",
       "  -4.113904317147335,\n",
       "  -4.009704683585465,\n",
       "  -4.067531074603673,\n",
       "  -4.133060095647795,\n",
       "  -4.251199853902767,\n",
       "  -4.246191706839249,\n",
       "  -4.222036979333792,\n",
       "  -4.2557924849919075,\n",
       "  -4.419566452271607,\n",
       "  -4.3241336825749945,\n",
       "  -4.41287384110687,\n",
       "  -4.453487867504409,\n",
       "  -4.60534747829789,\n",
       "  -4.344232184316459,\n",
       "  -4.457542160801601,\n",
       "  -4.604201841470574,\n",
       "  -4.690382948307846,\n",
       "  -4.320188803696621],\n",
       " 'elbos': [tensor(-74133.9242, grad_fn=<AddBackward0>),\n",
       "  tensor(-15614.6545, grad_fn=<AddBackward0>),\n",
       "  tensor(-7004.2930, grad_fn=<AddBackward0>),\n",
       "  tensor(-4875.1285, grad_fn=<AddBackward0>),\n",
       "  tensor(-3099.1208, grad_fn=<AddBackward0>),\n",
       "  tensor(-2639.4434, grad_fn=<AddBackward0>),\n",
       "  tensor(-2363.8975, grad_fn=<AddBackward0>),\n",
       "  tensor(-2177.7688, grad_fn=<AddBackward0>),\n",
       "  tensor(-2165.0671, grad_fn=<AddBackward0>),\n",
       "  tensor(-2107.3846, grad_fn=<AddBackward0>),\n",
       "  tensor(-2020.9579, grad_fn=<AddBackward0>),\n",
       "  tensor(-1973.4441, grad_fn=<AddBackward0>),\n",
       "  tensor(-1937.6825, grad_fn=<AddBackward0>),\n",
       "  tensor(-1768.5188, grad_fn=<AddBackward0>),\n",
       "  tensor(-1742.1508, grad_fn=<AddBackward0>),\n",
       "  tensor(-1779.8314, grad_fn=<AddBackward0>),\n",
       "  tensor(-1754.7501, grad_fn=<AddBackward0>),\n",
       "  tensor(-1753.2228, grad_fn=<AddBackward0>),\n",
       "  tensor(-1679.4410, grad_fn=<AddBackward0>),\n",
       "  tensor(-1677.7019, grad_fn=<AddBackward0>),\n",
       "  tensor(-1555.2812, grad_fn=<AddBackward0>),\n",
       "  tensor(-1686.0315, grad_fn=<AddBackward0>),\n",
       "  tensor(-1649.2858, grad_fn=<AddBackward0>),\n",
       "  tensor(-1545.8064, grad_fn=<AddBackward0>),\n",
       "  tensor(-1626.8187, grad_fn=<AddBackward0>),\n",
       "  tensor(-1515.3565, grad_fn=<AddBackward0>),\n",
       "  tensor(-1547.1693, grad_fn=<AddBackward0>),\n",
       "  tensor(-1381.9130, grad_fn=<AddBackward0>),\n",
       "  tensor(-1524.6891, grad_fn=<AddBackward0>),\n",
       "  tensor(-1421.4090, grad_fn=<AddBackward0>),\n",
       "  tensor(-1470.0143, grad_fn=<AddBackward0>),\n",
       "  tensor(-1298.3027, grad_fn=<AddBackward0>),\n",
       "  tensor(-1334.8812, grad_fn=<AddBackward0>),\n",
       "  tensor(-1386.1688, grad_fn=<AddBackward0>),\n",
       "  tensor(-1394.5266, grad_fn=<AddBackward0>),\n",
       "  tensor(-1315.8273, grad_fn=<AddBackward0>),\n",
       "  tensor(-1302.2553, grad_fn=<AddBackward0>),\n",
       "  tensor(-1302.0526, grad_fn=<AddBackward0>),\n",
       "  tensor(-1355.2573, grad_fn=<AddBackward0>),\n",
       "  tensor(-1267.0421, grad_fn=<AddBackward0>),\n",
       "  tensor(-1216.0278, grad_fn=<AddBackward0>)],\n",
       " 'iwaes': [0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0],\n",
       " 'smses': [2.011295374124305,\n",
       "  1.2925360270468544,\n",
       "  0.4741501495896343,\n",
       "  0.4059121183160648,\n",
       "  0.3538287187793365,\n",
       "  0.3654505292145975,\n",
       "  0.3260885404687838,\n",
       "  0.3104537092187931,\n",
       "  0.3249039191595799,\n",
       "  0.3097100860140544,\n",
       "  0.332169922257203,\n",
       "  0.3036553550278395,\n",
       "  0.3274197665825513,\n",
       "  0.30540142175246937,\n",
       "  0.3319556708810225,\n",
       "  0.3296056183779334,\n",
       "  0.3129399336887326,\n",
       "  0.30700427838875294,\n",
       "  0.27477967969045963,\n",
       "  0.27355980389009715,\n",
       "  0.272161160511495,\n",
       "  0.2547278178098971,\n",
       "  0.2541698016677348,\n",
       "  0.2591609861384757,\n",
       "  0.2509854879107731,\n",
       "  0.23854635075477673,\n",
       "  0.2466387762317922,\n",
       "  0.2414650033076259,\n",
       "  0.2160074902113749,\n",
       "  0.2257864958940203,\n",
       "  0.23847447598264426,\n",
       "  0.238609249302439,\n",
       "  0.21821987629227693,\n",
       "  0.2038101638198991,\n",
       "  0.22897598811185463,\n",
       "  0.21320523677697212,\n",
       "  0.21577935581471339,\n",
       "  0.22806161846907377,\n",
       "  0.22779795287803253,\n",
       "  0.2378967443257832,\n",
       "  0.23012909685181823],\n",
       " 'smlls': [31.030172245755622,\n",
       "  10.043860983970715,\n",
       "  3.383892711688619,\n",
       "  1.8272766087378691,\n",
       "  1.5867054697152552,\n",
       "  1.3944360053872498,\n",
       "  1.0052087669409973,\n",
       "  0.7341415748075066,\n",
       "  1.1384281790567643,\n",
       "  0.6481615848706469,\n",
       "  0.7376123444072582,\n",
       "  0.5381250965063156,\n",
       "  0.6049987499595811,\n",
       "  0.5805033598326602,\n",
       "  0.5494336811399432,\n",
       "  0.9167545868163293,\n",
       "  0.574742910568235,\n",
       "  0.210721831102661,\n",
       "  0.19935628988074935,\n",
       "  0.08968395749665392,\n",
       "  -0.007605081873901076,\n",
       "  0.1383137719390506,\n",
       "  0.09246949917907059,\n",
       "  0.17269931309896971,\n",
       "  -0.18610502812992902,\n",
       "  0.008978557628844142,\n",
       "  -0.1353622187897677,\n",
       "  0.13061757606701038,\n",
       "  0.09727411614488142,\n",
       "  0.01956086507811435,\n",
       "  -0.00029652759938388523,\n",
       "  0.28566697510488276,\n",
       "  0.14527960981342805,\n",
       "  -0.15927603614614805,\n",
       "  0.16981234446418303,\n",
       "  -0.13391414056163561,\n",
       "  -0.2161116626486396,\n",
       "  0.5345247451481207,\n",
       "  -0.09397033135950277,\n",
       "  0.02115929805341299,\n",
       "  0.07378236613666662],\n",
       " 'mlls': [33.23841383739199,\n",
       "  12.252102575607083,\n",
       "  5.592134303324987,\n",
       "  4.035518200374238,\n",
       "  3.794947061351623,\n",
       "  3.602677597023618,\n",
       "  3.2134503585773655,\n",
       "  2.942383166443875,\n",
       "  3.3466697706931328,\n",
       "  2.8564031765070155,\n",
       "  2.9458539360436267,\n",
       "  2.746366688142684,\n",
       "  2.8132403415959497,\n",
       "  2.788744951469029,\n",
       "  2.757675272776312,\n",
       "  3.1249961784526974,\n",
       "  2.7829845022046036,\n",
       "  2.4189634227390293,\n",
       "  2.407597881517118,\n",
       "  2.2979255491330224,\n",
       "  2.200636509762467,\n",
       "  2.346555363575419,\n",
       "  2.300711090815439,\n",
       "  2.380940904735338,\n",
       "  2.0221365635064394,\n",
       "  2.2172201492652124,\n",
       "  2.0728793728466006,\n",
       "  2.3388591677033785,\n",
       "  2.30551570778125,\n",
       "  2.2278024567144827,\n",
       "  2.2079450640369847,\n",
       "  2.493908566741251,\n",
       "  2.353521201449796,\n",
       "  2.0489655554902204,\n",
       "  2.3780539361005513,\n",
       "  2.074327451074733,\n",
       "  1.9921299289877288,\n",
       "  2.7427663367844892,\n",
       "  2.1142712602768654,\n",
       "  2.2294008896897815,\n",
       "  2.2820239577730352]}"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_train_model(j_model, new_td_estimator, new_loader, j_args,\n",
    "               new_elbo_estimator, gpvae.estimators.gpvae_estimators.iwae_estimator,\n",
    "               train, test, y_std, y_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:12: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  if sys.path[0] == '':\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dff2414032be4dfc8f51cd0e9bfbf1d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 344.006\n",
      "(New) ELBO: -74206.035\n",
      "(Old) ELBO: -74240.370\n",
      "IWAE: 0.000\n",
      "SMSE: 2.012\n",
      "\n",
      "Epoch 50\n",
      "Loss: 68.987\n",
      "(New) ELBO: -14469.169\n",
      "(Old) ELBO: -14423.006\n",
      "IWAE: 0.000\n",
      "SMSE: 1.446\n",
      "\n",
      "Epoch 100\n",
      "Loss: 19.987\n",
      "(New) ELBO: -5508.677\n",
      "(Old) ELBO: -5499.198\n",
      "IWAE: 0.000\n",
      "SMSE: 0.492\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train_model(j_model_copy, 'jonny', new_td_estimator, new_loader, \n",
    "            2000, 50, j_args, new_elbo_estimator, evaluate_model=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def new_train_model(model, loss_fn, loader, args, elbo_estimator=None, iwae_estimator=None, \n",
    "                train=None, test=None, y_std=1., y_mean=1.):\n",
    "    metrics = {'epochs': [],\n",
    "               'losses': [],\n",
    "               'elbos': [],\n",
    "               'iwaes': [],\n",
    "               'smses': [],\n",
    "               'smlls': [],\n",
    "               'mlls': []\n",
    "              }\n",
    "    \n",
    "    model.train(True)\n",
    "    optimiser = optim.Adam(model.parameters(), lr=0.001)\n",
    "    \n",
    "    # Get dataset.\n",
    "    dataset = loader.dataset.get_dataset()\n",
    "    if loader.dataset.contains_nan:\n",
    "        x, (y, m) = dataset\n",
    "#         x, y, m = dataset\n",
    "    else:\n",
    "        x, y = dataset\n",
    "        m = None\n",
    "   \n",
    "    # Training.\n",
    "    for epoch in tqdm_notebook(range(2000)):\n",
    "        epoch_losses = []\n",
    "        for i, batch in enumerate(loader):\n",
    "            if loader.dataset.contains_nan:\n",
    "                x_b, (y_b, m_b) = batch\n",
    "            else:\n",
    "                x_b, y_b = batch\n",
    "                m_b = None\n",
    "            \n",
    "            optimiser.zero_grad()\n",
    "            \n",
    "            # TODO: this doesn't work with mean-field encoders.\n",
    "#             loss = loss_fn(model, x=x_b, y=y_b, mask=m_b, num_samples=1)\n",
    "            loss = loss_fn(model, batch=batch, contains_nan=True, num_samples=1)\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "            \n",
    "            epoch_losses.append(loss.item())\n",
    "            \n",
    "        # Evaluate model.\n",
    "        if (epoch % 50 == 0) or (epoch == 2000 - 1):\n",
    "            model.eval()\n",
    "            \n",
    "            # Average loss over previous epoch.\n",
    "            mean_loss = np.mean(epoch_losses)\n",
    "            metrics['losses'].append(mean_loss)\n",
    "            \n",
    "            if elbo_estimator is not None:\n",
    "                # ELBO estimate.\n",
    "#                 elbo = elbo_estimator(model, x, y, mask=m, num_samples=10)\n",
    "                elbo = new_elbo_estimator(model, dataset, num_samples=100,\n",
    "                                          contains_nan=True, decoder_scale=1.)\n",
    "                metrics['elbos'].append(elbo)\n",
    "            \n",
    "            if iwae_estimator is not None:\n",
    "                # IWAE estimate.\n",
    "#                 iwae = iwae_estimator(model, x, y, mask=m, num_samples=10)\n",
    "                iwae = 0.\n",
    "                metrics['iwaes'].append(iwae)\n",
    "            \n",
    "            if test is not None:\n",
    "                # Test predictions.\n",
    "#                 mean, sigma = model.predict_y(x=x, y=y, mask=m, num_samples=10)[:2]\n",
    "                mean, sigma = model.predict_y(data=dataset, contains_nan=True,\n",
    "                                              num_samples=100)[:2]\n",
    "                mean = mean.numpy() * y_std + y_mean\n",
    "                sigma = sigma.numpy() * y_std\n",
    "            \n",
    "                # Evaluate test predictions.\n",
    "                mean = pd.DataFrame(mean, index=train.index, columns=train.columns)\n",
    "                var = pd.DataFrame(sigma ** 2, index=train.index, columns=train.columns)\n",
    "                \n",
    "                smse = gpvae.utils.wbml_metrics.smse(mean, test).mean()\n",
    "                smll = gpvae.utils.wbml_metrics.smll(mean, var, test).mean()\n",
    "                mll = gpvae.utils.wbml_metrics.mll(mean, var, test).mean()\n",
    "                \n",
    "                metrics['smses'].append(smse)\n",
    "                metrics['smlls'].append(smll)\n",
    "                metrics['mlls'].append(mll)\n",
    "            \n",
    "            print('Epoch {}\\nLoss: {:.3f}\\nELBO: {:.3f}\\n'\n",
    "                  'IWAE: {:.3f}\\nSMSE: {:.3f}\\n'.format(epoch, mean_loss, elbo, iwae, smse))\n",
    "            \n",
    "            model.train(True)\n",
    "            \n",
    "    valid_response = False\n",
    "    while not valid_response:\n",
    "        response = input('Save model? (yes/no)')\n",
    "        if response == 'yes':\n",
    "            save_model(model, args, metrics)\n",
    "            valid_response = True\n",
    "        elif response == 'no':\n",
    "            valid_response = True\n",
    "        else:\n",
    "            pass\n",
    "        \n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, name, loss_fn, loader, epochs, cache_freq, hyperparameters, elbo_estimator, evaluate_model=False):\n",
    "    losses = []\n",
    "    elbos = []\n",
    "    smses = []\n",
    "    smlls = []\n",
    "    mlls = []\n",
    "    \n",
    "    model.train(True)\n",
    "    optimiser = optim.Adam(model.parameters(), lr=0.001)\n",
    "    dataset = loader.dataset.get_dataset()\n",
    "    \n",
    "    for epoch in tqdm_notebook(range(epochs)):\n",
    "        running_losses = []\n",
    "        for i, batch in enumerate(loader):\n",
    "            optimiser.zero_grad()\n",
    "            loss = loss_fn(model, batch, num_samples=1, \n",
    "                           contains_nan=True, decoder_scale=decoder_scale)\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "            \n",
    "            running_losses.append(loss.item())\n",
    "            \n",
    "        if (epoch % cache_freq == 0) or (epoch == epochs - 1):\n",
    "            model.eval()\n",
    "\n",
    "            # get elbo estimate\n",
    "            elbo = new_elbo_estimator(model, dataset, num_samples=100,\n",
    "                                  contains_nan=True, decoder_scale=1.)\n",
    "            old_elbo = old_elbo_estimator(model, dataset, num_samples=100,\n",
    "                                          contains_nan=True, decoder_scale=1.)\n",
    "#                 # get iwae estimate\n",
    "#                 iwae = gpvae_iwae_estimator(model, dataset, num_samples=100, \n",
    "#                                             contains_nan=True, decoder_scale=1.)\n",
    "            iwae = 0.\n",
    "\n",
    "            # get test performance\n",
    "            mean, sigma = model.predict_y(data=dataset, contains_nan=True,\n",
    "                                          num_samples=100)[:2]\n",
    "\n",
    "            mean = mean.numpy() * y_std + y_mean\n",
    "            sigma = sigma.numpy() * y_std\n",
    "            pred = pd.DataFrame(mean, index=train.index, columns=train.columns)\n",
    "            var = pd.DataFrame(sigma**2, index=train.index,\n",
    "                               columns=train.columns)\n",
    "            \n",
    "#             smse = wbml.metric.smse(pred, test)\n",
    "            smse = gpvae.utils.wbml_metrics.smse(pred, test)\n",
    "            mean_smse = smse.mean()\n",
    "            \n",
    "#             smll = wbml.metric.smll(pred, var, test)\n",
    "            smll = gpvae.utils.wbml_metrics.smll(pred, var, test)\n",
    "            mean_smll = smll.mean()\n",
    "            \n",
    "#             mll = wbml.metric.mll(pred, var, test)\n",
    "            mll = gpvae.utils.wbml_metrics.mll(pred, var, test)\n",
    "            mean_mll = mll.mean()\n",
    "            \n",
    "            mean_loss = np.mean(running_losses)\n",
    "            \n",
    "            print('Epoch {}\\nLoss: {:.3f}\\n(New) ELBO: {:.3f}\\n(Old) ELBO: {:.3f}\\n'\n",
    "                  'IWAE: {:.3f}\\nSMSE: {:.3f}\\n'.format(epoch, mean_loss, elbo, old_elbo, iwae, mean_smse))\n",
    "            \n",
    "            losses.append(mean_loss)\n",
    "            elbos.append(elbo)\n",
    "            smses.append(mean_smse)\n",
    "            smlls.append(mean_smll)\n",
    "            mlls.append(mean_mll)\n",
    "            \n",
    "            if (epoch > 0 and mean_smse > smses[-2] and evaluate_model) or (epoch == epochs - 1):\n",
    "                valid_response = False\n",
    "                while not valid_response:\n",
    "                    response = input('Save model? (yes/no/quit)')\n",
    "                    if response == 'yes':\n",
    "                        save_model(model, name, epoch, losses, elbos, \n",
    "                                   smses, smlls, mlls, cache_freq,\n",
    "                                   hyperparameters)\n",
    "                        valid_response = True\n",
    "                    elif response == 'no':\n",
    "                        valid_response = True\n",
    "                    elif response == 'quit':\n",
    "                        return losses, elbos, smses, smlls, mlls\n",
    "                    else:\n",
    "                        pass\n",
    "                    \n",
    "            running_losses = []\n",
    "            model.train(True)\n",
    "            \n",
    "    return losses, elbos, smses, smlls, mlls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_model(model, name, epoch, losses, elbos, smses, smlls, mlls, cache_freq, hyperparameters):\n",
    "    results_dir = 'complex_saved_models/' + name\n",
    "    if os.path.isdir(results_dir):\n",
    "        i = 1\n",
    "        while os.path.isdir(results_dir + '_' + str(i)):\n",
    "            i += 1\n",
    "            \n",
    "        results_dir = results_dir + '_' + str(i)\n",
    "        \n",
    "    os.makedirs(results_dir)\n",
    "    results_path = results_dir + '/report.txt'\n",
    "    model_path = results_dir + '/model_state_dict.pt'\n",
    "    \n",
    "    # store metrics\n",
    "    epochs = np.arange(0, (epoch+cache_freq), cache_freq)\n",
    "    metrics = {'epochs': epochs,\n",
    "               'losses': losses,\n",
    "               'elbos': elbos,\n",
    "               'smses': smses,\n",
    "               'smlls': smlls,\n",
    "               'mlls': mlls}\n",
    "    \n",
    "    # pickle hyperparameters and metrics for later use\n",
    "    with open(results_dir + '/hyperparameters.pkl', 'wb') as f:\n",
    "        pickle.dump(hyperparameters, f)\n",
    "        \n",
    "    with open(results_dir + '/metrics.pkl', 'wb') as f:\n",
    "        pickle.dump(metrics, f)\n",
    "    \n",
    "    # save hyperparameters and results in text format\n",
    "    with open(results_path, 'w') as f:\n",
    "        f.write('Hyperparameters: \\n')\n",
    "        if isinstance(hyperparameters, list):\n",
    "            for d in hyperparameters:\n",
    "                f.write(str(d) + '\\n')\n",
    "        else:\n",
    "            f.write(str(hyperparameters) + '\\n')\n",
    "\n",
    "        f.write('\\nPerformance: \\n')\n",
    "        epoch = 0\n",
    "        for loss, elbo, smse, smll, mll in zip(losses, elbos, smses, smlls, mlls):\n",
    "            f.write('\\nEpoch: {}\\n'\n",
    "                    'Loss (batch): {}\\n'\n",
    "                    'ELBO (train): {}\\n'\n",
    "                    'SMSE (test): {}\\n'\n",
    "                    'SMLL (test): {}\\n'\n",
    "                    'MLL (test): {}\\n'.format(epoch, loss, elbo, smse, smll, mll))\n",
    "            epoch += cache_freq\n",
    "    \n",
    "    # save model for later use\n",
    "    # doesn't seem to load properly\n",
    "    torch.save(model.state_dict(), model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPVAE(VAE):\n",
    "    \"\"\"Variational autoencoder with GP prior.\n",
    "\n",
    "    :param encoder: the encoder network.\n",
    "    :param decoder: the decoder network.\n",
    "    :param latent_dim: the dimension of latent space.\n",
    "    :param kernel: the GP kernel.\n",
    "    \"\"\"\n",
    "    def __init__(self, encoder, decoder, latent_dim, kernel):\n",
    "        super().__init__(encoder, decoder, latent_dim)\n",
    "\n",
    "        if not isinstance(kernel, list):\n",
    "            self.kernels = [copy.deepcopy(kernel) for _ in range(latent_dim)]\n",
    "            self.kernels = nn.ModuleList(self.kernels)\n",
    "\n",
    "        else:\n",
    "            assert len(kernel) == latent_dim, 'Number of kernels must be ' \\\n",
    "                                              'equal to the latent dimension.'\n",
    "            self.kernels = nn.ModuleList(copy.deepcopy(kernel))\n",
    "\n",
    "    def get_latent_prior(self, x):\n",
    "        # gaussian process prior\n",
    "        batch_size = x.shape[0]\n",
    "        pf_mu = torch.zeros(self.latent_dim, batch_size)\n",
    "        pf_cov = torch.stack([kernel.forward(x, x) for kernel in self.kernels])\n",
    "#         pf_cov = add_diagonal(pf_cov, 1e-5)\n",
    "\n",
    "        return pf_mu, pf_cov\n",
    "\n",
    "    def get_latent_dists(self, data, x_test=None, contains_nan=False):\n",
    "        x, y = data\n",
    "        if contains_nan:\n",
    "            y, mask = y\n",
    "            lf_y_mu, lf_y_sigma = self.encoder(y, mask)\n",
    "        else:\n",
    "            lf_y_mu, lf_y_sigma = self.encoder(y)\n",
    "            \n",
    "        # Reshape.\n",
    "        lf_y_mu = lf_y_mu.transpose(0, 1)\n",
    "        lf_y_sigma = lf_y_sigma.transpose(0, 1)\n",
    "        lf_y_cov = lf_y_sigma.pow(2).diag_embed()\n",
    "        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()\n",
    "        lf_y_root_precision = lf_y_sigma.pow(-1).diag_embed()\n",
    "\n",
    "        # GP prior.\n",
    "        pf_mu, kff = self.get_latent_prior(x)\n",
    "\n",
    "        # See GPML section 3.4.3.\n",
    "        a = kff.matmul(lf_y_root_precision)\n",
    "        at = a.transpose(-1, -2)\n",
    "        w = lf_y_root_precision.matmul(a)\n",
    "        w = add_diagonal(w, 1)\n",
    "        winv = w.inverse()\n",
    "\n",
    "        if x_test is not None:\n",
    "            # GP prior.\n",
    "            ps_mu, kss = self.get_latent_prior(x_test)\n",
    "\n",
    "            # GP conditional prior.\n",
    "            ksf = self.kernels.forward(x_test, x)\n",
    "            kfs = ksf.transpose(-1, -2)\n",
    "\n",
    "            # GP test posterior.\n",
    "            b = lf_y_root_precision.matmul(winv.matmul(lf_y_root_precision))\n",
    "            c = ksf.matmul(b)\n",
    "            qs_cov = kss - c.matmul(kfs)\n",
    "            qs_mu = c.matmul(lf_y_mu.unsqueeze(2))\n",
    "            qs_mu = qs_mu.squeeze(2)\n",
    "\n",
    "            return qs_mu, qs_cov, ps_mu, kss\n",
    "        else:\n",
    "            # GP training posterior.\n",
    "            qf_cov = kff - a.matmul(winv.matmul(at))\n",
    "            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))\n",
    "            qf_mu = qf_mu.squeeze(2)\n",
    "\n",
    "            return qf_mu, qf_cov, pf_mu, kff, lf_y_mu, lf_y_cov\n",
    "            \n",
    "#         lf_y_mu = lf_y_mu.transpose(0, 1)  # [latent_dim, M]\n",
    "#         lf_y_sigma = lf_y_sigma.transpose(0, 1)  # [latent_dim, M]\n",
    "#         lf_y_cov = torch.diag_embed(lf_y_sigma.pow(2))\n",
    "#         lf_y_cov_inv = torch.diag_embed(lf_y_sigma.pow(-2))\n",
    "#         lf_y_cov_inv_sqrt = torch.diag_embed(lf_y_sigma.pow(-1))\n",
    "\n",
    "#         Kff = torch.stack([kernel.forward(x, x) for kernel in self.kernels])\n",
    "#         pf_mu, Kff = self.get_latent_prior(x)\n",
    "\n",
    "#         A = torch.matmul(Kff, lf_y_cov_inv_sqrt)\n",
    "#         # W has lovely eigenvalues\n",
    "#         W = (torch.stack([torch.eye(lf_y_sigma.shape[1])\n",
    "#                           for _ in range(lf_y_sigma.shape[0])])\n",
    "#              + torch.matmul(lf_y_cov_inv_sqrt, A))\n",
    "#         W_inv = torch.inverse(W)\n",
    "\n",
    "#         if x_test is not None:\n",
    "#             Ksf = torch.stack([kernel.forward(x_test, x)\n",
    "#                                for kernel in self.kernels])\n",
    "#             Kss = torch.stack([kernel.forward(x_test, x_test)\n",
    "#                                for kernel in self.kernels])\n",
    "#             B = torch.matmul(Ksf, lf_y_cov_inv_sqrt)\n",
    "#             qs_mu = torch.matmul(B, torch.matmul(W_inv, torch.matmul(\n",
    "#                 lf_y_cov_inv_sqrt, lf_y_mu.unsqueeze(2))))\n",
    "#             qs_cov = Kss - torch.matmul(B, torch.matmul(W_inv,\n",
    "#                                                         B.transpose(-1, -2)))\n",
    "\n",
    "#             qs_mu = torch.squeeze(qs_mu)\n",
    "#             return qs_mu, qs_cov, torch.zeros_like(qs_mu), Kss\n",
    "#         else:\n",
    "#             qf_cov = Kff - torch.matmul(A, torch.matmul(W_inv,\n",
    "#                                                         A.transpose(-1, -2)))\n",
    "#             qf_mu = torch.matmul(qf_cov, torch.matmul(lf_y_cov_inv,\n",
    "#                                                       lf_y_mu.unsqueeze(2)))\n",
    "\n",
    "#             qf_mu = torch.squeeze(qf_mu, 2)\n",
    "#             return qf_mu, qf_cov, torch.zeros_like(\n",
    "#                 qf_mu), Kff, lf_y_mu, lf_y_cov\n",
    "\n",
    "    def sample_latent_posterior(self, data, num_samples=1, full_cov=True,\n",
    "                                **kwargs):\n",
    "        qf_mu, qf_cov = self.get_latent_dists(data, **kwargs)[:2]\n",
    "\n",
    "        samples = []\n",
    "        if full_cov:\n",
    "            qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "                    \n",
    "            for i in range(num_samples):\n",
    "                samples.append(qf.sample())\n",
    "        else:\n",
    "            qf_sigma = torch.stack([torch.diag(cov) for cov in qf_cov]) ** 0.5\n",
    "            for i in range(num_samples):\n",
    "                samples.append(qf_mu + qf_sigma * torch.randn_like(qf_mu))\n",
    "\n",
    "        return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "JITTER = 1e-5\n",
    "\n",
    "def add_diagonal(x, val=1.):\n",
    "    \"\"\"Adds a vlaue to the diagonal of a matrix.\n",
    "    :param x: the matrix to modify.\n",
    "    :param val: A float, the value to add to the diagonal.\n",
    "    \"\"\"\n",
    "    assert x.shape[-2] == x.shape[-1], 'x must be square.'\n",
    "\n",
    "    d = (torch.ones(x.shape[-2]) * val).diag_embed()\n",
    "\n",
    "    return x + d\n",
    "\n",
    "\n",
    "class KernelList(nn.ModuleList):\n",
    "    def __init__(self, kernels):\n",
    "        super().__init__(kernels)\n",
    "\n",
    "    def forward(self, x1, x2, diag=False):\n",
    "        covs = [kernel.forward(x1, x2, diag=diag) for kernel in self]\n",
    "\n",
    "        if diag:\n",
    "            # Reshape before stacking.\n",
    "            covs = torch.stack([cov.diag_embed() for cov in covs])\n",
    "        else:\n",
    "            covs = torch.stack(covs)\n",
    "\n",
    "        return covs\n",
    "    \n",
    "\n",
    "class NewGPVAE(VAE):\n",
    "    \"\"\"VAE with GP prior.\n",
    "    :param encoder: the encoder network.\n",
    "    :param decoder: the decoder network.\n",
    "    :param latent_dim: the dimension of latent space.\n",
    "    :param kernel: the GP kernel.\n",
    "    \"\"\"\n",
    "    def __init__(self, encoder, decoder, latent_dim, kernel, add_jitter=False):\n",
    "        super().__init__(encoder, decoder, latent_dim)\n",
    "        \n",
    "        self.add_jitter = add_jitter\n",
    "\n",
    "        if not isinstance(kernel, list):\n",
    "            kernels = [copy.deepcopy(kernel) for _ in range(latent_dim)]\n",
    "            self.kernels = KernelList(kernels)\n",
    "\n",
    "        else:\n",
    "            assert len(kernel) == latent_dim, 'Number of kernels must be ' \\\n",
    "                                              'equal to the latent dimension.'\n",
    "            self.kernels = KernelList(copy.deepcopy(kernel))\n",
    "\n",
    "    def get_latent_prior(self, x, diag=False):\n",
    "        # Gaussian process prior.\n",
    "        mf = torch.zeros(self.latent_dim, x.shape[0])\n",
    "        kff = self.kernels.forward(x, x, diag)\n",
    "        \n",
    "        if self.add_jitter:\n",
    "            kff = add_diagonal(kff, JITTER)\n",
    "\n",
    "        return mf, kff\n",
    "\n",
    "#     def get_latent_dists(self, data, x_test=None, contains_nan=False):\n",
    "    def get_latent_dists(self, x, y, mask=None, x_test=None):\n",
    "#         x, y = data\n",
    "        # Likelihood terms.\n",
    "#         if contains_nan:\n",
    "#             y, mask = y\n",
    "        if mask is not None:\n",
    "            lf_y_mu, lf_y_sigma = self.encoder(y, mask)\n",
    "        else:\n",
    "            lf_y_mu, lf_y_sigma = self.encoder(y)\n",
    "\n",
    "        # Reshape.\n",
    "        lf_y_mu = lf_y_mu.transpose(0, 1)\n",
    "        lf_y_sigma = lf_y_sigma.transpose(0, 1)\n",
    "        lf_y_cov = lf_y_sigma.pow(2).diag_embed()\n",
    "        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()\n",
    "        lf_y_root_precision = lf_y_sigma.pow(-1).diag_embed()\n",
    "\n",
    "        # GP prior.\n",
    "        pf_mu, kff = self.get_latent_prior(x)\n",
    "\n",
    "        # See GPML section 3.4.3.\n",
    "        a = kff.matmul(lf_y_root_precision)\n",
    "        at = a.transpose(-1, -2)\n",
    "        w = lf_y_root_precision.matmul(a)\n",
    "        w = add_diagonal(w, 1)\n",
    "        winv = w.inverse()\n",
    "\n",
    "        if x_test is not None:\n",
    "            # GP prior.\n",
    "            ps_mu, kss = self.get_latent_prior(x_test)\n",
    "\n",
    "            # GP conditional prior.\n",
    "            ksf = self.kernels.forward(x_test, x)\n",
    "            kfs = ksf.transpose(-1, -2)\n",
    "\n",
    "            # GP test posterior.\n",
    "            b = lf_y_root_precision.matmul(winv.matmul(lf_y_root_precision))\n",
    "            c = ksf.matmul(b)\n",
    "            qs_cov = kss - c.matmul(kfs)\n",
    "            qs_mu = c.matmul(lf_y_mu.unsqueeze(2))\n",
    "            qs_mu = qs_mu.squeeze(2)\n",
    "\n",
    "            return qs_mu, qs_cov, ps_mu, kss\n",
    "        else:\n",
    "            # GP training posterior.\n",
    "            qf_cov = kff - a.matmul(winv.matmul(at))\n",
    "            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))\n",
    "            qf_mu = qf_mu.squeeze(2)\n",
    "\n",
    "            return qf_mu, qf_cov, pf_mu, kff, lf_y_mu, lf_y_cov\n",
    "\n",
    "    def sample_latent_posterior(self, data, contains_nan=False, num_samples=1,\n",
    "                                full_cov=True, **kwargs):\n",
    "        # Latent posterior distribution.\n",
    "        if contains_nan:\n",
    "            x = data[0]\n",
    "            y, mask = data[1]\n",
    "        else:\n",
    "            x, y = data\n",
    "            mask = None\n",
    "            \n",
    "        qf_mu, qf_cov = self.get_latent_dists(x, y, mask, **kwargs)[:2]\n",
    "\n",
    "        if full_cov:\n",
    "            # Use GPyTorch MultivariateNormal class for sampling using the\n",
    "            # full covariance matrix.\n",
    "            qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "            samples = [qf.sample() for _ in range(num_samples)]\n",
    "        else:\n",
    "            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5\n",
    "            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)\n",
    "                       for _ in range(num_samples)]\n",
    "\n",
    "        return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def old_td_estimator(model, data, num_samples, contains_nan=False,\n",
    "                       decoder_scale=None, debug=False, make_lazy=True):\n",
    "    \"\"\"Estimates the negative ELBO using the total derivative of the\n",
    "    reparameterisation trick for a model with an for models with an\n",
    "    encoder/decoder architecture and GP prior over latent variables and\n",
    "    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)\n",
    "\n",
    "    :param: model: A nn.Module, the model to evaluate on.\n",
    "    :param: data: A tuple (x, y), the data to estimate ELBO on.\n",
    "    :param: num_samples: An int, the number of samples to estimate the\n",
    "    ELBO with.\n",
    "    :param: contains_nan: A bool, whether the observations contain NaN\n",
    "    values to ignore.\n",
    "    :param: decoder_scale: A float, by how much to scale the decoder terms\n",
    "    to account for missing observations.\n",
    "    \"\"\"\n",
    "    x, y = data\n",
    "    batch_size = x.shape[0]\n",
    "\n",
    "    terms = {'py_f_terms': [],\n",
    "             'lf_y_terms': [],\n",
    "             'norm_terms': []\n",
    "             }\n",
    "\n",
    "    if contains_nan:\n",
    "        y, mask = y\n",
    "        if decoder_scale is None:\n",
    "            num_nan = 1. * torch.sum(mask)\n",
    "            num_observations = y.shape[0] * y.shape[1]\n",
    "            decoder_scale = 1. - num_nan / num_observations\n",
    "    else:\n",
    "        mask = None\n",
    "        decoder_scale = 1.\n",
    "\n",
    "    estimator = 0\n",
    "\n",
    "    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(\n",
    "        data, contains_nan=contains_nan)\n",
    "\n",
    "    if make_lazy:\n",
    "        qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "        pf = MultivariateNormal(pf_mu, lazify(pf_cov))\n",
    "    else:\n",
    "        qf = MultivariateNormal(qf_mu, qf_cov)\n",
    "        pf = MultivariateNormal(pf_mu, pf_cov)\n",
    "\n",
    "    lf_y_var = torch.stack([torch.diag(cov) for cov in lf_y_cov])\n",
    "\n",
    "    for i in range(num_samples):\n",
    "        f = qf.rsample()\n",
    "\n",
    "#         # decoder\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "\n",
    "        if isinstance(y, list) or isinstance(y, tuple):\n",
    "            decoder_term = gaussian_diagonal_ll(y[0], py_f_mu,\n",
    "                                                py_f_sigma.pow(2),\n",
    "                                                mask=mask)\n",
    "        else:\n",
    "            decoder_term = gaussian_diagonal_ll(y, py_f_mu,\n",
    "                                                py_f_sigma.pow(2),\n",
    "                                                mask=mask)\n",
    "\n",
    "        decoder_term = decoder_scale * torch.sum(decoder_term)\n",
    "\n",
    "        encoder_term = torch.sum(gaussian_diagonal_ll(f, lf_y_mu.detach(),\n",
    "                                                      lf_y_var.detach()))\n",
    "        prior_term = torch.sum(pf.log_prob(f.detach()))\n",
    "\n",
    "        estimator += decoder_term - encoder_term + prior_term\n",
    "\n",
    "        terms['py_f_terms'].append(decoder_term.item())\n",
    "        terms['lf_y_terms'].append(encoder_term.item())\n",
    "        terms['norm_terms'].append(prior_term.item())\n",
    "\n",
    "    # inner summation (over samples from approximate posterior)\n",
    "    estimator /= num_samples\n",
    "    # outer summation (over batch)\n",
    "    estimator /= batch_size\n",
    "\n",
    "    loss = - estimator\n",
    "\n",
    "    if debug:\n",
    "        return loss, terms\n",
    "    else:\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def new_td_estimator(model, batch, num_samples=1, contains_nan=False, decoder_scale=None,\n",
    "                     make_lazy=True):\n",
    "    \"\"\"Estimates the negative ELBO using the total derivative of the\n",
    "    reparameterisation trick for a model with an for models with an\n",
    "    encoder/decoder architecture and GP prior over latent variables and\n",
    "    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)\n",
    "    :param: model: A nn.Module, the model to evaluate on.\n",
    "    :param: x: A torch.Tensor, the input data.\n",
    "    :param: y: A torch.Tensor, the output data.\n",
    "    :param: mask: A torch.Tensor, the mask to apply to the output data.\n",
    "    :param: num_samples: An int, the number of samples to estimate the\n",
    "    ELBO gradient with.\n",
    "    \"\"\"\n",
    "    x, y = batch\n",
    "    if contains_nan:\n",
    "        y, mask = y\n",
    "        # Scale decoder terms by the reciprocal of the proportion of missing\n",
    "        # observations.\n",
    "        if decoder_scale is None:\n",
    "            num_nan = 1. * torch.sum(mask)\n",
    "            num_observations = y.shape[0] * y.shape[1]\n",
    "            decoder_scale = 1. - num_nan / num_observations\n",
    "    else:\n",
    "        decoder_scale = 1.\n",
    "\n",
    "    estimator = 0\n",
    "\n",
    "    # Latent distributions.\n",
    "    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(\n",
    "        x, y, mask)\n",
    "\n",
    "    # Required distributions.\n",
    "    if make_lazy:\n",
    "        # Use GPyTorch MultivariateNormal class for sampling.\n",
    "        qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "        pf = MultivariateNormal(pf_mu, lazify(pf_cov))\n",
    "    else:\n",
    "        qf = MultivariateNormal(qf_mu, qf_cov)\n",
    "        pf = MultivariateNormal(pf_mu, pf_cov)\n",
    "\n",
    "    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])\n",
    "\n",
    "    # Monte-Carlo estimate of ELBO gradient.\n",
    "    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.\n",
    "    for _ in range(num_samples):\n",
    "        f = qf.rsample()\n",
    "\n",
    "        # log p(y|f) term.\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)\n",
    "        py_f_term = decoder_scale * py_f_term.sum()\n",
    "        estimator += py_f_term\n",
    "\n",
    "        # log l(f|y) term.\n",
    "        lf_y_term = gaussian_diagonal_ll(f, lf_y_mu.detach(),\n",
    "                                         lf_y_var.detach())\n",
    "        lf_y_term = lf_y_term.sum()\n",
    "        estimator += - lf_y_term\n",
    "\n",
    "        # log p(f) term.\n",
    "        pf_term = pf.log_prob(f.detach()).sum()\n",
    "        estimator += pf_term\n",
    "\n",
    "    # Inner summation over samples from q(f).\n",
    "    estimator /= num_samples\n",
    "\n",
    "    # Outer summation over batch.\n",
    "    estimator /= x.shape[0]\n",
    "\n",
    "    return - estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def new_elbo_estimator(model, batch, num_samples=1, contains_nan=False, make_lazy=True, decoder_scale=1.):\n",
    "    \"\"\"Estimates the negative ELBO using analytical results were possible,\n",
    "    and the reparameterisation trick for the decoder term for models with an\n",
    "    encoder/decoder architecture and GP prior over latent variables and\n",
    "    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)\n",
    "\n",
    "    :param: model: A nn.Module, the model to evaluate on.\n",
    "    :param: x: A torch.Tensor, the input data.\n",
    "    :param: y: A torch.Tensor, the output data.\n",
    "    :param: mask: A torch.Tensor, the mask to apply to the output data.\n",
    "    :param: num_samples: An int, the number of samples to estimate the\n",
    "    ELBO with.\n",
    "    \"\"\"\n",
    "    elbo = 0\n",
    "    \n",
    "    x, y = batch\n",
    "    if contains_nan:\n",
    "        y, mask = y\n",
    "\n",
    "\n",
    "    # Latent distributions.\n",
    "    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(\n",
    "        x, y, mask)\n",
    "    sum_cov = pf_cov + lf_y_cov\n",
    "\n",
    "    # Required distributions.\n",
    "    if make_lazy:\n",
    "        # Use GPyTorch MultivariateNormal class for sampling.\n",
    "        qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "        zq = MultivariateNormal(lf_y_mu, lazify(sum_cov))\n",
    "    else:\n",
    "        qf = MultivariateNormal(qf_mu, qf_cov)\n",
    "        zq = MultivariateNormal(lf_y_mu, sum_cov)\n",
    "\n",
    "    qf_var = torch.stack([cov.diag() for cov in qf_cov])\n",
    "    lf_y_var = torch.stack([cov.diag() for cov in lf_y_cov])\n",
    "\n",
    "    # Monte-Carlo estimate of ELBO.\n",
    "    # See Spatio-Temporal VAEs: ELBO\n",
    "    for i in range(num_samples):\n",
    "        f = qf.rsample()\n",
    "\n",
    "        # log p(y|f) term.\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask=mask)\n",
    "        py_f_term = py_f_term.sum()\n",
    "        elbo += py_f_term\n",
    "\n",
    "    # Inner summation over samples from q(f).\n",
    "    elbo /= num_samples\n",
    "\n",
    "    # log l(f|y) term.\n",
    "    lf_y_term = gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var).sum()\n",
    "    lf_y_term += - 0.5 * (qf_var / lf_y_var).sum()\n",
    "    elbo += - lf_y_term\n",
    "\n",
    "    # log Zq term.\n",
    "    zq_term = zq.log_prob(torch.zeros_like(lf_y_mu)).sum()\n",
    "    elbo += zq_term\n",
    "\n",
    "    return elbo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def old_elbo_estimator(model, data, num_samples, contains_nan=False,\n",
    "                         decoder_scale=None, make_lazy=True):\n",
    "    \"\"\"Estimates the negative ELBO using analytical results were possible,\n",
    "    and the reparameterisation trick for the decoder term for models with an\n",
    "    encoder/decoder architecture and GP prior over latent variables and\n",
    "    approximate posterior of the form q(f) = 1/Z p(f)l(f|y)\n",
    "\n",
    "    :param: model: A nn.Module, the model to evaluate on.\n",
    "    :param: data: A tuple (x, y), the data to estimate ELBO on.\n",
    "    :param: num_samples: An int, the number of samples to estimate the\n",
    "    ELBO with.\n",
    "    :param: contains_nan: A bool, whether the observations contain NaN\n",
    "    values to ignore.\n",
    "    :param: decoder_scale: A float, by how much to scale the decoder\n",
    "    terms\n",
    "    to account for missing observations.\n",
    "    \"\"\"\n",
    "\n",
    "    x, y = data\n",
    "\n",
    "    if contains_nan:\n",
    "        y, mask = y\n",
    "        if decoder_scale is None:\n",
    "            num_nan = 1. * torch.sum(mask)\n",
    "            num_observations = y.shape[0] * y.shape[1]\n",
    "            decoder_scale = 1. - num_nan / num_observations\n",
    "    else:\n",
    "        mask = None\n",
    "        decoder_scale = 1.\n",
    "\n",
    "    estimator = 0\n",
    "\n",
    "    qf_mu, qf_cov, pf_mu, pf_cov, lf_y_mu, lf_y_cov = model.get_latent_dists(\n",
    "        x, y, mask)\n",
    "\n",
    "    sum_cov = pf_cov + lf_y_cov\n",
    "\n",
    "    # required distribution\n",
    "    if make_lazy:\n",
    "        qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "        log_zq_dist = MultivariateNormal(lf_y_mu, lazify(sum_cov))\n",
    "    else:\n",
    "        qf = MultivariateNormal(qf_mu, qf_cov)\n",
    "        log_zq_dist = MultivariateNormal(lf_y_mu, sum_cov)\n",
    "\n",
    "    qf_var = torch.stack([torch.diag(cov) for cov in qf_cov])\n",
    "    lf_y_var = torch.stack([torch.diag(cov) for cov in lf_y_cov])\n",
    "\n",
    "    for i in range(num_samples):\n",
    "        f = qf.rsample()\n",
    "\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "\n",
    "        if isinstance(y, list) or isinstance(y, tuple):\n",
    "            decoder_term = gaussian_diagonal_ll(y[0], py_f_mu,\n",
    "                                                py_f_sigma.pow(2),\n",
    "                                                mask=mask)\n",
    "        else:\n",
    "            decoder_term = gaussian_diagonal_ll(y, py_f_mu,\n",
    "                                                py_f_sigma.pow(2),\n",
    "                                                mask=mask)\n",
    "\n",
    "        decoder_term = decoder_scale * torch.sum(decoder_term)\n",
    "        estimator += decoder_term\n",
    "\n",
    "    # inner summation (over samples from approximate posterior)\n",
    "    estimator /= num_samples\n",
    "\n",
    "    encoder_term = (torch.sum(gaussian_diagonal_ll(qf_mu, lf_y_mu, lf_y_var))\n",
    "                    - 0.5 * torch.sum(qf_var / lf_y_var))\n",
    "    log_zq_term = torch.sum(log_zq_dist.log_prob(torch.zeros_like(lf_y_mu)))\n",
    "\n",
    "    estimator += - encoder_term + log_zq_term\n",
    "\n",
    "    return estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NewTupleDataset(Dataset):\n",
    "    def __init__(self, x, y, contains_nan=False):\n",
    "        super().__init__()\n",
    "\n",
    "        assert len(x) == len(y), 'x and y must be the same length.'\n",
    "\n",
    "        if len(x.shape) == 1:\n",
    "            # Ensure inputs are 2-dimensional.\n",
    "            self.x = x.unsqueeze(1)\n",
    "        else:\n",
    "            self.x = x\n",
    "\n",
    "        if contains_nan:\n",
    "            self.y = copy.deepcopy(y)\n",
    "            self.m = torch.ones_like(y).fill_(True)\n",
    "\n",
    "            # Identify nan values and replace with 0.\n",
    "            m_idx = torch.isnan(y)\n",
    "            self.m[m_idx] = False\n",
    "            self.y[m_idx] = 0.\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "        self.contains_nan = contains_nan\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.x)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        x = self.x[idx]\n",
    "        y = self.y[idx]\n",
    "\n",
    "        if self.contains_nan:\n",
    "            m = self.m[idx]\n",
    "            return x, (y, m)\n",
    "        else:\n",
    "            return x, y\n",
    "\n",
    "    def get_dataset(self):\n",
    "        idx = list(range(len(self)))\n",
    "\n",
    "        return self.__getitem__(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spatio-temporal-vae-env",
   "language": "python",
   "name": "spatio-temporal-vae-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
