{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c625837",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from ablations.lagcavae import ImageDataset, Model\n",
    "from Lagrangian_caVAE.utils import my_collate\n",
    "# from ablations.hgn import Model\n",
    "from keycld.data.dm import Data\n",
    "from keycld.util import NumpyLoader, visualize_n_maps\n",
    "import numpy as onp\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import pickle\n",
    "from moviepy.editor import ImageSequenceClip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f22fcd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'lagcavae-None-no-None'\n",
    "\n",
    "checkpoint_path = f'../lightning_logs/{name}/last.ckpt'\n",
    "model = Model.load_from_checkpoint(checkpoint_path)\n",
    "# T_pred = model.hparams.T_pred\n",
    "T_pred = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05e93471",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = Data('pendulum', 'random', model.hparams.control)\n",
    "dataset = ImageDataset(data, T_pred)\n",
    "# dataloader = DataLoader(dataset, batch_size=20, shuffle=False, collate_fn=my_collate)\n",
    "model.t_eval = torch.from_numpy(dataset.t_eval)\n",
    "model.hparams.solver = 'rk4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4aaa29e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [item['x'] for item in data.val]\n",
    "x = onp.stack(x)\n",
    "x = torch.from_numpy(x)\n",
    "x = x.permute(1, 0, 4, 2, 3)\n",
    "\n",
    "u = [item['action'] for item in data.val]\n",
    "u = onp.stack(u)\n",
    "u = torch.from_numpy(u)\n",
    "\n",
    "model(x, u)\n",
    "x_pred = model.Xrec\n",
    "\n",
    "x = x.permute(1, 0, 3, 4, 2)\n",
    "x = x.cpu().numpy()\n",
    "x_pred = x_pred.permute(1, 0, 3, 4, 2)\n",
    "x_pred = x_pred.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38bb6b44",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 5\n",
    "output = onp.concatenate([x[i], x_pred[i]], axis=2)\n",
    "output = (output * 255).clip(0, 255).astype(onp.uint8)\n",
    "ImageSequenceClip(list(output), fps=30).resize((512, 256)).ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "482ba17c",
   "metadata": {},
   "outputs": [],
   "source": [
    "vpts = []\n",
    "for run, prediction in zip(x, x_pred):\n",
    "    mse = onp.mean((run - prediction) ** 2, axis=(1, 2, 3))\n",
    "    for vpt, error in enumerate(mse):\n",
    "        if error > data.epsilon:\n",
    "            break\n",
    "    vpts.append(vpt)\n",
    "vpt_mean, vpt_std, vpt_median = onp.mean(vpts), onp.std(vpts), onp.median(vpts)\n",
    "print(vpt_mean, vpt_std, vpt_median)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.14 ('.venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.14"
  },
  "vscode": {
   "interpreter": {
    "hash": "4805c36e082d3af34e16eb135be64aadd1f679f5f61fec690034a274ac6709ff"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
