{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "from argparse import ArgumentParser\n",
    "import os, sys\n",
    "THIS_DIR = os.path.abspath('')\n",
    "PARENT_DIR = os.path.dirname(os.path.abspath(''))\n",
    "sys.path.append(PARENT_DIR)\n",
    "\n",
    "# Third party imports\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning import Trainer, seed_everything\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from torchdiffeq import odeint\n",
    "from torchvision import utils\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# local application imports\n",
    "from lag_caVAE.lag import Lag_Net\n",
    "from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD\n",
    "from hyperspherical_vae.distributions import VonMisesFisher\n",
    "from hyperspherical_vae.distributions import HypersphericalUniform\n",
    "from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset\n",
    "from examples.cart_lag_cavae_trainer import Model as Model_lag_cavae\n",
    "from ablations.ablation_cart_MLPdyna_cavae_trainer import Model as Model_MLPdyna_cavae\n",
    "from ablations.ablation_cart_lag_vae_trainer import Model as Model_lag_vae\n",
    "from ablations.ablation_cart_lag_MLPEnc_caDec_trainer import Model as Model_lag_MLPEnc_caDec\n",
    "from ablations.ablation_cart_lag_caEnc_MLPDec_trainer import Model as Model_lag_caEnc_MLPDec\n",
    "from ablations.ablation_cart_lag_caAE_trainer import Model as Model_lag_caAE\n",
    "from ablations.HGN import Model as Model_HGN\n",
    "\n",
    "seed_everything(0)\n",
    "%matplotlib inline\n",
    "DPI = 600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'updated-cart-lag-cavae-T_p=4-epoch=905.ckpt')\n",
    "model_lag_cavae = Model_lag_cavae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-cart-MLPdyna-cavae-T_p=4-epoch=807.ckpt')\n",
    "model_MLPdyna_cavae = Model_MLPdyna_cavae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-cart-lag-vae-T_p=4-epoch=987.ckpt')\n",
    "model_lag_vae = Model_lag_vae.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-cart-lag-MLPEnc-caDec-T_p=4-epoch=524.ckpt')\n",
    "model_lag_MLPEnc_caDec = Model_lag_MLPEnc_caDec.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "# this checkpoint is trained with learning rate 1e-4\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-cart-lag-caEnc-MLPDec-T_p=4-epoch=954.ckpt')\n",
    "model_lag_caEnc_MLPDec = Model_lag_caEnc_MLPDec.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "# this checkpoint is trained with learning rate 1e-4\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'ablation-cart-lag-caAE-T_p=4-epoch=909.ckpt')\n",
    "model_lag_caAE = Model_lag_caAE.load_from_checkpoint(checkpoint_path)\n",
    "\n",
    "checkpoint_path = os.path.join(PARENT_DIR, \n",
    "                               'checkpoints', \n",
    "                               'baseline-cart-HGN-T_p=4-epoch=1777.ckpt')\n",
    "model_HGN = Model_HGN.load_from_checkpoint(checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load train data, prepare for plotting prediction\n",
    "# WARNING: this might requires ~18G memory at peak\n",
    "data_path=os.path.join(PARENT_DIR, 'datasets', 'cartpole-gym-image-dataset-rgb-u9-train.pkl')\n",
    "train_dataset = HomoImageDataset(data_path, T_pred=4)\n",
    "# prepare model\n",
    "model_lag_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_cavae.hparams.annealing = False\n",
    "model_MLPdyna_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_vae.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_MLPEnc_caDec.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_caEnc_MLPDec.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_lag_caAE.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_HGN.t_eval = torch.from_numpy(train_dataset.t_eval)\n",
    "model_HGN.step = 3 ; model_HGN.alpha = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/z0042y5x/.pyenv/versions/3.7.5/envs/lag/lib/python3.7/site-packages/torch/nn/functional.py:3447: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n",
      "/home/z0042y5x/.pyenv/versions/3.7.5/envs/lag/lib/python3.7/site-packages/torch/nn/functional.py:3384: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
      "  warnings.warn(\"Default grid_sample and affine_grid behavior has changed \"\n"
     ]
    }
   ],
   "source": [
    "lag_cavae_train_loss = []\n",
    "MLPdyna_cavae_train_loss = []\n",
    "lag_MLPEnc_caDec_train_loss = []\n",
    "lag_caEnc_MLPDec_train_loss = []\n",
    "lag_vae_train_loss = []\n",
    "lag_caAE_train_loss = []\n",
    "\n",
    "for i in range(len(train_dataset.x)):\n",
    "    train_dataset.u_idx = i\n",
    "    dataLoader = DataLoader(train_dataset, batch_size=512, shuffle=False, collate_fn=my_collate)\n",
    "    for batch in dataLoader:\n",
    "        lag_cavae_train_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        MLPdyna_cavae_train_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_vae_train_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_MLPEnc_caDec_train_loss.append(model_lag_MLPEnc_caDec.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_caEnc_MLPDec_train_loss.append(model_lag_caEnc_MLPDec.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_caAE_train_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "HGN_train_loss = []\n",
    "train_dataset.u_idx = 0\n",
    "dataLoader = DataLoader(train_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)\n",
    "for batch in dataLoader:\n",
    "    HGN_train_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "del dataLoader\n",
    "del train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load train data, prepare for plotting prediction\n",
    "# WARNING: this might requires ~18G memory at peak\n",
    "data_path=os.path.join(PARENT_DIR, 'datasets', 'cartpole-gym-image-dataset-rgb-u9-test.pkl')\n",
    "test_dataset = HomoImageDataset(data_path, T_pred=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "lag_cavae_test_loss = []\n",
    "MLPdyna_cavae_test_loss = []\n",
    "lag_MLPEnc_caDec_test_loss = []\n",
    "lag_caEnc_MLPDec_test_loss = []\n",
    "lag_vae_test_loss = []\n",
    "lag_caAE_test_loss = []\n",
    "\n",
    "for i in range(len(test_dataset.x)):\n",
    "    test_dataset.u_idx = i\n",
    "    dataLoader = DataLoader(test_dataset, batch_size=512, shuffle=False, collate_fn=my_collate)\n",
    "    for batch in dataLoader:\n",
    "        lag_cavae_test_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        MLPdyna_cavae_test_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_vae_test_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_MLPEnc_caDec_test_loss.append(model_lag_MLPEnc_caDec.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_caEnc_MLPDec_test_loss.append(model_lag_caEnc_MLPDec.training_step(batch, 0)['log']['recon_loss'].item())\n",
    "        lag_caAE_test_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "HGN_test_loss = []\n",
    "test_dataset.u_idx = 0\n",
    "dataLoader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)\n",
    "for batch in dataLoader:\n",
    "    HGN_test_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lag_cavae: train: 0.0034782227568535343, test: 0.0035928564741172725\n",
      "MLPdyna_cavae: train: 0.013915377296507358, test: 0.015017079458468491\n",
      "lag_vae: train: 364.9612683741065, test: 599.4507895821602\n",
      "lag_MLPEnc_caDec: train: 0.018483501424392066, test: 0.01914455969300535\n",
      "lag_caEnc_MLPDec: train: 14433.177732683418, test: 15153.190610957206\n",
      "lag_caAE: train: 0.0041534248480780255, test: 0.004238036494805581\n",
      "HGN: train: 0.002811681083403528, test: 0.0029128929716534914\n"
     ]
    }
   ],
   "source": [
    "scale = 64*64*5\n",
    "print(f'lag_cavae: train: {np.mean(lag_cavae_train_loss)/scale}, test: {np.mean(lag_cavae_test_loss)/scale}')\n",
    "print(f'MLPdyna_cavae: train: {np.mean(MLPdyna_cavae_train_loss)/scale}, test: {np.mean(MLPdyna_cavae_test_loss)/scale}')\n",
    "print(f'lag_vae: train: {np.mean(lag_vae_train_loss)/scale}, test: {np.mean(lag_vae_test_loss)/scale}')\n",
    "print(f'lag_MLPEnc_caDec: train: {np.mean(lag_MLPEnc_caDec_train_loss)/scale}, test: {np.mean(lag_MLPEnc_caDec_test_loss)/scale}')\n",
    "print(f'lag_caEnc_MLPDec: train: {np.mean(lag_caEnc_MLPDec_train_loss)/scale}, test: {np.mean(lag_caEnc_MLPDec_test_loss)/scale}')\n",
    "print(f'lag_caAE: train: {np.mean(lag_caAE_train_loss)/scale}, test: {np.mean(lag_caAE_test_loss)/scale}')\n",
    "print(f'HGN: train: {np.mean(HGN_train_loss)/scale}, test: {np.mean(HGN_test_loss)/scale}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
