{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import pdb\n",
    "import time\n",
    "import pickle\n",
    "import copy\n",
    "sys.path.append('../../')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "import data.jura\n",
    "import gpytorch\n",
    "import gpvae\n",
    "import wbml.metric\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from torch.utils.data import DataLoader\n",
    "from experiments.jura.train_jura import train_jura\n",
    "\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Jura data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = data.jura.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [[i, j] for (i, j) in train.index]\n",
    "x = np.array(x)\n",
    "y = np.array(train)\n",
    "y_dim = y.shape[1]\n",
    "decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))\n",
    "\n",
    "# Log-transform the data.\n",
    "# y = np.log(y)\n",
    "\n",
    "# Normalise data.\n",
    "y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)\n",
    "y = (y - y_mean) / y_std"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct GPVAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "args = {'batch_size': 100,\n",
    "        'gp_latent_dim': 2,\n",
    "        'sn_latent_dim': 1,\n",
    "        'init_lengthscale': 1.,\n",
    "        'init_scale': 1.,\n",
    "        'lr': 0.001,\n",
    "        'num_samples': 1,\n",
    "        'elbo_samples': 100,\n",
    "        'test_samples': 100,\n",
    "        'epochs': 3000,\n",
    "        'cache_freq': 100,\n",
    "        'decoder_scale': decoder_scale\n",
    "       }\n",
    "\n",
    "# Set up dataset and dataloaders.\n",
    "dataset = gpvae.utils.dataset_utils.TupleDataset(torch.tensor(x), torch.tensor(y), contains_nan=True)\n",
    "loader = DataLoader(dataset, batch_size=args['batch_size'], shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Kernel.\n",
    "kernel = gpvae.kernels.RBFKernel(lengthscale=args['init_lengthscale'], scale=args['init_scale'])\n",
    "\n",
    "# Likelihood function hyperparameters.\n",
    "likelihood_args = {'in_dim': args['gp_latent_dim'] + args['sn_latent_dim'],\n",
    "                  'out_dim': y_dim,\n",
    "                  'hidden_dims': [20, 20],\n",
    "                  'sigma': .1,\n",
    "                  'train_sigma': True\n",
    "                 }\n",
    "\n",
    "# Semi-amortised DeepSet inference network hyperparameters.\n",
    "sads_gp_network_args = {'in_dim': y_dim, \n",
    "                        'out_dim': args['gp_latent_dim'],\n",
    "                        'middle_dim': 20,\n",
    "                        'hidden_dims': [20, 20],\n",
    "                        'shared_hidden_dims': [20, 20],\n",
    "                        'initial_sigma': .1,\n",
    "                        'initial_mu': 0.,\n",
    "                        'min_sigma': 0.01\n",
    "                        }\n",
    "\n",
    "sads_sn_network_args = {'in_dim': y_dim, \n",
    "                        'out_dim': args['sn_latent_dim'],\n",
    "                        'middle_dim': 20,\n",
    "                        'hidden_dims': [20, 20],\n",
    "                        'shared_hidden_dims': [20, 20],\n",
    "                        'initial_sigma': .1,\n",
    "                        'initial_mu': 0.,\n",
    "                        'min_sigma': 0.01\n",
    "                        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "sads_args = copy.deepcopy(args)\n",
    "sads_args['model'] = 'sads_hybrid'\n",
    "sads_args['gp_inference_network_args'] = sads_gp_network_args\n",
    "sads_args['sn_inference_network_args'] = sads_sn_network_args\n",
    "sads_args['likelihood_args'] = likelihood_args\n",
    "sads_gp_network = gpvae.networks.IndexNet(**sads_gp_network_args)\n",
    "sads_sn_network = gpvae.networks.IndexNet(**sads_sn_network_args)\n",
    "sads_likelihood = gpvae.networks.LinearGaussian(**likelihood_args)\n",
    "sads_model = HybridGPVAE(sads_gp_network, sads_sn_network, sads_likelihood, \n",
    "                         args['gp_latent_dim'], args['sn_latent_dim'], kernel,\n",
    "                         add_jitter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 3/3000 [00:00<18:44,  2.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 148.664\n",
      "ELBO: -45980.001\n",
      "SMSE: 1.016\n",
      "SMLL: 24.980\n",
      "MLL: 26.029\n",
      "MAE: 0.559\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 103/3000 [00:08<05:54,  8.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100\n",
      "Loss: 2.777\n",
      "ELBO: -2476.340\n",
      "SMSE: 0.815\n",
      "SMLL: 3.014\n",
      "MLL: 4.063\n",
      "MAE: 0.444\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|▋         | 203/3000 [00:17<05:32,  8.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 200\n",
      "Loss: -0.018\n",
      "ELBO: -1578.310\n",
      "SMSE: 1.145\n",
      "SMLL: 2.666\n",
      "MLL: 3.715\n",
      "MAE: 0.464\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 303/3000 [00:25<05:19,  8.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 300\n",
      "Loss: 0.788\n",
      "ELBO: -1342.764\n",
      "SMSE: 1.077\n",
      "SMLL: 1.491\n",
      "MLL: 2.540\n",
      "MAE: 0.459\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█▎        | 403/3000 [00:33<05:07,  8.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 400\n",
      "Loss: 1.390\n",
      "ELBO: -1224.228\n",
      "SMSE: 0.858\n",
      "SMLL: 1.000\n",
      "MLL: 2.049\n",
      "MAE: 0.430\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 501/3000 [00:41<06:05,  6.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500\n",
      "Loss: 1.725\n",
      "ELBO: -1151.278\n",
      "SMSE: 0.689\n",
      "SMLL: 0.377\n",
      "MLL: 1.427\n",
      "MAE: 0.406\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 603/3000 [00:49<04:44,  8.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 600\n",
      "Loss: 1.908\n",
      "ELBO: -1116.404\n",
      "SMSE: 0.691\n",
      "SMLL: 0.180\n",
      "MLL: 1.229\n",
      "MAE: 0.405\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 703/3000 [00:58<04:34,  8.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 700\n",
      "Loss: 2.030\n",
      "ELBO: -1064.891\n",
      "SMSE: 0.647\n",
      "SMLL: -0.037\n",
      "MLL: 1.012\n",
      "MAE: 0.399\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|██▋       | 803/3000 [01:06<04:22,  8.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 800\n",
      "Loss: 2.094\n",
      "ELBO: -1042.904\n",
      "SMSE: 0.650\n",
      "SMLL: -0.096\n",
      "MLL: 0.953\n",
      "MAE: 0.395\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 902/3000 [01:15<05:00,  6.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 900\n",
      "Loss: 1.934\n",
      "ELBO: -1037.044\n",
      "SMSE: 0.671\n",
      "SMLL: -0.010\n",
      "MLL: 1.039\n",
      "MAE: 0.397\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 1003/3000 [01:24<04:43,  7.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000\n",
      "Loss: 2.086\n",
      "ELBO: -1031.382\n",
      "SMSE: 0.683\n",
      "SMLL: -0.057\n",
      "MLL: 0.992\n",
      "MAE: 0.401\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 37%|███▋      | 1102/3000 [01:34<05:10,  6.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1100\n",
      "Loss: 2.179\n",
      "ELBO: -1025.978\n",
      "SMSE: 0.691\n",
      "SMLL: -0.108\n",
      "MLL: 0.941\n",
      "MAE: 0.403\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 1202/3000 [01:44<05:03,  5.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1200\n",
      "Loss: 2.061\n",
      "ELBO: -1016.724\n",
      "SMSE: 0.671\n",
      "SMLL: -0.157\n",
      "MLL: 0.892\n",
      "MAE: 0.398\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 43%|████▎     | 1302/3000 [01:54<04:42,  6.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1300\n",
      "Loss: 2.141\n",
      "ELBO: -1018.369\n",
      "SMSE: 0.679\n",
      "SMLL: -0.052\n",
      "MLL: 0.997\n",
      "MAE: 0.402\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████▋     | 1403/3000 [02:05<03:54,  6.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1400\n",
      "Loss: 2.097\n",
      "ELBO: -1018.966\n",
      "SMSE: 0.680\n",
      "SMLL: -0.130\n",
      "MLL: 0.919\n",
      "MAE: 0.401\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1502/3000 [02:15<03:42,  6.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1500\n",
      "Loss: 2.190\n",
      "ELBO: -1018.601\n",
      "SMSE: 0.686\n",
      "SMLL: -0.033\n",
      "MLL: 1.016\n",
      "MAE: 0.401\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|█████▎    | 1602/3000 [02:25<03:27,  6.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1600\n",
      "Loss: 2.205\n",
      "ELBO: -1021.643\n",
      "SMSE: 0.702\n",
      "SMLL: -0.040\n",
      "MLL: 1.009\n",
      "MAE: 0.407\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|█████▋    | 1702/3000 [02:36<03:31,  6.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1700\n",
      "Loss: 2.264\n",
      "ELBO: -1018.761\n",
      "SMSE: 0.698\n",
      "SMLL: -0.008\n",
      "MLL: 1.041\n",
      "MAE: 0.402\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 1802/3000 [02:46<03:18,  6.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1800\n",
      "Loss: 2.396\n",
      "ELBO: -1054.762\n",
      "SMSE: 0.696\n",
      "SMLL: 0.006\n",
      "MLL: 1.055\n",
      "MAE: 0.401\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 61%|██████▏   | 1842/3000 [02:50<01:47, 10.77it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-82-4ff19dea7f63>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m train_jura(sads_model, td_estimator, loader, sads_args, \n\u001b[0;32m----> 2\u001b[0;31m            elbo_estimator, None, save_model=True)\n\u001b[0m",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/experiments/jura/train_jura.py\u001b[0m in \u001b[0;36mtrain_jura\u001b[0;34m(model, loss_fn, loader, args, elbo_estimator, iwae_estimator, normalised, save_model, log_transform)\u001b[0m\n\u001b[1;32m     76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m             \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     80\u001b[0m             \u001b[0mepoch_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    101\u001b[0m                     \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m                     \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m                 \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_jura(sads_model, td_estimator, loader, sads_args, \n",
    "           elbo_estimator, None, save_model=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Class and function definitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpvae.utils.gaussian_utils import gaussian_diagonal_ll\n",
    "\n",
    "def td_estimator(model, x, y, mask=None, num_samples=1, decoder_scale=None,\n",
    "                 make_lazy=True, mf=False, idx=None):\n",
    "    \"\"\"Estimates the gradient of the negative ELBO using the total \n",
    "    derivative of the reparaemterisation trick for the GPVAE model.\n",
    "\n",
    "    :param model: A nn.Module, the model to evaluate on.\n",
    "    :param x: A torch.Tensor, the input data.\n",
    "    :param y: A torch.Tensor, the output data.\n",
    "    :param mask: A torch.Tensor, the mask to apply to the output data.\n",
    "    :param num_samples: An int, the number of samples to estimate the ELBO\n",
    "    gradient with.\n",
    "    :param decoder_scale: None or a float, the amount by which to scale the\n",
    "    decoder term, p(y|f), by. Relevant in the presence of missing values.\n",
    "    :param make_lazy: A bool, whether to use the GPyTorch MultivariateNormal\n",
    "    class for handling multivariate Gaussians.\n",
    "    :param mf: A bool, whether to model uses mean-field variational\n",
    "    inference or not.\n",
    "    :param idx: A torch.Tensor, the data indeces.\n",
    "    \"\"\"\n",
    "    if mask is not None:\n",
    "        # Scale decoder terms by the reciprocal of the proportion of missing\n",
    "        # observations.\n",
    "        if decoder_scale is None:\n",
    "            num_nan = 1. * torch.sum(mask)\n",
    "            num_observations = y.shape[0] * y.shape[1]\n",
    "            decoder_scale = 1. - num_nan / num_observations\n",
    "    else:\n",
    "        decoder_scale = 1.\n",
    "\n",
    "    estimator = 0\n",
    "\n",
    "    # Latent distributions.\n",
    "    if mf:\n",
    "        # Pass mean-field models the data indeces.\n",
    "        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \\\n",
    "            model.get_latent_gp_dists(x, idx=idx)\n",
    "        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x, idx=idx)\n",
    "    else:\n",
    "        # Pass amortisation models the observation data.\n",
    "        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \\\n",
    "            model.get_latent_gp_dists(x, y, mask)\n",
    "        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x, y, mask)\n",
    "\n",
    "    # Required distributions.\n",
    "    if make_lazy:\n",
    "        # Use GPyTorch MultivariateNormal class for sampling.\n",
    "        qf_gp = MultivariateNormal(qf_gp_mu, lazify(qf_gp_cov))\n",
    "        pf_gp = MultivariateNormal(pf_gp_mu, lazify(pf_gp_cov))\n",
    "    else:\n",
    "        qf_gp = MultivariateNormal(qf_gp_mu, qf_gp_cov)\n",
    "        pf_gp = MultivariateNormal(pf_gp_mu, pf_gp_cov)\n",
    "\n",
    "    lf_gp_y_var = torch.stack([cov.diag() for cov in lf_gp_y_cov])\n",
    "    qf_sn_var = torch.stack([cov.diag() for cov in qf_sn_cov])\n",
    "    pf_sn_var = torch.stack([cov.diag() for cov in pf_sn_cov])\n",
    "\n",
    "    # Monte-Carlo estimate of ELBO gradient.\n",
    "    # See Spatio-Temporal VAEs: ELBO Gradient Estimators.\n",
    "    for _ in range(num_samples):\n",
    "        f_gp = qf_gp.rsample()\n",
    "        f_sn = qf_sn_mu + qf_sn_var ** 0.5 * torch.randn_like(qf_sn_mu)\n",
    "        f = torch.cat([f_gp, f_sn], dim=0)\n",
    "\n",
    "        # log p(y|f) term.\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)\n",
    "        py_f_term = decoder_scale * py_f_term.sum()\n",
    "        estimator += py_f_term\n",
    "\n",
    "        # log l(f_gp|y) term.\n",
    "        lf_gp_y_term = gaussian_diagonal_ll(f_gp, lf_gp_y_mu.detach(),\n",
    "                                            lf_gp_y_var.detach())\n",
    "        lf_gp_y_term = lf_gp_y_term.sum()\n",
    "        estimator += - lf_gp_y_term\n",
    "        \n",
    "        # log q(f_sn) term.\n",
    "        qf_sn_term = gaussian_diagonal_ll(f_sn, qf_sn_mu, qf_sn_var)\n",
    "        qf_sn_term = qf_sn_term.sum()\n",
    "        estimator += - qf_sn_term\n",
    "\n",
    "        # log p(f_gp) term.\n",
    "        pf_gp_term = pf_gp.log_prob(f_gp.detach()).sum()\n",
    "        estimator += pf_gp_term\n",
    "        \n",
    "        # log p(f_sn) term.\n",
    "        pf_sn_term = gaussian_diagonal_ll(f_sn, pf_sn_mu, pf_sn_var)\n",
    "        pf_sn_term = pf_sn_term.sum()\n",
    "        estimator += pf_sn_term\n",
    "\n",
    "    # Inner summation over samples from q(f).\n",
    "    estimator /= num_samples\n",
    "\n",
    "    # Outer summation over batch.\n",
    "    estimator /= x.shape[0]\n",
    "\n",
    "    return - estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "def elbo_estimator(model, x, y, mask=None, num_samples=1, make_lazy=True,\n",
    "                   mf=False, idx=None):\n",
    "    elbo = 0\n",
    "\n",
    "    # Latent distributions.\n",
    "    if mf:\n",
    "        # Pass mean-field models the data indeces.\n",
    "        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \\\n",
    "            model.get_latent_gp_dists(x, idx=idx)\n",
    "        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x, idx=idx)\n",
    "    else:\n",
    "        # Pass amortisation models the observation data.\n",
    "        qf_gp_mu, qf_gp_cov, pf_gp_mu, pf_gp_cov, lf_gp_y_mu, lf_gp_y_cov = \\\n",
    "            model.get_latent_gp_dists(x, y, mask)\n",
    "        qf_sn_mu, qf_sn_cov, pf_sn_mu, pf_sn_cov = model.get_latent_sn_dists(x, y, mask)\n",
    "\n",
    "    gp_sum_cov = pf_gp_cov + lf_gp_y_cov\n",
    "\n",
    "    # Required distributions.\n",
    "    if make_lazy:\n",
    "        # Use GPyTorch MultivariateNormal class for sampling.\n",
    "        qf_gp = MultivariateNormal(qf_gp_mu, lazify(qf_gp_cov))\n",
    "        zq_gp = MultivariateNormal(lf_gp_y_mu, lazify(gp_sum_cov))\n",
    "    else:\n",
    "        qf_gp = MultivariateNormal(qf_gp_mu, qf_gp_cov)\n",
    "        zq_gp = MultivariateNormal(lf_gp_y_mu, gp_sum_cov)\n",
    "\n",
    "    qf_gp_var = torch.stack([cov.diag() for cov in qf_gp_cov])\n",
    "    lf_gp_y_var = torch.stack([cov.diag() for cov in lf_gp_y_cov])\n",
    "    qf_sn_var = torch.stack([cov.diag() for cov in qf_sn_cov])\n",
    "    pf_sn_var = torch.stack([cov.diag() for cov in pf_sn_cov])\n",
    "\n",
    "    # Monte-Carlo estimate of ELBO.\n",
    "    # See Spatio-Temporal VAEs: ELBO\n",
    "    for i in range(num_samples):\n",
    "        f_gp = qf_gp.rsample()\n",
    "        f_sn = qf_sn_mu + qf_sn_var ** 0.5 * torch.randn_like(qf_sn_mu)\n",
    "        f = torch.cat([f_gp, f_sn], dim=0)\n",
    "\n",
    "        # log p(y|f) term.\n",
    "        py_f_mu, py_f_sigma = model.decoder(f.transpose(0, 1))\n",
    "        py_f_term = gaussian_diagonal_ll(y, py_f_mu, py_f_sigma.pow(2), mask)\n",
    "        py_f_term = py_f_term.sum()\n",
    "        elbo += py_f_term\n",
    "\n",
    "    # Inner summation over samples from q(f).\n",
    "    elbo /= num_samples\n",
    "\n",
    "    # log l(f_gp|y) term.\n",
    "    lf_gp_y_term = gaussian_diagonal_ll(qf_gp_mu, lf_gp_y_mu, lf_gp_y_var).sum()\n",
    "    lf_gp_y_term += - 0.5 * (qf_gp_var / lf_gp_y_var).sum()\n",
    "    elbo += - lf_gp_y_term\n",
    "\n",
    "    # log Zq_gp term.\n",
    "    zq_gp_term = zq_gp.log_prob(torch.zeros_like(lf_gp_y_mu)).sum()\n",
    "    elbo += zq_gp_term\n",
    "    \n",
    "    # KL(qf_sn || pf_sn) term.\n",
    "    kl_sn_term = gaussian_diagonal_kl(qf_sn_mu, qf_sn_var, pf_sn_mu, pf_sn_var)\n",
    "    kl_sn_term = kl_sn_term.sum()\n",
    "    elbo += - kl_sn_term\n",
    "    \n",
    "    return elbo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "from gpvae.utils.matrix_utils import add_diagonal\n",
    "from gpvae.kernels.composition_kernels import KernelList\n",
    "\n",
    "from gpytorch.distributions.multivariate_normal import MultivariateNormal\n",
    "from gpytorch.lazy import lazify\n",
    "\n",
    "JITTER = 1e-5\n",
    "\n",
    "class HybridGPVAE(nn.Module):\n",
    "    \"\"\"VAE with a mixture of GP and Gaussian latent prior.\n",
    "\n",
    "    :param gp_encoder: the encoder network for latent GPs.\n",
    "    :param sn_encoder: the encoder network for standard normal latent\n",
    "    dimensions.\n",
    "    :param decoder: the decoder network.\n",
    "    :param gp_latent_dim: the number of latent GPs.\n",
    "    :param sn_latent_dim: the number of standard normal latent dimensions.\n",
    "    :param kernel: the GP kernel.\n",
    "    :param add_jitter: whether to add jitter to the GP prior covariance matrix.\n",
    "    \"\"\"\n",
    "    def __init__(self, gp_encoder, sn_encoder, decoder, gp_latent_dim,\n",
    "                 sn_latent_dim, kernel, add_jitter=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.gp_encoder = gp_encoder\n",
    "        self.sn_encoder = sn_encoder\n",
    "        self.decoder = decoder\n",
    "        self.gp_latent_dim = gp_latent_dim\n",
    "        self.sn_latent_dim = sn_latent_dim\n",
    "        self.add_jitter = add_jitter\n",
    "\n",
    "        if not isinstance(kernel, list):\n",
    "            kernels = [copy.deepcopy(kernel) for _ in range(gp_latent_dim)]\n",
    "            self.kernels = KernelList(kernels)\n",
    "\n",
    "        else:\n",
    "            assert len(kernel) == gp_latent_dim, 'Number of kernels must be ' \\\n",
    "                                                 'equal to the latent ' \\\n",
    "                                                 'dimension.'\n",
    "            self.kernels = KernelList(copy.deepcopy(kernel))\n",
    "\n",
    "    def get_latent_gp_prior(self, x, diag=False):\n",
    "        # Gaussian process prior.\n",
    "        mf = torch.zeros(self.gp_latent_dim, x.shape[0])\n",
    "        kff = self.kernels.forward(x, x, diag)\n",
    "\n",
    "        if self.add_jitter:\n",
    "            # Add jitter to improve condition number.\n",
    "            kff = add_diagonal(kff, JITTER)\n",
    "\n",
    "        return mf, kff\n",
    "\n",
    "    def get_latent_sn_prior(self, x, diag=False):\n",
    "        # Standard normal prior.\n",
    "        pf_mu = torch.zeros(self.sn_latent_dim, x.shape[0])\n",
    "        pf_cov = torch.ones(self.sn_latent_dim, x.shape[0]).diag_embed()\n",
    "\n",
    "        return pf_mu, pf_cov\n",
    "\n",
    "    def get_latent_gp_dists(self, x, y, mask=None, x_test=None):\n",
    "        # Likelihood terms.\n",
    "        if mask is not None:\n",
    "            lf_y_mu, lf_y_sigma = self.gp_encoder(y, mask)\n",
    "        else:\n",
    "            lf_y_mu, lf_y_sigma = self.gp_encoder(y)\n",
    "\n",
    "        # Reshape.\n",
    "        lf_y_mu = lf_y_mu.transpose(0, 1)\n",
    "        lf_y_sigma = lf_y_sigma.transpose(0, 1)\n",
    "        lf_y_cov = lf_y_sigma.pow(2).diag_embed()\n",
    "        lf_y_precision = lf_y_sigma.pow(-2).diag_embed()\n",
    "        lf_y_root_precision = lf_y_sigma.pow(-1).diag_embed()\n",
    "\n",
    "        # GP prior.\n",
    "        pf_mu, kff = self.get_latent_gp_prior(x)\n",
    "\n",
    "        # See GPML section 3.4.3.\n",
    "        a = kff.matmul(lf_y_root_precision)\n",
    "        at = a.transpose(-1, -2)\n",
    "        w = lf_y_root_precision.matmul(a)\n",
    "        w = add_diagonal(w, 1)\n",
    "        winv = w.inverse()\n",
    "\n",
    "        if x_test is not None:\n",
    "            # GP prior.\n",
    "            ps_mu, kss = self.get_latent_prior(x_test)\n",
    "\n",
    "            # GP conditional prior.\n",
    "            ksf = self.kernels.forward(x_test, x)\n",
    "            kfs = ksf.transpose(-1, -2)\n",
    "\n",
    "            # GP test posterior.\n",
    "            b = lf_y_root_precision.matmul(winv.matmul(lf_y_root_precision))\n",
    "            c = ksf.matmul(b)\n",
    "            qs_cov = kss - c.matmul(kfs)\n",
    "            qs_mu = c.matmul(lf_y_mu.unsqueeze(2))\n",
    "            qs_mu = qs_mu.squeeze(2)\n",
    "\n",
    "            return qs_mu, qs_cov, ps_mu, kss\n",
    "        else:\n",
    "            # GP training posterior.\n",
    "            qf_cov = kff - a.matmul(winv.matmul(at))\n",
    "            qf_mu = qf_cov.matmul(lf_y_precision.matmul(lf_y_mu.unsqueeze(2)))\n",
    "            qf_mu = qf_mu.squeeze(2)\n",
    "\n",
    "            return qf_mu, qf_cov, pf_mu, kff, lf_y_mu, lf_y_cov\n",
    "\n",
    "    def get_latent_sn_dists(self, x, y, mask=None):\n",
    "        # Posterior.\n",
    "        if mask is not None:\n",
    "            qf_mu, qf_sigma = self.sn_encoder(y, mask)\n",
    "        else:\n",
    "            qf_mu, qf_sigma = self.sn_encoder(y)\n",
    "\n",
    "        # Reshape.\n",
    "        qf_mu = qf_mu.transpose(0, 1)\n",
    "        qf_sigma = qf_sigma.transpose(0, 1)\n",
    "        qf_cov = qf_sigma.pow(2).diag_embed()\n",
    "\n",
    "        # Prior.\n",
    "        pf_mu, pf_cov = self.get_latent_sn_prior(x)\n",
    "\n",
    "        return qf_mu, qf_cov, pf_mu, pf_cov\n",
    "\n",
    "    def sample_latent_gp_posterior(self, x, y=None, mask=None, num_samples=1,\n",
    "                                   full_cov=True, **kwargs):\n",
    "        # Latent posterior distribution.\n",
    "        if y is not None:\n",
    "            qf_mu, qf_cov = self.get_latent_gp_dists(x, y, mask, **kwargs)[:2]\n",
    "        else:\n",
    "            qf_mu, qf_cov = self.get_latent_gp_prior(x)\n",
    "\n",
    "        if full_cov:\n",
    "            # Use GPyTorch MultivariateNormal class for sampling using the\n",
    "            # full covariance matrix.\n",
    "            qf = MultivariateNormal(qf_mu, lazify(qf_cov))\n",
    "            samples = [qf.sample() for _ in range(num_samples)]\n",
    "        else:\n",
    "            qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5\n",
    "            samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)\n",
    "                       for _ in range(num_samples)]\n",
    "\n",
    "        return samples\n",
    "\n",
    "    def sample_latent_sn_posterior(self, x, y=None, mask=None,\n",
    "                                   num_samples=1, **kwargs):\n",
    "        if y is not None:\n",
    "            # Latent posterior distribution.\n",
    "            qf_mu, qf_cov = self.get_latent_sn_dists(x, y, mask)[:2]\n",
    "        else:\n",
    "            # Latent posterior distribution is the prior.\n",
    "            qf_mu, qf_cov = self.get_latent_sn_prior(x)\n",
    "\n",
    "        qf_sigma = torch.stack([cov.diag() for cov in qf_cov]) ** 0.5\n",
    "        samples = [qf_mu + qf_sigma * torch.randn_like(qf_mu)\n",
    "                   for _ in range(num_samples)]\n",
    "\n",
    "        return samples\n",
    "\n",
    "    def predict_y(self, **kwargs):\n",
    "        # Sample latent posterior distributions.\n",
    "        f_gp_samples = self.sample_latent_gp_posterior(**kwargs)\n",
    "        f_sn_samples = self.sample_latent_sn_posterior(**kwargs)\n",
    "\n",
    "        py_f_mus, py_f_sigmas, py_f_samples = [], [], []\n",
    "        for f_gp, f_sn in zip(f_gp_samples, f_sn_samples):\n",
    "            # Latent sample.\n",
    "            f = torch.cat([f_gp, f_sn], dim=0)\n",
    "\n",
    "            # Output conditional posterior distribution.\n",
    "            py_f_mu, py_f_sigma = self.decoder(f.transpose(0, 1))\n",
    "            py_f_mus.append(py_f_mu)\n",
    "            py_f_sigmas.append(py_f_sigma)\n",
    "            py_f_samples.append(\n",
    "                py_f_mu + py_f_sigma * torch.randn_like(py_f_mu))\n",
    "\n",
    "        py_f_mu = torch.stack(py_f_mus).mean(0).detach()\n",
    "        py_f_sigma = torch.stack(py_f_samples).std(0).detach()\n",
    "\n",
    "        return py_f_mu, py_f_sigma, py_f_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian_diagonal_kl(m1, v1, m2, v2):\n",
    "    kl = 0.5 * ((v2 / v1).log() + (v1 + (m1 - m2) ** 2) / v2 - 1)\n",
    "    return kl.sum(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spatio-temporal-vae-env",
   "language": "python",
   "name": "spatio-temporal-vae-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
