{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import pdb\n",
    "import time\n",
    "import pickle\n",
    "import copy\n",
    "sys.path.append('../../')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import pandas as pd\n",
    "import data\n",
    "import gpytorch\n",
    "import gpvae\n",
    "import wbml.metric\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Jura data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = data.jura.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [[i, j] for (i, j) in train.index]\n",
    "x = np.array(x)\n",
    "y = np.array(train)\n",
    "y_dim = y.shape[1]\n",
    "decoder_scale = 1. / (np.sum(~np.isnan(y)) / (y.shape[0] * y.shape[1]))\n",
    "\n",
    "# Log-transform the data.\n",
    "# y = np.log(y)\n",
    "\n",
    "# Normalise data.\n",
    "y_mean, y_std = np.nanmean(y, axis=0), np.nanstd(y, axis=0)\n",
    "y = (y - y_mean) / y_std"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct GPVAE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "args = {'batch_size': 100,\n",
    "        'latent_dim': 5,\n",
    "        'init_lengthscale': 1.,\n",
    "        'init_scale': 1.,\n",
    "        'init_period': .1,\n",
    "        'auxiliary_dim': 1,\n",
    "        'num_inducing': 100,\n",
    "        'lr': 0.001,\n",
    "        'num_samples': 1,\n",
    "        'num_f_samples': 1,\n",
    "        'num_s_samples': 1,\n",
    "        'num_elbo_samples': 100,\n",
    "        'num_test_samples': 100,\n",
    "        'epochs': 5000,\n",
    "        'cache_freq': 100,\n",
    "        'decoder_scale': decoder_scale\n",
    "       }\n",
    "\n",
    "# Set up dataset and dataloaders.\n",
    "dataset = gpvae.utils.dataset_utils.TupleDataset(torch.tensor(x), torch.tensor(y), contains_nan=True)\n",
    "loader = DataLoader(dataset, batch_size=args['batch_size'], shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/scipy/cluster/vq.py:579: UserWarning: One of the clusters is empty. Re-run kmeans with a different initialization.\n",
      "  warnings.warn(\"One of the clusters is empty. \"\n"
     ]
    }
   ],
   "source": [
    "# kernel\n",
    "# kernel = gpvae.kernels.RBFKernel(lengthscale=args['init_lengthscale'], scale=args['init_scale'])\n",
    "# k2 = gpvae.kernels.RBFKernel(lengthscale=1., scale=args['init_scale'] / 2)\n",
    "# kernel = gpvae.kernels.AdditiveKernel(k1, k2)\n",
    "# kernel = k1\n",
    "# kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RQKernel())\n",
    "\n",
    "kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "# kernel.base_kernel._set_lengthscale(args['init_lengthscale'])\n",
    "# kernel._set_outputscale(args['init_scale'])\n",
    "\n",
    "# decoder hyperparameters\n",
    "decoder_args = {'in_dim': args['latent_dim'],\n",
    "                'out_dim': y_dim,\n",
    "                'hidden_dims': [20, 20],\n",
    "                'sigma': 0.1,\n",
    "                'train_sigma': False\n",
    "               }\n",
    "\n",
    "# jonny encoder hyperparameters\n",
    "j_encoder_args = {'in_dim': y_dim, \n",
    "                'out_dim': args['latent_dim'],\n",
    "                'middle_dim': 20,\n",
    "                'hidden_dims': [20, 20],\n",
    "                'shared_hidden_dims': [20, 20],\n",
    "                'initial_sigma': 1.,\n",
    "                'initial_mu': 0.,\n",
    "                'min_sigma': 0.01\n",
    "               }\n",
    "\n",
    "# vfe encoder hyperparameters\n",
    "vfe_encoder_args = {'out_dim': args['latent_dim']}\n",
    "\n",
    "# vfe inducing points\n",
    "z = kmeans2(x, args['num_inducing'], minit='points')[0]\n",
    "z = torch.tensor(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "j_args = copy.deepcopy(args)\n",
    "j_args['model'] = 'jonny'\n",
    "j_args['encoder_args'] = j_encoder_args\n",
    "j_args['decoder_args'] = decoder_args\n",
    "j_encoder = gpvae.networks.IndexNet(**j_encoder_args)\n",
    "j_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "j_model = gpvae.models.GPVAE(j_encoder, j_decoder, args['latent_dim'], kernel=kernel, add_jitter=True)\n",
    "\n",
    "# vfe_gpvae_args = [args, vfe_encoder_args, decoder_args]\n",
    "# vfe_encoder = gpvae.mf_networks.MeanFieldSparseNet(z, **vfe_encoder_args)\n",
    "# vfe_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "# vfe_model = gpvae.models.SparseGPVAE(vfe_encoder, vfe_decoder, args['latent_dim'], kernel=kernel)\n",
    "\n",
    "vae_args = copy.deepcopy(args)\n",
    "vae_args['model'] = 'gpvae'\n",
    "vae_args['encoder_args'] = encoder_args\n",
    "vae_args['decoder_args'] = decoder_args\n",
    "vae_encoder = gpvae.networks.IndexNet(**encoder_args)\n",
    "vae_decoder = gpvae.networks.LinearGaussian(**decoder_args)\n",
    "vae_model = gpvae.models.VAE(vae_encoder, vae_decoder, args['latent_dim'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'vfe_model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-ebfdf0b4bda9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvfe_metrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvfe_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'vfe'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgpvae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_estimators\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpvae_vfe_analytical_estimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'vfe_model' is not defined"
     ]
    }
   ],
   "source": [
    "vfe_metrics = train_model(vfe_model, 'vfe', gpvae.gpvae_estimators.gpvae_vfe_analytical_estimator, loader, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:27: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "633f7b7b2cac4552a1f49aa6af15aec5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 129.987\n",
      "ELBO: -47625.031\n",
      "IWAE: 0.000\n",
      "MAE: 0.579\n",
      "\n",
      "Epoch 100\n",
      "Loss: 17.655\n",
      "ELBO: -22188.288\n",
      "IWAE: 0.000\n",
      "MAE: 0.557\n",
      "\n",
      "Epoch 200\n",
      "Loss: 7.094\n",
      "ELBO: -14320.119\n",
      "IWAE: 0.000\n",
      "MAE: 0.567\n",
      "\n",
      "Epoch 300\n",
      "Loss: -0.001\n",
      "ELBO: -9193.854\n",
      "IWAE: 0.000\n",
      "MAE: 0.648\n",
      "\n",
      "Epoch 400\n",
      "Loss: -1.229\n",
      "ELBO: -7073.327\n",
      "IWAE: 0.000\n",
      "MAE: 0.615\n",
      "\n",
      "Epoch 500\n",
      "Loss: -1.924\n",
      "ELBO: -5796.346\n",
      "IWAE: 0.000\n",
      "MAE: 0.604\n",
      "\n",
      "Epoch 600\n",
      "Loss: -2.128\n",
      "ELBO: -5062.194\n",
      "IWAE: 0.000\n",
      "MAE: 0.567\n",
      "\n",
      "Epoch 700\n",
      "Loss: -2.559\n",
      "ELBO: -4440.593\n",
      "IWAE: 0.000\n",
      "MAE: 0.544\n",
      "\n",
      "Epoch 800\n",
      "Loss: -3.134\n",
      "ELBO: -3879.661\n",
      "IWAE: 0.000\n",
      "MAE: 0.610\n",
      "\n",
      "Epoch 900\n",
      "Loss: -3.610\n",
      "ELBO: -3488.687\n",
      "IWAE: 0.000\n",
      "MAE: 0.598\n",
      "\n",
      "Epoch 1000\n",
      "Loss: -4.288\n",
      "ELBO: -3175.299\n",
      "IWAE: 0.000\n",
      "MAE: 0.535\n",
      "\n",
      "Epoch 1100\n",
      "Loss: -4.273\n",
      "ELBO: -2938.700\n",
      "IWAE: 0.000\n",
      "MAE: 0.590\n",
      "\n",
      "Epoch 1200\n",
      "Loss: -4.886\n",
      "ELBO: -2596.921\n",
      "IWAE: 0.000\n",
      "MAE: 0.604\n",
      "\n",
      "Epoch 1300\n",
      "Loss: -5.089\n",
      "ELBO: -2459.920\n",
      "IWAE: 0.000\n",
      "MAE: 0.604\n",
      "\n",
      "Epoch 1400\n",
      "Loss: -5.172\n",
      "ELBO: -2247.837\n",
      "IWAE: 0.000\n",
      "MAE: 0.564\n",
      "\n",
      "Epoch 1500\n",
      "Loss: -5.532\n",
      "ELBO: -2093.011\n",
      "IWAE: 0.000\n",
      "MAE: 0.610\n",
      "\n",
      "Epoch 1600\n",
      "Loss: -5.678\n",
      "ELBO: -1982.757\n",
      "IWAE: 0.000\n",
      "MAE: 0.550\n",
      "\n",
      "Epoch 1700\n",
      "Loss: -6.098\n",
      "ELBO: -1925.292\n",
      "IWAE: 0.000\n",
      "MAE: 0.603\n",
      "\n",
      "Epoch 1800\n",
      "Loss: -6.197\n",
      "ELBO: -1809.166\n",
      "IWAE: 0.000\n",
      "MAE: 0.492\n",
      "\n",
      "Epoch 1900\n",
      "Loss: -6.291\n",
      "ELBO: -1731.586\n",
      "IWAE: 0.000\n",
      "MAE: 0.470\n",
      "\n",
      "Epoch 2000\n",
      "Loss: -6.352\n",
      "ELBO: -1654.804\n",
      "IWAE: 0.000\n",
      "MAE: 0.474\n",
      "\n",
      "Epoch 2100\n",
      "Loss: -6.780\n",
      "ELBO: -1591.532\n",
      "IWAE: 0.000\n",
      "MAE: 0.462\n",
      "\n",
      "Epoch 2200\n",
      "Loss: -6.603\n",
      "ELBO: -1537.948\n",
      "IWAE: 0.000\n",
      "MAE: 0.474\n",
      "\n",
      "Epoch 2300\n",
      "Loss: -6.593\n",
      "ELBO: -1416.884\n",
      "IWAE: 0.000\n",
      "MAE: 0.457\n",
      "\n",
      "Epoch 2400\n",
      "Loss: -6.589\n",
      "ELBO: -1344.916\n",
      "IWAE: 0.000\n",
      "MAE: 0.432\n",
      "\n",
      "Epoch 2500\n",
      "Loss: -6.842\n",
      "ELBO: -1242.601\n",
      "IWAE: 0.000\n",
      "MAE: 0.414\n",
      "\n",
      "Epoch 2600\n",
      "Loss: -6.997\n",
      "ELBO: -1187.809\n",
      "IWAE: 0.000\n",
      "MAE: 0.416\n",
      "\n",
      "Epoch 2700\n",
      "Loss: -7.221\n",
      "ELBO: -1187.525\n",
      "IWAE: 0.000\n",
      "MAE: 0.419\n",
      "\n",
      "Epoch 2800\n",
      "Loss: -7.429\n",
      "ELBO: -1106.850\n",
      "IWAE: 0.000\n",
      "MAE: 0.417\n",
      "\n",
      "Epoch 2900\n",
      "Loss: -7.370\n",
      "ELBO: -1096.756\n",
      "IWAE: 0.000\n",
      "MAE: 0.421\n",
      "\n",
      "Epoch 3000\n",
      "Loss: -7.422\n",
      "ELBO: -1084.361\n",
      "IWAE: 0.000\n",
      "MAE: 0.416\n",
      "\n",
      "Epoch 3100\n",
      "Loss: -7.558\n",
      "ELBO: -1057.081\n",
      "IWAE: 0.000\n",
      "MAE: 0.415\n",
      "\n",
      "Epoch 3200\n",
      "Loss: -7.773\n",
      "ELBO: -1071.413\n",
      "IWAE: 0.000\n",
      "MAE: 0.412\n",
      "\n",
      "Epoch 3300\n",
      "Loss: -7.717\n",
      "ELBO: -1084.133\n",
      "IWAE: 0.000\n",
      "MAE: 0.418\n",
      "\n",
      "Epoch 3400\n",
      "Loss: -7.823\n",
      "ELBO: -1073.106\n",
      "IWAE: 0.000\n",
      "MAE: 0.426\n",
      "\n",
      "Epoch 3500\n",
      "Loss: -8.146\n",
      "ELBO: -1049.678\n",
      "IWAE: 0.000\n",
      "MAE: 0.406\n",
      "\n",
      "Epoch 3600\n",
      "Loss: -7.666\n",
      "ELBO: -1039.072\n",
      "IWAE: 0.000\n",
      "MAE: 0.415\n",
      "\n",
      "Epoch 3700\n",
      "Loss: -7.970\n",
      "ELBO: -1040.451\n",
      "IWAE: 0.000\n",
      "MAE: 0.407\n",
      "\n",
      "Epoch 3800\n",
      "Loss: -8.172\n",
      "ELBO: -1030.759\n",
      "IWAE: 0.000\n",
      "MAE: 0.410\n",
      "\n",
      "Epoch 3900\n",
      "Loss: -7.885\n",
      "ELBO: -1035.690\n",
      "IWAE: 0.000\n",
      "MAE: 0.420\n",
      "\n",
      "Epoch 4000\n",
      "Loss: -8.066\n",
      "ELBO: -1044.768\n",
      "IWAE: 0.000\n",
      "MAE: 0.417\n",
      "\n",
      "Epoch 4100\n",
      "Loss: -8.417\n",
      "ELBO: -1019.575\n",
      "IWAE: 0.000\n",
      "MAE: 0.425\n",
      "\n",
      "Epoch 4200\n",
      "Loss: -8.305\n",
      "ELBO: -1040.476\n",
      "IWAE: 0.000\n",
      "MAE: 0.421\n",
      "\n",
      "Epoch 4300\n",
      "Loss: -8.204\n",
      "ELBO: -1040.292\n",
      "IWAE: 0.000\n",
      "MAE: 0.440\n",
      "\n",
      "Epoch 4400\n",
      "Loss: -8.323\n",
      "ELBO: -1062.147\n",
      "IWAE: 0.000\n",
      "MAE: 0.421\n",
      "\n",
      "Epoch 4500\n",
      "Loss: -8.692\n",
      "ELBO: -1012.685\n",
      "IWAE: 0.000\n",
      "MAE: 0.424\n",
      "\n",
      "Epoch 4600\n",
      "Loss: -8.349\n",
      "ELBO: -1066.487\n",
      "IWAE: 0.000\n",
      "MAE: 0.411\n",
      "\n",
      "Epoch 4700\n",
      "Loss: -8.387\n",
      "ELBO: -1024.496\n",
      "IWAE: 0.000\n",
      "MAE: 0.422\n",
      "\n",
      "Epoch 4800\n",
      "Loss: -8.547\n",
      "ELBO: -1042.758\n",
      "IWAE: 0.000\n",
      "MAE: 0.423\n",
      "\n",
      "Epoch 4900\n",
      "Loss: -8.747\n",
      "ELBO: -1023.056\n",
      "IWAE: 0.000\n",
      "MAE: 0.413\n",
      "\n",
      "Epoch 4999\n",
      "Loss: -8.369\n",
      "ELBO: -1083.462\n",
      "IWAE: 0.000\n",
      "MAE: 0.422\n",
      "\n",
      "\n",
      "Save model? (yes/no)yes\n"
     ]
    }
   ],
   "source": [
    "metrics = train_model(\n",
    "    j_model, td_estimator, loader, gpvae_args,\n",
    "    gpvae.estimators.gpvae_estimators.elbo_estimator, None, train, test, y_std, y_mean, log_transform=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/matt/projects/mlmi/SpatioTemporalVAE/venv/lib/python3.7/site-packages/ipykernel_launcher.py:27: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e74943ab8bf45a499df333e2eabeb88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Loss: 157.486\n",
      "ELBO: -49351.936\n",
      "IWAE: 0.000\n",
      "MAE: 0.559\n",
      "\n",
      "Epoch 100\n",
      "Loss: 16.379\n",
      "ELBO: -5496.473\n",
      "IWAE: 0.000\n",
      "MAE: 0.563\n",
      "\n",
      "Epoch 200\n",
      "Loss: 5.459\n",
      "ELBO: -1875.057\n",
      "IWAE: 0.000\n",
      "MAE: 0.490\n",
      "\n",
      "Epoch 300\n",
      "Loss: 3.925\n",
      "ELBO: -1465.890\n",
      "IWAE: 0.000\n",
      "MAE: 0.490\n",
      "\n",
      "Epoch 400\n",
      "Loss: 3.764\n",
      "ELBO: -1434.300\n",
      "IWAE: 0.000\n",
      "MAE: 0.489\n",
      "\n",
      "Epoch 500\n",
      "Loss: 3.520\n",
      "ELBO: -1346.798\n",
      "IWAE: 0.000\n",
      "MAE: 0.484\n",
      "\n",
      "Epoch 600\n",
      "Loss: 3.367\n",
      "ELBO: -1318.660\n",
      "IWAE: 0.000\n",
      "MAE: 0.472\n",
      "\n",
      "Epoch 700\n",
      "Loss: 3.358\n",
      "ELBO: -1315.962\n",
      "IWAE: 0.000\n",
      "MAE: 0.462\n",
      "\n",
      "Epoch 800\n",
      "Loss: 3.219\n",
      "ELBO: -1239.554\n",
      "IWAE: 0.000\n",
      "MAE: 0.433\n",
      "\n",
      "Epoch 900\n",
      "Loss: 3.177\n",
      "ELBO: -1217.915\n",
      "IWAE: 0.000\n",
      "MAE: 0.415\n",
      "\n",
      "Epoch 1000\n",
      "Loss: 3.016\n",
      "ELBO: -1170.086\n",
      "IWAE: 0.000\n",
      "MAE: 0.409\n",
      "\n",
      "Epoch 1100\n",
      "Loss: 2.906\n",
      "ELBO: -1146.907\n",
      "IWAE: 0.000\n",
      "MAE: 0.413\n",
      "\n",
      "Epoch 1200\n",
      "Loss: 2.959\n",
      "ELBO: -1176.549\n",
      "IWAE: 0.000\n",
      "MAE: 0.420\n",
      "\n",
      "Epoch 1300\n",
      "Loss: 2.936\n",
      "ELBO: -1165.764\n",
      "IWAE: 0.000\n",
      "MAE: 0.410\n",
      "\n",
      "Epoch 1400\n",
      "Loss: 3.028\n",
      "ELBO: -1168.353\n",
      "IWAE: 0.000\n",
      "MAE: 0.415\n",
      "\n",
      "Epoch 1500\n",
      "Loss: 2.905\n",
      "ELBO: -1168.826\n",
      "IWAE: 0.000\n",
      "MAE: 0.419\n",
      "\n",
      "Epoch 1600\n",
      "Loss: 2.948\n",
      "ELBO: -1118.015\n",
      "IWAE: 0.000\n",
      "MAE: 0.437\n",
      "\n",
      "Epoch 1700\n",
      "Loss: 2.853\n",
      "ELBO: -1122.694\n",
      "IWAE: 0.000\n",
      "MAE: 0.427\n",
      "\n",
      "Epoch 1800\n",
      "Loss: 2.916\n",
      "ELBO: -1172.563\n",
      "IWAE: 0.000\n",
      "MAE: 0.407\n",
      "\n",
      "Epoch 1900\n",
      "Loss: 2.870\n",
      "ELBO: -1122.847\n",
      "IWAE: 0.000\n",
      "MAE: 0.426\n",
      "\n",
      "Epoch 2000\n",
      "Loss: 2.862\n",
      "ELBO: -1092.135\n",
      "IWAE: 0.000\n",
      "MAE: 0.435\n",
      "\n",
      "Epoch 2100\n",
      "Loss: 2.875\n",
      "ELBO: -1116.423\n",
      "IWAE: 0.000\n",
      "MAE: 0.427\n",
      "\n",
      "Epoch 2200\n",
      "Loss: 2.867\n",
      "ELBO: -1122.194\n",
      "IWAE: 0.000\n",
      "MAE: 0.397\n",
      "\n",
      "Epoch 2300\n",
      "Loss: 2.793\n",
      "ELBO: -1098.703\n",
      "IWAE: 0.000\n",
      "MAE: 0.423\n",
      "\n",
      "Epoch 2400\n",
      "Loss: 2.775\n",
      "ELBO: -1093.314\n",
      "IWAE: 0.000\n",
      "MAE: 0.421\n",
      "\n",
      "Epoch 2499\n",
      "Loss: 2.873\n",
      "ELBO: -1091.498\n",
      "IWAE: 0.000\n",
      "MAE: 0.406\n",
      "\n",
      "\n",
      "Save model? (yes/no)no\n"
     ]
    }
   ],
   "source": [
    "vae_metrics = train_model(\n",
    "    vae_model, gpvae.estimators.vae_estimators.td_estimator, loader, vae_args,\n",
    "    gpvae.estimators.vae_estimators.elbo_estimator, None, train, test, y_std, y_mean, log_transform=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_models = [gpvae.models.TitsiasSparseGPVAE]\n",
    "\n",
    "def train_model(model, loss_fn, loader, args, elbo_estimator=None, iwae_estimator=None, \n",
    "                train=None, test=None, y_std=1., y_mean=1., log_transform=False):\n",
    "    metrics = {'epochs': [],\n",
    "               'losses': [],\n",
    "               'elbos': [],\n",
    "               'iwaes': [],\n",
    "               'smses': [],\n",
    "               'smlls': [],\n",
    "               'mlls': [],\n",
    "               'maes': []\n",
    "              }\n",
    "    \n",
    "    model.train(True)\n",
    "    optimiser = optim.Adam(model.parameters(), lr=args['lr'])\n",
    "    \n",
    "    # Get dataset.\n",
    "    dataset = loader.dataset.dataset()\n",
    "    if loader.dataset.contains_nan:\n",
    "        x, y, m, idx = dataset\n",
    "    else:\n",
    "        x, y, idx = dataset\n",
    "        m = None\n",
    "   \n",
    "    # Training.\n",
    "    for epoch in tqdm_notebook(range(args['epochs'])):\n",
    "        epoch_losses = []\n",
    "        for i, batch in enumerate(loader):\n",
    "            if loader.dataset.contains_nan:\n",
    "                x_b, y_b, m_b, idx_b = batch\n",
    "            else:\n",
    "                x_b, y_b, idx_b = batch\n",
    "                m_b = None\n",
    "            \n",
    "            optimiser.zero_grad()\n",
    "            \n",
    "            if type(model) in mf_models:\n",
    "                loss = loss_fn(\n",
    "                    model, x=x_b, y=y_b, mask=m_b, num_samples=1,\n",
    "                    decoder_scale=args['decoder_scale'], mf=True, idx=idx_b)\n",
    "            else:\n",
    "                loss = loss_fn(\n",
    "                    model, x=x_b, y=y_b, mask=m_b, num_samples=1,\n",
    "                    decoder_scale=args['decoder_scale'])\n",
    "                \n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "            \n",
    "            epoch_losses.append(loss.item())\n",
    "            \n",
    "        # Evaluate model.\n",
    "        if (epoch % args['cache_freq'] == 0) or (epoch == args['epochs'] - 1):\n",
    "            model.eval()\n",
    "            \n",
    "            # Average loss over previous epoch.\n",
    "            mean_loss = np.mean(epoch_losses)\n",
    "            metrics['losses'].append(mean_loss)\n",
    "            \n",
    "            if elbo_estimator is not None:\n",
    "                # ELBO estimate.\n",
    "                if type(model) in mf_models:\n",
    "                    elbo = elbo_estimator(\n",
    "                        model, x, y, mask=m, num_samples=10, mf=True, idx=idx)\n",
    "                else:\n",
    "                    elbo = elbo_estimator(model, x, y, mask=m, num_samples=10)\n",
    "                    \n",
    "                metrics['elbos'].append(elbo)\n",
    "            else:\n",
    "                elbo = 0.\n",
    "            \n",
    "            if iwae_estimator is not None:\n",
    "                # IWAE estimate.\n",
    "                iwae = iwae_estimator(model, x, y, mask=m, num_samples=10)\n",
    "                metrics['iwaes'].append(iwae)\n",
    "            else:\n",
    "                iwae = 0.\n",
    "            \n",
    "            if test is not None:\n",
    "                # Test predictions.\n",
    "                if type(model) in mf_models:\n",
    "                    mean, sigma = model.predict_y(\n",
    "                        x=x, idx=idx, num_samples=100)[:2]\n",
    "                else:\n",
    "                    mean, sigma = model.predict_y(\n",
    "                        x=x, y=y, mask=m, num_samples=100)[:2]\n",
    "                    \n",
    "                mean = mean.numpy() * y_std + y_mean\n",
    "                sigma = sigma.numpy() * y_std\n",
    "                \n",
    "                if log_transform:\n",
    "                    mean = np.exp(mean)\n",
    "                    sigma = np.exp(sigma)\n",
    "            \n",
    "                # Evaluate test predictions.\n",
    "                mean = pd.DataFrame(mean, index=train.index, columns=train.columns)\n",
    "                var = pd.DataFrame(sigma ** 2, index=train.index, columns=train.columns)\n",
    "                \n",
    "                smse = wbml.metric.smse(mean, test).mean()\n",
    "                smll = wbml.metric.smll(mean, var, test).mean()\n",
    "                mll = wbml.metric.mll(mean, var, test).mean()\n",
    "                mae = wbml.metric.mae(mean, test).mean()\n",
    "                \n",
    "                metrics['smses'].append(smse)\n",
    "                metrics['smlls'].append(smll)\n",
    "                metrics['mlls'].append(mll)\n",
    "                metrics['maes'].append(mae)\n",
    "            else:\n",
    "                smse = 0.\n",
    "                smll = 0.\n",
    "                mll = 0.\n",
    "            \n",
    "            print('Epoch {}\\nLoss: {:.3f}\\nELBO: {:.3f}\\n'\n",
    "                  'IWAE: {:.3f}\\nMAE: {:.3f}\\n'.format(epoch, mean_loss, elbo, iwae, mae))\n",
    "            \n",
    "            model.train(True)\n",
    "            \n",
    "    valid_response = False\n",
    "    while not valid_response:\n",
    "        response = input('Save model? (yes/no)')\n",
    "        if response == 'yes':\n",
    "            save(model, args, metrics)\n",
    "            valid_response = True\n",
    "        elif response == 'no':\n",
    "            valid_response = True\n",
    "        else:\n",
    "            pass\n",
    "        \n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save(model, args, metrics):\n",
    "    if 'model' not in args:\n",
    "        print(\"Error: 'model' does not exist in args. Aborting save.\")\n",
    "        return\n",
    "\n",
    "    results_dir = '_results/' + args['model']\n",
    "    if os.path.isdir(results_dir):\n",
    "        i = 1\n",
    "        while os.path.isdir(results_dir + '_' + str(i)):\n",
    "            i += 1\n",
    "\n",
    "        results_dir = results_dir + '_' + str(i)\n",
    "\n",
    "    os.makedirs(results_dir)\n",
    "    results_path = results_dir + '/report.txt'\n",
    "    model_path = results_dir + '/model_state_dict.pt'\n",
    "\n",
    "    # Pickle args and metrics.\n",
    "    with open(results_dir + '/args.pkl', 'wb') as f:\n",
    "        pickle.dump(args, f)\n",
    "\n",
    "    with open(results_dir + '/metrics.pkl', 'wb') as f:\n",
    "        pickle.dump(metrics, f)\n",
    "\n",
    "    # Save args and results in text format.\n",
    "    with open(results_path, 'w') as f:\n",
    "        f.write('Args: \\n')\n",
    "        if isinstance(args, list):\n",
    "            for d in args:\n",
    "                f.write(str(d) + '\\n')\n",
    "        else:\n",
    "            f.write(str(args) + '\\n')\n",
    "\n",
    "        f.write('\\nPerformance: \\n')\n",
    "        for (key, values) in metrics.items():\n",
    "            try:\n",
    "                f.write('{}: {}\\n'.format(key, values[-1]))\n",
    "            except IndexError:\n",
    "                pass\n",
    "\n",
    "    # Save model.state_dict().\n",
    "    torch.save(model.state_dict(), model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spatio-temporal-vae-env",
   "language": "python",
   "name": "spatio-temporal-vae-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
