{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d549596-55ef-4afa-a2ad-1e3a3524a1d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from dataset import NeuralPhysDataset\n",
    "\n",
    "def custom_sort(x):\n",
    "    numbers = re.findall(r'\\d+', x)\n",
    "    return int(numbers[0]) if numbers else float('inf')\n",
    "def show_pre(pre):\n",
    "    pre_=[]\n",
    "    for i in pre:\n",
    "        i = i.reshape(3,128,256).detach().transpose(2,0).numpy()\n",
    "        pre_.append(np.swapaxes(i[:128], 0, 1))\n",
    "    return pre_ \n",
    "\n",
    "\n",
    "traj_index=892\n",
    "val_dataset = NeuralPhysDataset(flag='val')\n",
    "image_folder = 'training_data/'+ 'single_pendulum'+'/{}'.format(traj_index)\n",
    "images = sorted(os.listdir(image_folder), key=custom_sort)\n",
    "x0 = val_dataset.get_data(os.path.join(image_folder, images[0]))\n",
    "x1 = val_dataset.get_data(os.path.join(image_folder, images[1]))\n",
    "X_0 = torch.cat([x0, x1], 2)\n",
    "def Img(frame):\n",
    "    x0 = val_dataset.get_data(os.path.join(image_folder, images[frame]))\n",
    "    x1 = val_dataset.get_data(os.path.join(image_folder, images[frame+1]))\n",
    "    x = torch.cat([x0, x1], 2)\n",
    "    img_true = x.transpose(2,0).numpy()\n",
    "    \n",
    "    return x, np.swapaxes(img_true[:128], 0, 1)\n",
    "################## Ours\n",
    "seed=0\n",
    "filename='single_pendulum' + '_caehyp_vp{}'.format(seed)\n",
    "cae_path='outputs/'+filename+'/model_best.pkl'\n",
    "cae=torch.load(cae_path,map_location='cpu').eval()\n",
    "\n",
    "filename = 'latent_'+'single_pendulum' + '_caehyp_vp{}'.format(seed)\n",
    "latent_net_path='outputs/'+filename+'/model_best.pkl'\n",
    "net=torch.load(latent_net_path,map_location='cpu').eval()\n",
    "\n",
    "x0= cae.encoder(X_0.unsqueeze(0)).squeeze()\n",
    "latent_pre = net.predict(x0, h=net.h, steps=60, keepinitx=True)\n",
    "pre_cae=[]\n",
    "for i in range(latent_pre.shape[0]):\n",
    "    pre_cae.append(cae.decoder(latent_pre[i:i+1]))\n",
    "pre_cae=show_pre(pre_cae) \n",
    "\n",
    "\n",
    "####################             PLOT          #####################\n",
    "fig, ax=plt.subplots(2, 8, figsize=(12,4))\n",
    "fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.02, hspace=0.4)\n",
    "\n",
    "\n",
    "a = Img(0) \n",
    "\n",
    "for i in range(8): \n",
    "    traj_i=1+4*i\n",
    "    x, ima_true = Img(traj_i*2)\n",
    "    ax[0,i].imshow(ima_true)\n",
    "    ax[0,i].axis('off')\n",
    "    ax[1,i].imshow(pre_cae[traj_i*2])\n",
    "    ax[1,i].axis('off')\n",
    "    \n",
    "    ax[1,i].text(70,150, 't={}/60'.format(traj_i*2), fontsize=18, color='black',\n",
    "      horizontalalignment='center', verticalalignment='center')\n",
    "   \n",
    "\n",
    "    \n",
    "titlesize=25\n",
    "ax[0,0].set_title('Ground truth', fontsize=titlesize, loc='left')\n",
    "ax[1,0].set_title('Prediction of CpAE', fontsize=titlesize, loc='left')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Itorch",
   "language": "python",
   "name": "itorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
