{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as onp\n",
    "from functools import partial\n",
    "import wandb\n",
    "import pickle\n",
    "from moviepy.editor import ImageSequenceClip\n",
    "import matplotlib.pyplot as plt\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax.experimental.ode import odeint\n",
    "from functools import partial\n",
    "from tqdm import tqdm\n",
    "\n",
    "from keycld import models\n",
    "from keycld.models import predict, predict_run, KeyCLD\n",
    "from keycld.data.dm import Data\n",
    "from keycld import util\n",
    "from keycld.util import NumpyLoader, visualize_n_maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the string below to the 'Run path' (see Overview on the wandb dashboard)\n",
    "# this notebook is only tested for the KeyCLD models\n",
    "\n",
    "run = wandb.Api().run('user/dm-cartpole/1kmhb7uj')\n",
    "\n",
    "args = run.config\n",
    "\n",
    "def get_params(epoch):\n",
    "    path = f'params_{epoch}.p'\n",
    "    run.file(path).download(replace=True)\n",
    "    with open(path, 'rb') as f:\n",
    "        params = pickle.load(f)\n",
    "    return params\n",
    "\n",
    "data = Data(args['environment'], args['init_mode'], args['control'])\n",
    "dataloader = NumpyLoader(data.train, batch_size=1, num_workers=12, shuffle=True)\n",
    "\n",
    "params = get_params(args['num_epochs'] - 1)\n",
    "model = KeyCLD(data.n, data.n, args['num_hidden_dim'], constraint_fn=data.constraint_fn)\n",
    "\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mass matrix\n",
    "mass_matrix_static = model.mass_matrix(params, jnp.zeros(2 * model.num_keypoints))\n",
    "\n",
    "with onp.printoptions(precision=5, suppress=True, floatmode='fixed'):\n",
    "    print(mass_matrix_static)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# potential energy\n",
    "def calculate_potential_energy(image):\n",
    "    keypoints, _ = util.map_to_keypoints(model.encoder(params, image[None]))\n",
    "    state = keypoints.flatten()\n",
    "    return model.potential_energy(params, state)\n",
    "images = data.grid['x']\n",
    "positions = data.grid['positions']\n",
    "potential_energies = jax.vmap(jax.vmap(calculate_potential_energy))(images)\n",
    "\n",
    "plt.imshow(potential_energies)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "item = data.val[i]\n",
    "t = item['t']\n",
    "x = item['x']\n",
    "action = item['action']\n",
    "keypoint_maps = model.encoder(x)\n",
    "keypoints, keypoint_maps = util.map_to_keypoints(keypoint_maps)\n",
    "\n",
    "keypoint_maps_n = keypoint_maps / keypoint_maps.max((1, 2), keepdims=True)\n",
    "heatmaps = (onp.concatenate([x, visualize_n_maps(keypoint_maps_n)], axis=-2) * 255).astype(onp.uint8)\n",
    "\n",
    "keypoints_pred = predict(model, t, keypoints[:3], action)\n",
    "x_recon, gaussian_maps = model.renderer(keypoints_pred)\n",
    "\n",
    "prediction = (onp.concatenate([x, visualize_n_maps(gaussian_maps), x_recon], axis=-2) * 255).astype(onp.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageSequenceClip(list(heatmaps), fps=30).resize((512, 256)).ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageSequenceClip(list(prediction), fps=30).resize((3*256, 256)).ipython_display()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.14 ('.venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.14"
  },
  "vscode": {
   "interpreter": {
    "hash": "4805c36e082d3af34e16eb135be64aadd1f679f5f61fec690034a274ac6709ff"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
