{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "from argparse import ArgumentParser\n",
    "import os, sys\n",
    "THIS_DIR = os.path.abspath('')\n",
    "PARENT_DIR = os.path.dirname(os.path.abspath(''))\n",
    "sys.path.append(PARENT_DIR)\n",
    "\n",
    "# Third party imports\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning import Trainer, seed_everything\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from torchdiffeq import odeint\n",
    "from torchvision import utils\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# local application imports\n",
    "from lag_caVAE.lag import Lag_Net\n",
    "from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD\n",
    "from hyperspherical_vae.distributions import VonMisesFisher\n",
    "from hyperspherical_vae.distributions import HypersphericalUniform\n",
    "from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset\n",
    "from examples.pend_lag_cavae_trainer import Model as Model_lag_cavae\n",
    "from ablations.ablation_pend_MLPdyna_cavae_trainer import Model as Model_MLPdyna_cavae\n",
    "from ablations.ablation_pend_lag_vae_trainer import Model as Model_lag_vae\n",
    "from ablations.ablation_pend_lag_caAE_trainer import Model as Model_lag_caAE\n",
    "from ablations.HGN import Model as Model_HGN\n",
    "\n",
    "seed_everything(0)\n",
    "%matplotlib inline\n",
    "DPI = 600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'pend-lag-cavae-T_p=4-epoch=701.ckpt')\n",
    "model_lag_cavae = Model_lag_cavae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-pend-MLPdyna-cavae-T_p=4-epoch=919.ckpt')\n",
    "model_MLPdyna_cavae = Model_MLPdyna_cavae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-pend-lag-vae-T_p=4-epoch=916.ckpt')\n",
    "model_lag_vae = Model_lag_vae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-pend-lag-caAE-T_p=4-epoch=778.ckpt')\n",
    "model_lag_caAE = Model_lag_caAE.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'baseline-pend-HGN-T_p=4-epoch=1543.ckpt')\n",
    "model_HGN = Model_HGN.load_from_checkpoint(checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load train data, prepare for plotting prediction\n",
    "data_path=os.path.join(PARENT_DIR, 'datasets', 'pendulum-gym-image-dataset-train.pkl')\n",
    "train_dataset = HomoImageDataset(data_path, T_pred=4)\n",
    "# prepare model\n",
    "model_lag_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_cavae.hparams.annealing = False\n",
    "model_MLPdyna_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_vae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_caAE.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_HGN.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_HGN.step = 3 ; model_HGN.alpha = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/desmond/.pyenv/versions/embed/lib/python3.7/site-packages/torch/nn/functional.py:3447: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n",
      "/home/desmond/.pyenv/versions/embed/lib/python3.7/site-packages/torch/nn/functional.py:3384: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n"
     ]
    }
   ],
   "source": [
    "lag_cavae_train_loss = []\n",
    "MLPdyna_cavae_train_loss = []\n",
    "lag_vae_train_loss = []\n",
    "lag_caAE_train_loss = []\n",
    "\n",
    "for i in range(len(train_dataset.x)):\n",
    "    batch = (torch.from_numpy(train_dataset.x[i]), torch.from_numpy(train_dataset.u[i]))\n",
    "    lag_cavae_train_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    MLPdyna_cavae_train_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    lag_vae_train_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    lag_caAE_train_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "HGN_train_loss = []\n",
    "train_dataset.u_idx = 0\n",
    "dataLoader = DataLoader(train_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)\n",
    "for batch in dataLoader:\n",
    "    HGN_train_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data, prepare for plotting prediction\n",
    "data_path=os.path.join(PARENT_DIR, 'datasets', 'pendulum-gym-image-dataset-test.pkl')\n",
    "test_dataset = HomoImageDataset(data_path, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "lag_cavae_test_loss = []\n",
    "MLPdyna_cavae_test_loss = []\n",
    "lag_vae_test_loss = []\n",
    "lag_caAE_test_loss = []\n",
    "\n",
    "for i in range(len(train_dataset.x)):\n",
    "    batch = (torch.from_numpy(test_dataset.x[i]), torch.from_numpy(test_dataset.u[i]))\n",
    "    lag_cavae_test_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    MLPdyna_cavae_test_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    lag_vae_test_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "    lag_caAE_test_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "HGN_test_loss = []\n",
    "train_dataset.u_idx = 0\n",
    "dataLoader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)\n",
    "for batch in dataLoader:\n",
    "    HGN_test_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lag_cavae: train: 0.0018338330462574957, test: 0.0018673422187566757\n",
      "MLPdyna_cavae: train: 0.0018255940079689025, test: 0.001863210164010525\n",
      "lag_vae: train: 0.002403554953634739, test: 0.002519616447389126\n",
      "lag_caAE: train: 0.001860453300178051, test: 0.00189580075442791\n",
      "HGN: train: 0.0005488340655574575, test: 0.000710727070691064\n"
     ]
    }
   ],
   "source": [
    "scale = 32*32*5\n",
    "print(f'lag_cavae: train: {np.mean(lag_cavae_train_loss)/scale}, test: {np.mean(lag_cavae_test_loss)/scale}')\n",
    "print(f'MLPdyna_cavae: train: {np.mean(MLPdyna_cavae_train_loss)/scale}, test: {np.mean(MLPdyna_cavae_test_loss)/scale}')\n",
    "print(f'lag_vae: train: {np.mean(lag_vae_train_loss)/scale}, test: {np.mean(lag_vae_test_loss)/scale}')\n",
    "print(f'lag_caAE: train: {np.mean(lag_caAE_train_loss)/scale}, test: {np.mean(lag_caAE_test_loss)/scale}')\n",
    "print(f'HGN: train: {np.mean(HGN_train_loss)/scale}, test: {np.mean(HGN_test_loss)/scale}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 64-bit ('embed': venv)",
   "language": "python",
   "name": "python37464bitembedvenv00bda69ab17746c28ffa9a3e65936c26"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
