{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab37752",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from ablations.lagcavae import ImageDataset\n",
    "from Lagrangian_caVAE.utils import my_collate\n",
    "from ablations.hgn import Model\n",
    "from keycld.data.dm import Data\n",
    "from keycld.util import NumpyLoader, visualize_n_maps\n",
    "import numpy as onp\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import pickle\n",
    "from moviepy.editor import ImageSequenceClip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7146e428",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'hgn-acrobot-None'\n",
    "\n",
    "checkpoint_path = f'../lightning_logs/{name}/last.ckpt'\n",
    "model = Model.load_from_checkpoint(checkpoint_path)\n",
    "# T_pred = model.hparams.T_pred\n",
    "T_pred = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e025651",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = Data(model.hparams.environment, 'random', 'no')\n",
    "dataset = ImageDataset(data, T_pred)\n",
    "# dataloader = DataLoader(dataset, batch_size=20, shuffle=False, collate_fn=my_collate)\n",
    "model.t_eval = torch.from_numpy(dataset.t_eval)\n",
    "model.hparams.solver = 'rk4'\n",
    "model.step = 4\n",
    "model.alpha = 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48a24200",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [item['x'] for item in data.val]\n",
    "x = onp.stack(x)\n",
    "x = torch.from_numpy(x)\n",
    "x = x.permute(1, 0, 4, 2, 3)\n",
    "\n",
    "model(x[:4])\n",
    "x_pred = model.Xrec\n",
    "\n",
    "x = x.permute(1, 0, 3, 4, 2)\n",
    "x = x.cpu().numpy()\n",
    "x_pred = x_pred.permute(1, 0, 3, 4, 2)\n",
    "x_pred = x_pred.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441f025c",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "output = onp.concatenate([x[i], x_pred[i]], axis=2)\n",
    "output = (output * 255).clip(0, 255).astype(onp.uint8)\n",
    "ImageSequenceClip(list(output), fps=30).resize((512, 256)).ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e10953",
   "metadata": {},
   "outputs": [],
   "source": [
    "vpts = []\n",
    "for run, prediction in zip(x, x_pred):\n",
    "    mse = onp.mean((run - prediction) ** 2, axis=(1, 2, 3))\n",
    "    for vpt, error in enumerate(mse):\n",
    "        if error > data.epsilon:\n",
    "            break\n",
    "    vpts.append(vpt)\n",
    "vpt_mean, vpt_std, vpt_median = onp.mean(vpts), onp.std(vpts), onp.median(vpts)\n",
    "print(vpt_mean, vpt_std, vpt_median)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
