{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "innocent-madagascar",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scipy.io as sio\n",
    "import torch\n",
    "from physvae import utils\n",
    "from physvae.locomotion.model import VAE\n",
    "from torchdiffeq import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dynamic-harvey",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "datadir = './data/locomotion/'\n",
    "dataname = 'test'\n",
    "data_test = sio.loadmat('{}/data_{}.mat'.format(datadir, dataname))['data'].astype(np.float32)\n",
    "_, dim_x, dim_t = data_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "miniature-porcelain",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load training args as dict\n",
    "dim_y = 3\n",
    "modeldir = './out_locomotion/'\n",
    "\n",
    "with open('{}/args.json'.format(modeldir), 'r') as f:\n",
    "    args_tr_dict = json.load(f)\n",
    "\n",
    "# set model\n",
    "device = \"cpu\" # torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = VAE(args_tr_dict).to(device)\n",
    "\n",
    "# load model\n",
    "model.load_state_dict(torch.load('{}/model.pt'.format(modeldir), map_location=device))\n",
    "model.eval()\n",
    "print('model loaded')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mysterious-pricing",
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference & reconstruction on test data, and compute full trajectories (including generalized momenta)\n",
    "data_test_tensor = torch.Tensor(data_test).to(device).contiguous()\n",
    "\n",
    "# reg\n",
    "z_phy_stat, z_aux2_stat, init_yy = model.encode(data_test_tensor)\n",
    "z_phy, z_aux2 = model.draw(z_phy_stat, z_aux2_stat, hard_z=False)\n",
    "x_PB, x_P, _, _ = model.decode(z_phy, z_aux2, init_yy, full=True)\n",
    "def ODEfunc(t:torch.Tensor, yy:torch.Tensor):\n",
    "    return model.physics_model(z_phy, yy)\n",
    "yy_seq = odeint(ODEfunc, init_yy, model.t_intg, method='dopri5') # <len_intg x n x 2dim_y>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "intended-cooling",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot\n",
    "idx=0\n",
    "dat = data_test[idx].T\n",
    "\n",
    "# reg\n",
    "plt.figure()\n",
    "plt.subplot(2,1,1)\n",
    "plt.plot(dat)\n",
    "plt.plot(x_PB[idx].detach().cpu().numpy().T, 'k--')\n",
    "plt.subplot(2,1,2)\n",
    "plt.plot(dat)\n",
    "plt.plot(x_P[idx].detach().cpu().numpy().T, 'k--')\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
