{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "860523f3-df37-436b-a887-18669fde4436",
   "metadata": {},
   "source": [
    "# Learning from Demonstration with Implicit Nonlinear Dynamics Models\n",
    "\n",
    "The purpose of this notebook is to enable the analysis of evaluation data generated by trained models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "751c897a-464a-48f9-bc89-c9ebe0c089e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "import io\n",
    "from PIL import Image\n",
    "from PIL import Image, ImageDraw\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objs as go\n",
    "from plotly.subplots import make_subplots\n",
    "from IPython.display import Image as IPImage\n",
    "from IPython.display import display\n",
    "from IPython.display import HTML\n",
    "\n",
    "from implicit_nonlinear_dynamics.evaluate.utils import sample_within_threshold\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_esn import load_esn, evaluate_demo_esn, output_representations_esn\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_esn_multitask import load_esn_multi, evaluate_demo_esn_multi, output_representations_esn_multi\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_feedforward import load_feedforward, evaluate_demo_feedforward, output_representations_feedforward\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_feedforward_multitask import load_feedforward_multi, evaluate_demo_feedforward_multi, output_representations_feedforward_multi\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_ours import load_ours, evaluate_demo_ours, output_representations_ours\n",
    "from implicit_nonlinear_dynamics.evaluate.evaluate_ours_multitask import load_ours_multi, evaluate_demo_ours_multi, output_representations_ours_multi\n",
    "\n",
    "COLOURS = [\n",
    "    '#636EFA',\n",
    "    '#EF553B',\n",
    "    '#00CC96',\n",
    "    '#AB63FA',\n",
    "    '#FFA15A',\n",
    "    '#19D3F3',\n",
    "    '#FF6692',\n",
    "    '#B6E880',\n",
    "    '#FF97FF',\n",
    "    '#FECB52',\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49a960d6",
   "metadata": {},
   "source": [
    "# Validate Trained Models\n",
    "\n",
    "This first section simply loads models and attempts to generate evaluation data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c442e84-e03d-4394-b18a-b459ce13ebc8",
   "metadata": {},
   "source": [
    "## Validate ESN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb9cc08-2fd0-47ed-a272-b667ce1adaba",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/esn\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo, esn_predicted, esn_frechet, esn_latency, _, _ = evaluate_demo_esn(esn_model, cfg)\n",
    "esn_representations = output_representations_esn(esn_model, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "731f9f1c-c299-4829-a889-28251721d38b",
   "metadata": {},
   "source": [
    "## Validate ESN Multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb594623-fdd3-4bc2-9978-67eb1ce75564",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/esn_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "esn_model_multi = load_esn_multi(cfg)\n",
    "esn_multi_demo, esn_multi_predicted, esn_multi_frechet, esn_multi_latency = evaluate_demo_esn_multi(esn_model_multi, cfg)\n",
    "esn_multi_representations = output_representations_esn_multi(esn_model_multi, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fddbf188-1b3a-454b-b563-9e90446f745b",
   "metadata": {},
   "source": [
    "## Validate Feedforward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da14529c-1903-4e6a-97e1-1d1b00b030a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 96\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo, feedforward_predicted, feedforward_frechet, feedforward_latency, _, _ = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "feedforward_representations = output_representations_feedforward(feedforward_model, cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41c184e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# feedforward action chunking + temporal ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    # apply ensemble \n",
    "    cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "temp_feedforward_demo_s, temp_feedforward_predicted_s, temp_feedforward_frechet_s, temp_feedforward_latency_s, temp_feedforward_jerk_x_s, temp_feedforward_jerk_y_s = evaluate_demo_feedforward(feedforward_model, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e1ce0a3-0c59-4f17-a865-0a6cd55b5930",
   "metadata": {},
   "source": [
    "## Validate Feedforward Multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1cae56-6a7f-42fa-bb05-2816d0dbf711",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "feedforward_model_multi = load_feedforward_multi(cfg)\n",
    "feedforward_multi_demo, feedforward_multi_predicted, feedforward_multi_frechet, feedforward_multi_latency = evaluate_demo_feedforward_multi(feedforward_model_multi, cfg)\n",
    "feedforward_multi_representations = output_representations_feedforward_multi(feedforward_model_multi, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ed14879-0c4c-408e-b040-bd9340cf8b3e",
   "metadata": {},
   "source": [
    "## Validate Ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c44635b5-62cb-427b-ac9b-74957cf1ae1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo, ours_predicted, ours_frechet, ours_latency, _, _ = evaluate_demo_ours(ours_model, cfg)\n",
    "ours_representations = output_representations_ours(ours_model, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7183ca1b-13b5-440f-867b-1bd8b7a0ecc0",
   "metadata": {},
   "source": [
    "## Validate Ours Multi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68cc5085-7af9-4d5d-9b07-f461421a4d94",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "ours_multi_model = load_ours_multi(cfg)\n",
    "ours_multi_demo, ours_multi_predicted, ours_multi_frechet, ours_multi_latency= evaluate_demo_ours_multi(ours_multi_model, cfg)\n",
    "ours_multi_representations = output_representations_ours_multi(ours_multi_model, cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0589320",
   "metadata": {},
   "source": [
    "## Generate Website Gifs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c870d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gif_from_predictions(demo_data, predicted_data, colour_idx, title, filename='demo_vs_predicted.gif', duration=100):\n",
    "    \"\"\"\n",
    "    Create a GIF of single-task character drawing.\n",
    "    \"\"\"\n",
    "    # Create a list to store the images\n",
    "    images = []\n",
    "\n",
    "    # Define the number of frames\n",
    "    num_frames = predicted_data.shape[1]\n",
    "\n",
    "    # Create the frames\n",
    "    for i in range(num_frames):\n",
    "        fig = go.Figure()\n",
    "\n",
    "        # Plot the demo data\n",
    "        fig.add_trace(go.Scatter(\n",
    "            x=demo_data[0, 0, :], \n",
    "            y=demo_data[0, 1, :], \n",
    "            mode='lines', \n",
    "            name='Demo', \n",
    "            line=dict(color=COLOURS[-5], width=2)\n",
    "        ))\n",
    "\n",
    "        # Plot the predicted data up to the current frame\n",
    "        fig.add_trace(go.Scatter(\n",
    "            x=predicted_data[0, :i, 0], \n",
    "            y=predicted_data[0, :i, 1], \n",
    "            mode='lines', \n",
    "            name='Predicted', \n",
    "            line=dict(color=colour_idx, width=2),\n",
    "            fill='tonexty'\n",
    "        ))\n",
    "\n",
    "        # Update layout\n",
    "        fig.update_layout(\n",
    "            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
    "            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),\n",
    "            plot_bgcolor='white',\n",
    "            showlegend=False,\n",
    "            title=dict(\n",
    "                    text=title, \n",
    "                    font=dict(size=60),\n",
    "                    y=0.0,\n",
    "                    x=0.5,\n",
    "                    xanchor='center',\n",
    "                    yanchor='bottom',\n",
    "                    ),\n",
    "        )\n",
    "\n",
    "        # Save the plot to a bytes buffer\n",
    "        buf = io.BytesIO()\n",
    "        fig.write_image(buf, format='png')\n",
    "        buf.seek(0)\n",
    "\n",
    "        # Create an image from the bytes buffer\n",
    "        img = Image.open(buf)\n",
    "        images.append(img)\n",
    "\n",
    "    # Save the images as a gif\n",
    "    images[0].save(filename, save_all=True, append_images=images[1:], duration=duration, loop=0)\n",
    "\n",
    "    # Display the gif using IPython display\n",
    "    display(IPImage(filename=filename))\n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo, ours_predicted, ours_frechet, ours_latency, _, _ = evaluate_demo_ours(ours_model, cfg)\n",
    "ours_representations = output_representations_ours(ours_model, cfg)    \n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/esn\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo, esn_predicted, esn_frechet, esn_latency, _, _ = evaluate_demo_esn(esn_model, cfg)\n",
    "esn_representations = output_representations_esn(esn_model, cfg)\n",
    "\n",
    "\n",
    "# feedforward action chunking + temporal ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_s, feedforward_predicted_s, feedforward_frechet_s, feedforward_latency_s, feedforward_jerk_x_s, feedforward_jerk_y_s = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "# feedforward action chunking + temporal ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "    # apply ensemble \n",
    "    cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "temp_feedforward_demo_s, temp_feedforward_predicted_s, temp_feedforward_frechet_s, temp_feedforward_latency_s, temp_feedforward_jerk_x_s, temp_feedforward_jerk_y_s = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "\n",
    "create_gif_from_predictions(ours_demo, ours_predicted, COLOURS[0], \"Ours\", \"demo_vs_predicted_ours.gif\")\n",
    "create_gif_from_predictions(esn_demo, esn_predicted, COLOURS[1], \"ESN\", \"demo_vs_predicted_esn.gif\")\n",
    "create_gif_from_predictions(feedforward_demo_s, feedforward_predicted_s[:,:100,:], COLOURS[2], \"Feedforward\", \"demo_vs_predicted_ff.gif\")\n",
    "create_gif_from_predictions(temp_feedforward_demo_s, temp_feedforward_predicted_s[:,:100,:], COLOURS[3], \"Temporal Ensemble\", \"demo_vs_predicted_ff_ensemble.gif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab51d305-5952-44f6-a5bc-3d5e2e1659fe",
   "metadata": {},
   "source": [
    "# Single Task Demonstrations\n",
    "\n",
    "In this section we generate all evaluation metrics for the single-task character generation setting. In particular we generate evaluation metrics for the \"S\", \"C\" and \"L\" characters as in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ac8ffc-0616-434e-ae19-f18c0e80079a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ours\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo_s, ours_predicted_s, ours_frechet_s, ours_latency_s, ours_jerk_x_s, ours_jerk_y_s = evaluate_demo_ours(ours_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"CShape\"\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo_c, ours_predicted_c, ours_frechet_c, ours_latency_c, ours_jerk_x_c, ours_jerk_y_c = evaluate_demo_ours(ours_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"LShape\"\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo_l, ours_predicted_l, ours_frechet_l, ours_latency_l, ours_jerk_x_l, ours_jerk_y_l = evaluate_demo_ours(ours_model, cfg)\n",
    "\n",
    "# esn\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/esn\")\n",
    "    cfg = cfg[\"config\"]\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo_s, esn_predicted_s, esn_frechet_s, esn_latency_s, esn_jerk_x_s, esn_jerk_y_s = evaluate_demo_esn(esn_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"CShape\"\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo_c, esn_predicted_c, esn_frechet_c, esn_latency_c, esn_jerk_x_c, esn_jerk_y_c = evaluate_demo_esn(esn_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"LShape\"\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo_l, esn_predicted_l, esn_frechet_l, esn_latency_l, esn_jerk_x_l, esn_jerk_y_l = evaluate_demo_esn(esn_model, cfg)\n",
    "\n",
    "# feedforward action chunking\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_s, feedforward_predicted_s, feedforward_frechet_s, feedforward_latency_s, feedforward_jerk_x_s, feedforward_jerk_y_s = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"CShape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_c, feedforward_predicted_c, feedforward_frechet_c, feedforward_latency_c, feedforward_jerk_x_c, feedforward_jerk_y_c = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"LShape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_l, feedforward_predicted_l, feedforward_frechet_l, feedforward_latency_l, feedforward_jerk_x_l, feedforward_jerk_y_l = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "\n",
    "# feedforward action chunking + temporal ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    # apply ensemble \n",
    "    cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "temp_feedforward_demo_s, temp_feedforward_predicted_s, temp_feedforward_frechet_s, temp_feedforward_latency_s, temp_feedforward_jerk_x_s, temp_feedforward_jerk_y_s = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"CShape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "temp_feedforward_demo_c, temp_feedforward_predicted_c, temp_feedforward_frechet_c, temp_feedforward_latency_c, temp_feedforward_jerk_x_c, temp_feedforward_jerk_y_c = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "cfg[\"dataset\"][\"shape\"] = \"LShape\"\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "temp_feedforward_demo_l, temp_feedforward_predicted_l, temp_feedforward_frechet_l, temp_feedforward_latency_l, temp_feedforward_jerk_x_l, temp_feedforward_jerk_y_l = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "\n",
    "data = {\n",
    "    \"ours\": {\n",
    "        \"Sshape\":{\"demo\": ours_demo_s, \"predicted\": ours_predicted_s, \"frechet\": ours_frechet_s, \"latency\": ours_latency_s, \"jerk_x\": ours_jerk_x_s, \"jerk_y\": ours_jerk_y_s},\n",
    "        \"CShape\":{\"demo\": ours_demo_c, \"predicted\": ours_predicted_c, \"frechet\": ours_frechet_c, \"latency\": ours_latency_c, \"jerk_x\": ours_jerk_x_c, \"jerk_y\": ours_jerk_y_c},\n",
    "        \"LShape\":{\"demo\": ours_demo_l, \"predicted\": ours_predicted_l, \"frechet\": ours_frechet_l, \"latency\": ours_latency_l, \"jerk_x\": ours_jerk_x_l, \"jerk_y\": ours_jerk_y_l},\n",
    "},\n",
    "    \"esn\": {\n",
    "        \"Sshape\":{\"demo\": esn_demo_s, \"predicted\": esn_predicted_s, \"frechet\": esn_frechet_s, \"latency\": esn_latency_s, \"jerk_x\": esn_jerk_x_s, \"jerk_y\": esn_jerk_y_s},\n",
    "        \"CShape\":{\"demo\": esn_demo_c, \"predicted\": esn_predicted_c, \"frechet\": esn_frechet_c, \"latency\": esn_latency_c, \"jerk_x\": esn_jerk_x_c, \"jerk_y\": esn_jerk_y_c},\n",
    "        \"LShape\":{\"demo\": esn_demo_l, \"predicted\": esn_predicted_l, \"frechet\": esn_frechet_l, \"latency\": esn_latency_l, \"jerk_x\": esn_jerk_x_l, \"jerk_y\": esn_jerk_y_l},\n",
    "    },\n",
    "    \"ff\": {\n",
    "        \"Sshape\":{\"demo\": feedforward_demo_s, \"predicted\": feedforward_predicted_s, \"frechet\": feedforward_frechet_s, \"latency\": feedforward_latency_s, \"jerk_x\": feedforward_jerk_x_s, \"jerk_y\": feedforward_jerk_y_s},\n",
    "        \"CShape\":{\"demo\": feedforward_demo_c, \"predicted\": feedforward_predicted_c, \"frechet\": feedforward_frechet_c, \"latency\": feedforward_latency_c, \"jerk_x\": feedforward_jerk_x_c, \"jerk_y\": feedforward_jerk_y_c},\n",
    "        \"LShape\":{\"demo\": feedforward_demo_l, \"predicted\": feedforward_predicted_l, \"frechet\": feedforward_frechet_l, \"latency\": feedforward_latency_l, \"jerk_x\": feedforward_jerk_x_l, \"jerk_y\": feedforward_jerk_y_l},\n",
    "    },\n",
    "    \"ff_ensemble\": {\n",
    "        \"Sshape\":{\"demo\": temp_feedforward_demo_s, \"predicted\": temp_feedforward_predicted_s, \"frechet\": temp_feedforward_frechet_s, \"latency\": temp_feedforward_latency_s, \"jerk_x\": temp_feedforward_jerk_x_s, \"jerk_y\": temp_feedforward_jerk_y_s},\n",
    "        \"CShape\":{\"demo\": temp_feedforward_demo_c, \"predicted\": temp_feedforward_predicted_c, \"frechet\": temp_feedforward_frechet_c, \"latency\": temp_feedforward_latency_c, \"jerk_x\": temp_feedforward_jerk_x_c, \"jerk_y\": temp_feedforward_jerk_y_c},\n",
    "        \"LShape\":{\"demo\": temp_feedforward_demo_l, \"predicted\": temp_feedforward_predicted_l, \"frechet\": temp_feedforward_frechet_l, \"latency\": temp_feedforward_latency_l, \"jerk_x\": temp_feedforward_jerk_x_l, \"jerk_y\": temp_feedforward_jerk_y_l},\n",
    "    },\n",
    "}\n",
    "\n",
    "# print metrics\n",
    "df = pd.DataFrame(columns=['model', 'shape', 'frechet_dist', 'latency', 'jerk_x', 'jerk_y'])\n",
    "for model_idx, model in enumerate([\"ours\", \"esn\", \"ff\", \"ff_ensemble\"]):\n",
    "    for shape_idx, shape in enumerate([\"Sshape\", \"CShape\", \"LShape\"]):\n",
    "        df.loc[len(df)] = [model, shape, np.mean(data[model][shape][\"frechet\"]), np.mean(data[model][shape][\"latency\"]), np.mean(data[model][shape][\"jerk_x\"]), np.mean(data[model][shape][\"jerk_y\"])]\n",
    "\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc5adca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a 3x3 subplot of all character drawing tasks and models\n",
    "fig = make_subplots(rows=3, cols=4)\n",
    "\n",
    "for model_idx, model in enumerate([\"ours\", \"esn\", \"ff\", \"ff_ensemble\"]):\n",
    "    for shape_idx, shape in enumerate([\"Sshape\", \"CShape\", \"LShape\"]):\n",
    "        sample_predicted, sample_demo = sample_within_threshold(\n",
    "                                            data[model][shape][\"predicted\"][0, :, -2:].T, \n",
    "                                            data[model][shape][\"demo\"][0, :, :1000], \n",
    "                                            1.0)\n",
    "        row = shape_idx + 1\n",
    "        col = model_idx + 1\n",
    "\n",
    "        fig.add_trace(go.Scatter(\n",
    "                        x=sample_demo[0, :], \n",
    "                        y=sample_demo[1, :],\n",
    "                        mode='lines', \n",
    "                        name='expert demonstration', \n",
    "                        line=dict(\n",
    "                             color=COLOURS[-5],\n",
    "                             width=4,\n",
    "                             ),\n",
    "                        legendgroup=\"demo\",\n",
    "                        showlegend=True if model_idx == shape_idx == 0 else False,\n",
    "                      ),\n",
    "                      row=row, col=col)\n",
    "    \n",
    "        fig.add_trace(go.Scatter(\n",
    "                    x=sample_predicted[0, :], \n",
    "                    y=sample_predicted[1, :],\n",
    "                    mode='lines', \n",
    "                    name=model, \n",
    "                    line=dict(color=COLOURS[model_idx]),\n",
    "                    fill='tonexty',\n",
    "                    legendgroup=model,\n",
    "                    showlegend=True if shape_idx == 0 else False,\n",
    "                    ),\n",
    "                    row=row, col=col)\n",
    "\n",
    "        fig.update_layout(\n",
    "            font=dict(size=6)\n",
    "        )\n",
    "        \n",
    "        fig.update_xaxes(\n",
    "            title_text=\"x1\", \n",
    "            row=row, \n",
    "            col=col, \n",
    "            gridcolor='lightgrey',\n",
    "            gridwidth=1,\n",
    "        )\n",
    "        \n",
    "        fig.update_yaxes(\n",
    "            title_text=\"x2\", \n",
    "            row=row, \n",
    "            col=col, \n",
    "            gridcolor='lightgrey',\n",
    "            gridwidth=1,\n",
    "        )\n",
    "\n",
    "fig.update_layout(legend=dict(\n",
    "    orientation=\"h\",\n",
    "    yanchor=\"bottom\",\n",
    "    xanchor=\"center\",\n",
    "    x=0.5,\n",
    "    y=-0.1,\n",
    "    itemsizing='trace',\n",
    "))\n",
    "\n",
    "fig.update_annotations(font_size=12)\n",
    "fig.update_layout(\n",
    "    height=900, \n",
    "    width=900, \n",
    "    plot_bgcolor='white',\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb6b3a84-6ad6-4b5d-8146-f7c619cf6bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create individual plots of qualitative results for \"S\" character drawing task.\n",
    "for shape_idx, shape in enumerate([\"Sshape\"]):\n",
    "    for model_idx, model in enumerate([\"ours\", \"esn\", \"ff\", \"ff_ensemble\"]):\n",
    "        \n",
    "        sample_predicted, sample_demo = sample_within_threshold(\n",
    "                                            data[model][shape][\"predicted\"][0, :, -2:].T, \n",
    "                                            data[model][shape][\"demo\"][0, :, :1000], \n",
    "                                            1.0)\n",
    "        \n",
    "        fig = go.Figure()\n",
    "\n",
    "        fig.add_trace(go.Scatter(\n",
    "                        x=sample_demo[0, :], \n",
    "                        y=sample_demo[1, :],\n",
    "                        mode='lines', \n",
    "                        name='expert demonstration', \n",
    "                        line=dict(\n",
    "                            color=COLOURS[-5],\n",
    "                            width=4,\n",
    "                            ),\n",
    "                        legendgroup=\"demo\",\n",
    "                        showlegend=True,\n",
    "                        ))\n",
    "    \n",
    "        fig.add_trace(go.Scatter(\n",
    "                        x=sample_predicted[0, :], \n",
    "                        y=sample_predicted[1, :],\n",
    "                        mode='lines', \n",
    "                        name=model, \n",
    "                        line=dict(color=COLOURS[model_idx]),\n",
    "                        fill='tonexty',\n",
    "                        legendgroup=model,\n",
    "                        showlegend=True,\n",
    "                        ))\n",
    "    \n",
    "        fig.update_xaxes(\n",
    "            title_text=\"x1\", \n",
    "            showline=True, \n",
    "            linewidth=1,\n",
    "            linecolor='black',\n",
    "            gridcolor='lightgrey',\n",
    "            gridwidth=1,\n",
    "        )\n",
    "        \n",
    "        fig.update_yaxes(\n",
    "            title_text=\"x2\", \n",
    "            showline=True, \n",
    "            linewidth=1,\n",
    "            linecolor='black',\n",
    "            gridcolor='lightgrey',\n",
    "            gridwidth=1,\n",
    "        )\n",
    "        \n",
    "        fig.update_layout(legend=dict(\n",
    "            orientation=\"h\",\n",
    "            yanchor=\"bottom\",\n",
    "            xanchor=\"center\",\n",
    "            x=0.5,\n",
    "            y=-0.2,\n",
    "            itemsizing='trace',\n",
    "        ))\n",
    "        \n",
    "        fig.update_layout(\n",
    "            height=600, \n",
    "            width=800, \n",
    "            plot_bgcolor='white',\n",
    "        )\n",
    "        \n",
    "        fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "942e0787-1a2d-4242-a4dd-c53f58cb0d5d",
   "metadata": {},
   "source": [
    "# Multitask Demonstrations\n",
    "\n",
    "In this section, we generate the evaluation metrics for the multi-task character drawing setting. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95662253-8e5a-4aae-a2dd-d7ba489421f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a 2x5 subplot grid (for a total of 10 subplots)\n",
    "def create_multi_plot(predicted_vals, demo_vals, model_idx):\n",
    "    fig = make_subplots(\n",
    "        rows=2, \n",
    "        cols=5,\n",
    "        )\n",
    "    \n",
    "    for idx, i in enumerate(range(1, 70, 7)):\n",
    "        row = (idx // 5) + 1\n",
    "        col = (idx % 5) + 1\n",
    "\n",
    "        sample_predicted, sample_demo = sample_within_threshold(\n",
    "                                    predicted_vals[i, :, -2:].T, \n",
    "                                    demo_vals[i, :, :1000], \n",
    "                                    1.0)\n",
    "        \n",
    "        fig.add_trace(go.Scatter(\n",
    "                            x=sample_demo[0, :], \n",
    "                            y=sample_demo[1, :],\n",
    "                            mode='lines', \n",
    "                            name='expert demonstration',\n",
    "                            legendgroup=\"demo\",\n",
    "                            showlegend=True if idx == 0 else False,\n",
    "                            line=dict(\n",
    "                                   color=COLOURS[-5],\n",
    "                                   width=4,  \n",
    "                                   )),\n",
    "                            row=row, col=col)\n",
    "        \n",
    "        fig.add_trace(go.Scatter(\n",
    "                            x=sample_predicted[0, :], \n",
    "                            y=sample_predicted[1, :],\n",
    "                            mode='lines', \n",
    "                            fill='tonexty',\n",
    "                            name='predicted trajectory', \n",
    "                            legendgroup=\"predicted\",\n",
    "                            showlegend=True if idx == 0 else False,\n",
    "                            line=dict(color=COLOURS[model_idx])),\n",
    "                            row=row, col=col)\n",
    "        \n",
    "        fig.update_xaxes(\n",
    "            title_text=\"x1\", \n",
    "            showgrid=True,\n",
    "            gridwidth=1, \n",
    "            gridcolor='LightGrey', \n",
    "            row=row, \n",
    "            col=col\n",
    "        )\n",
    "        \n",
    "        fig.update_yaxes(\n",
    "            title_text=\"x2\", \n",
    "            showgrid=True,\n",
    "            gridwidth=1, \n",
    "            gridcolor='LightGrey', \n",
    "            row=row, \n",
    "            col=col\n",
    "        )\n",
    "    \n",
    "    # Update layout for the entire figure\n",
    "    fig.update_annotations(font_size=12)\n",
    "    \n",
    "    fig.update_layout(legend=dict(\n",
    "        orientation=\"h\",\n",
    "        yanchor=\"bottom\",\n",
    "        xanchor=\"center\",\n",
    "        x=0.5,\n",
    "        y=-0.2,\n",
    "        itemsizing='trace',\n",
    "    ))\n",
    "    \n",
    "    fig.update_layout(\n",
    "        height=600,  # Adjusted for better visualization\n",
    "        width=1000,\n",
    "        showlegend=True,  # Disable legend for each subplot to avoid clutter\n",
    "        plot_bgcolor='white',\n",
    "        font=dict(size=10)  # Adjust font size if needed\n",
    "    )\n",
    "    \n",
    "    # Display the figure\n",
    "    fig.show()\n",
    "\n",
    "# ours    \n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "    \n",
    "ours_multi_model = load_ours_multi(cfg)\n",
    "ours_multi_demo, ours_multi_predicted, ours_multi_frechet, ours_multi_latency = evaluate_demo_ours_multi(ours_multi_model, cfg)\n",
    "create_multi_plot(ours_multi_predicted, ours_multi_demo, 0)\n",
    "\n",
    "# esn    \n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/esn_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "esn_model_multi = load_esn_multi(cfg)\n",
    "esn_multi_demo, esn_multi_predicted, esn_multi_frechet, esn_multi_latency = evaluate_demo_esn_multi(esn_model_multi, cfg)\n",
    "create_multi_plot(esn_multi_predicted, esn_multi_demo, 1)\n",
    "\n",
    "# feedforward\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "\n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "feedforward_model_multi = load_feedforward_multi(cfg)\n",
    "feedforward_multi_demo, feedforward_multi_predicted, feedforward_multi_frechet, feedforward_multi_latency = evaluate_demo_feedforward_multi(feedforward_model_multi, cfg)\n",
    "create_multi_plot(feedforward_multi_predicted, feedforward_multi_demo, 2)\n",
    "\n",
    "# feedforward + ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward_multi\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "feedforward_model_multi = load_feedforward_multi(cfg)\n",
    "temp_feedforward_multi_demo, temp_feedforward_multi_predicted, temp_feedforward_multi_frechet, temp_feedforward_multi_latency = evaluate_demo_feedforward_multi(feedforward_model_multi, cfg)\n",
    "create_multi_plot(temp_feedforward_multi_predicted, temp_feedforward_multi_demo, 3)\n",
    "\n",
    "\n",
    "data = {\n",
    "    \"ours\": {\n",
    "        \"demo\": ours_multi_demo, \n",
    "        \"predicted\": ours_multi_predicted, \n",
    "        \"frechet\": ours_multi_frechet, \n",
    "        \"latency\": ours_multi_latency,\n",
    "    },\n",
    "    \"esn\": {\n",
    "        \"demo\": esn_multi_demo, \n",
    "        \"predicted\": esn_multi_predicted, \n",
    "        \"frechet\": esn_multi_frechet, \n",
    "        \"latency\": esn_multi_latency,\n",
    "    },\n",
    "    \"ff\": {\n",
    "        \"demo\": feedforward_multi_demo, \n",
    "        \"predicted\": feedforward_multi_predicted, \n",
    "        \"frechet\": feedforward_multi_frechet, \n",
    "        \"latency\": feedforward_multi_latency,\n",
    "    },\n",
    "    \"ff_ensemble\": {\n",
    "        \"demo\": temp_feedforward_multi_demo, \n",
    "        \"predicted\": temp_feedforward_multi_predicted, \n",
    "        \"frechet\": temp_feedforward_multi_frechet, \n",
    "        \"latency\": temp_feedforward_multi_latency,\n",
    "    },\n",
    "}\n",
    "\n",
    "# convert dictionary data to pandas dataframe\n",
    "df = pd.DataFrame(columns=['model', 'frechet_dist', 'latency'])\n",
    "for model_idx, model in enumerate([\"ours\", \"esn\", \"ff\", \"ff_ensemble\"]):\n",
    "        df.loc[len(df)] = [model, np.mean(data[model][\"frechet\"]), np.mean(data[model][\"latency\"])]\n",
    "\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f198eb01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gif_from_predictions(demo_data, predicted_data, colour_idx, title, filename='demo_vs_predicted_multi.gif', duration=100):\n",
    "    \"\"\"\n",
    "    Create a GIF of multi-task character drawing setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Create a list to store the images\n",
    "    images = []\n",
    "\n",
    "    # preprocess data\n",
    "    predicted = []\n",
    "    demo = []\n",
    "    for idx, j in enumerate(range(1, 70, 7)):\n",
    "        # sample up to convergence\n",
    "        sample_predicted, sample_demo = sample_within_threshold(\n",
    "                                predicted_data[j, :, -2:].T, \n",
    "                                demo_data[j, :, :1000], \n",
    "                                1.0)\n",
    "\n",
    "        # repeat where less than 100 frames\n",
    "        if 100 > sample_predicted.shape[1]:\n",
    "            repeats = np.repeat(np.expand_dims(sample_predicted[:,-1], axis=0), 100 - sample_predicted.shape[1], axis=0).T\n",
    "            sample_predicted = np.concatenate([sample_predicted, repeats], axis=1)\n",
    "\n",
    "        predicted.append(sample_predicted)\n",
    "        demo.append(sample_demo)\n",
    "\n",
    "    # Create the frames\n",
    "    for i in range(100):\n",
    "        fig = make_subplots(\n",
    "            rows=2, \n",
    "            cols=5,\n",
    "        )\n",
    "\n",
    "        for idx, j in enumerate(range(10)):\n",
    "            row = (idx // 5) + 1\n",
    "            col = (idx % 5) + 1\n",
    "\n",
    "            # Plot the demo data\n",
    "            fig.add_trace(go.Scatter(\n",
    "                x=demo[j][0, :], \n",
    "                y=demo[j][1, :], \n",
    "                mode='lines', \n",
    "                name='Demo', \n",
    "                line=dict(color=COLOURS[-5], width=2)),\n",
    "                row=row, col=col)\n",
    "\n",
    "            # Plot the predicted data up to the current frame\n",
    "            fig.add_trace(go.Scatter(\n",
    "                x=predicted[j][0, :i],\n",
    "                y=predicted[j][1, :i],\n",
    "                mode='lines', \n",
    "                name='Predicted', \n",
    "                line=dict(color=colour_idx, width=2),\n",
    "                fill='tonexty'),\n",
    "                row=row, col=col)\n",
    "\n",
    "            fig.update_xaxes(\n",
    "                showgrid=False,\n",
    "                zeroline=False, \n",
    "                showticklabels=False,\n",
    "                row=row, \n",
    "                col=col,\n",
    "             )\n",
    "        \n",
    "            fig.update_yaxes(\n",
    "                showgrid=False,\n",
    "                zeroline=False, \n",
    "                showticklabels=False,\n",
    "                row=row, \n",
    "                col=col\n",
    "            )\n",
    "\n",
    "        # Update layout\n",
    "        fig.update_layout(\n",
    "            plot_bgcolor='white',\n",
    "            showlegend=False,\n",
    "            title=dict(\n",
    "                    text=title, \n",
    "                    font=dict(size=60),\n",
    "                    y=0.0,\n",
    "                    x=0.5,\n",
    "                    xanchor='center',\n",
    "                    yanchor='bottom',\n",
    "                    ),\n",
    "        )\n",
    "\n",
    "        # Save the plot to a bytes buffer\n",
    "        buf = io.BytesIO()\n",
    "        fig.write_image(buf, format='png')\n",
    "        buf.seek(0)\n",
    "\n",
    "        # Create an image from the bytes buffer\n",
    "        img = Image.open(buf)\n",
    "        images.append(img)\n",
    "\n",
    "    # Save the images as a gif\n",
    "    images[0].save(filename, save_all=True, append_images=images[1:], duration=duration, loop=0)\n",
    "\n",
    "    # Display the gif using IPython display\n",
    "    display(IPImage(filename=filename))\n",
    "\n",
    "\n",
    "create_gif_from_predictions(ours_multi_demo, ours_multi_predicted, COLOURS[0], 'Ours', filename='demo_vs_predicted_ours_multi.gif', duration=100)\n",
    "create_gif_from_predictions(esn_multi_demo, esn_multi_predicted, COLOURS[1], 'ESN', filename='demo_vs_predicted_esn_multi.gif', duration=100)\n",
    "create_gif_from_predictions(feedforward_multi_demo, feedforward_multi_predicted, COLOURS[2], 'Feedforward', filename='demo_vs_predicted_ff_multi.gif', duration=100)\n",
    "create_gif_from_predictions(temp_feedforward_multi_demo, temp_feedforward_multi_predicted, COLOURS[3], 'Temporal Ensemble', filename='demo_vs_predicted_ff_ensemble_multi.gif', duration=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bf5ef0c-ac03-44b1-80cb-1d324dc9834b",
   "metadata": {},
   "source": [
    "## Frechet Distance by Perturbation Rate\n",
    "\n",
    "Evaluating how Frechet distance varies by the amount of perturbations/noise applied to predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef6a961-52ed-49f6-aca3-470a701db16a",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturbations = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0]\n",
    "\n",
    "# Ours\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "ours_model = load_ours(cfg)\n",
    "ours_dists = []\n",
    "for pert in perturbations:\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = pert\n",
    "    ours_demo, ours_predicted, ours_frechet, ours_latency, _, _ = evaluate_demo_ours(ours_model, cfg)\n",
    "    ours_dists.append(np.mean(ours_frechet))\n",
    "\n",
    "\n",
    "# ESN\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/esn\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "esn_model = load_esn(cfg)\n",
    "esn_dists = []\n",
    "for pert in perturbations:\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = pert\n",
    "    esn_demo, esn_predicted, esn_frechet, esn_latency, _, _ = evaluate_demo_esn(esn_model, cfg)\n",
    "    esn_dists.append(np.mean(esn_frechet))\n",
    "\n",
    "\n",
    "# FeedForward\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_dists = []\n",
    "for pert in perturbations:\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = pert\n",
    "    feedforward_demo, feedforward_predicted, feedforward_frechet, feedforward_latency, _, _ = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "    feedforward_dists.append(np.mean(feedforward_frechet))\n",
    "\n",
    "\n",
    "# FeedForward with Temporal Ensemble\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24  \n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001 \n",
    "\n",
    "feedforward_temp_dists = []\n",
    "for pert in perturbations:\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = pert / 2 # as we ensemble predictions\n",
    "    feedforward_temp_demo, feedforward_temp_predicted, feedforward_temp_frechet, feedforward_temp_latency, _, _ = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "    feedforward_temp_dists.append(np.mean(feedforward_temp_frechet))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b79240",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot Results\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=perturbations,         \n",
    "        y=ours_dists, \n",
    "        mode='lines+markers', \n",
    "        name='ours',\n",
    "        line=dict(color=COLOURS[0]),\n",
    "    )\n",
    ")\n",
    "\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=perturbations, \n",
    "        y=esn_dists, \n",
    "        mode='lines+markers', \n",
    "        name='esn',\n",
    "        line=dict(color=COLOURS[1]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=perturbations, \n",
    "        y=feedforward_dists, \n",
    "        mode='lines+markers', \n",
    "        name='ff',\n",
    "        line=dict(color=COLOURS[2]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=perturbations, \n",
    "        y=feedforward_temp_dists, \n",
    "        mode='lines+markers', \n",
    "        name='ff_ensemble',\n",
    "        line=dict(color=COLOURS[3]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.update_xaxes(\n",
    "    title_text=\"Noise Scale\", \n",
    "    showline=True,\n",
    "    linecolor=\"black\",\n",
    ")\n",
    "            \n",
    "fig.update_yaxes(\n",
    "    title_text=\"Frechet Distance\", \n",
    "    gridcolor='lightgrey',\n",
    "    gridwidth=1,\n",
    "    showline=True,\n",
    "    linecolor=\"black\",\n",
    ")\n",
    "\n",
    "fig.update_layout(\n",
    "    height=600, \n",
    "    width=800, \n",
    "    plot_bgcolor='white',\n",
    ")\n",
    "\n",
    "fig.update_layout(legend=dict(\n",
    "    orientation=\"h\",\n",
    "    yanchor=\"bottom\",\n",
    "    xanchor=\"center\",\n",
    "    x=0.5,\n",
    "    y=-0.2,\n",
    "))\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62b75819-9ba0-4241-af13-d7f9789492c5",
   "metadata": {},
   "source": [
    "## Smoothness with Temporal Ensemble vs Ours\n",
    "\n",
    "Evaluating how much jerk is present in the trajectories generated by various models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a06542-1bb9-4e4d-b288-45a7f10da137",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_48, feedforward_predicted_48, feedforward_frechet_48, feedforward_latency_48, feedforward_jerk_x_48, feedforward_jerk_y_48 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_24, feedforward_predicted_24, feedforward_frechet_24, feedforward_latency_24, feedforward_jerk_x_24, feedforward_jerk_y_24 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 12\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 4\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_12, feedforward_predicted_12, feedforward_frechet_12, feedforward_latency_12, feedforward_jerk_x_12, feedforward_jerk_y_12 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo, ours_predicted, ours_frechet, ours_latency, ours_jerk_x, ours_jerk_y = evaluate_demo_ours(ours_model, cfg)\n",
    "ours_representations = output_representations_ours(ours_model, cfg)\n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/esn\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo, esn_predicted, esn_frechet, esn_latency, esn_jerk_x, esn_jerk_y = evaluate_demo_esn(esn_model, cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78df7ce0-e5c0-4123-8055-8d62846f8fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Box(\n",
    "        y=ours_jerk_y, \n",
    "        name=\"ours\",\n",
    "        line=dict(color=COLOURS[0]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Box(\n",
    "        y=esn_jerk_y, \n",
    "        name=\"ESN\",\n",
    "        line=dict(color=COLOURS[1]),        \n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Box(\n",
    "        y=feedforward_jerk_y_48, \n",
    "        name=\"ff\",\n",
    "        line=dict(color=COLOURS[2]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Box(\n",
    "        y=feedforward_jerk_y_24, \n",
    "        name=\"ff_ensemble\",\n",
    "        line=dict(color=COLOURS[3]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.update_traces(boxpoints=False)\n",
    "\n",
    "fig.update_layout(\n",
    "    plot_bgcolor='white',\n",
    "    height=600, \n",
    "    width=800, \n",
    ")\n",
    "\n",
    "fig.update_xaxes(\n",
    "    title={\n",
    "        \"text\" : \"Model\",\n",
    "    },\n",
    "    showline=True, \n",
    "    linecolor='black',\n",
    "    )\n",
    "\n",
    "fig.update_yaxes(\n",
    "    title={\n",
    "        \"text\" : \"Mean Absolute Jerk\",\n",
    "    },\n",
    "    showline=True, \n",
    "    linecolor='black',\n",
    "    gridcolor='lightgrey',\n",
    "    )\n",
    "\n",
    "fig.update_layout(legend=dict(\n",
    "    orientation=\"h\",\n",
    "    yanchor=\"bottom\",\n",
    "    xanchor=\"center\",\n",
    "    x=0.5,\n",
    "    y=-0.2,\n",
    "))\n",
    "\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94cda2f8-2425-4705-8ce8-b9f0e8a0af1b",
   "metadata": {},
   "source": [
    "## Compare Velocity Profiles\n",
    "\n",
    "Examining the similarity between velocity profiles with and without temporal ensembling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea9a556-d856-40b6-867c-80fdda5dc0d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_48, feedforward_predicted_48, feedforward_frechet_48, feedforward_latency_48, feedforward_jerk_x_48, feedforward_jerk_y_48 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 2\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_24, feedforward_predicted_24, feedforward_frechet_24, feedforward_latency_24, feedforward_jerk_x_24, feedforward_jerk_y_24 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/feedforward\")\n",
    "        cfg = cfg[\"config\"]\n",
    "\n",
    "        # apply noise\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        \n",
    "        # apply ensemble \n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 12\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 4\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_12, feedforward_predicted_12, feedforward_frechet_12, feedforward_latency_12, feedforward_jerk_x_12, feedforward_jerk_y_12 = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "        cfg = compose(config_name=\"config/ours\")\n",
    "        cfg = cfg[\"config\"]\n",
    "        cfg[\"evaluation\"][\"noise_scale\"] = 0.0\n",
    "        cfg[\"architecture\"][\"model_params\"][\"output_dim\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"action_chunks\"] = 24\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 48\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "        cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "ours_model = load_ours(cfg)\n",
    "ours_demo, ours_predicted, ours_frechet, ours_latency, ours_jerk_x, ours_jerk_y = evaluate_demo_ours(ours_model, cfg)\n",
    "ours_representations = output_representations_ours(ours_model, cfg)\n",
    "\n",
    "\n",
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/esn\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "\n",
    "esn_model = load_esn(cfg)\n",
    "esn_demo, esn_predicted, esn_frechet, esn_latency, esn_jerk_x, esn_jerk_y = evaluate_demo_esn(esn_model, cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95bdbcf-e343-4983-a819-58a962a5d453",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dtw import *\n",
    "\n",
    "predicted_48, demo_48 = sample_within_threshold(feedforward_predicted_48[0].T, feedforward_demo_48[0])\n",
    "predicted_24, demo_24 = sample_within_threshold(feedforward_predicted_24[0].T, feedforward_demo_24[0])\n",
    "predicted_12, demo_12 = sample_within_threshold(feedforward_predicted_12[0].T, feedforward_demo_12[0])\n",
    "ours_predicted_sampled, ours_demo_sampled = sample_within_threshold(ours_predicted[0].T, ours_demo[0])\n",
    "esn_predicted_sampled, esn_demo_sampled = sample_within_threshold(esn_predicted[0].T, esn_demo[0])\n",
    "\n",
    "\n",
    "a = np.diff(np.diff(ours_predicted_sampled[0]))\n",
    "b = np.diff(np.diff(ours_demo_sampled[0]))\n",
    "ours_alignment = dtw(a,b)\n",
    "\n",
    "a = np.diff(np.diff(esn_predicted_sampled[0]))\n",
    "b = np.diff(np.diff(esn_demo_sampled[0]))\n",
    "esn_alignment = dtw(a,b)\n",
    "\n",
    "a = np.diff(np.diff(predicted_48[0]))\n",
    "b = np.diff(np.diff(ours_demo_sampled[0]))\n",
    "alignment_48 = dtw(a,b)\n",
    "alignment_48.distance\n",
    "\n",
    "a = np.diff(np.diff(predicted_24[0]))\n",
    "b = np.diff(np.diff(ours_demo_sampled[0]))\n",
    "alignment_24 = dtw(a,b)\n",
    "alignment_24.distance\n",
    "\n",
    "a = np.diff(np.diff(predicted_12[0]))\n",
    "b = np.diff(np.diff(ours_demo_sampled[0]))\n",
    "alignment_12 = dtw(a,b)\n",
    "alignment_12.distance\n",
    "\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        name=\"ours\",\n",
    "        y=[ours_alignment.distance],\n",
    "        marker=dict(\n",
    "            size=15, \n",
    "            color=COLOURS[0],\n",
    "            symbol=\"star-diamond\",\n",
    "            ),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        name=\"esn\",\n",
    "        y=[esn_alignment.distance],\n",
    "        marker=dict(\n",
    "            size=15,\n",
    "            color=COLOURS[1],\n",
    "            symbol=\"star-diamond\",\n",
    "            ),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        name=\"ff_ensemble\",\n",
    "        y=[alignment_48.distance, alignment_24.distance, alignment_12.distance],\n",
    "        line=dict(color=COLOURS[3]),\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        name=\"ff\",\n",
    "        y=[alignment_48.distance],\n",
    "        marker=dict(\n",
    "            size=15,\n",
    "            color=COLOURS[2],\n",
    "            symbol=\"star-diamond\",\n",
    "            ),\n",
    "    )\n",
    ")\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    plot_bgcolor='white',\n",
    "    height=600, \n",
    "    width=800, \n",
    ")\n",
    "\n",
    "fig.update_xaxes(\n",
    "    title={\n",
    "        \"text\" : \"Temporal Ensemble History\",\n",
    "    },\n",
    "    showline=True, \n",
    "    linecolor='black',\n",
    "    tickmode='linear',\n",
    "    dtick=1,\n",
    "    )\n",
    "\n",
    "fig.update_yaxes(\n",
    "    title={\n",
    "        \"text\" : \"Euclidean Distance\",\n",
    "    },\n",
    "    showline=True, \n",
    "    linecolor='black',\n",
    "    gridcolor='lightgrey',\n",
    "    )\n",
    "\n",
    "fig.update_layout(legend=dict(\n",
    "    orientation=\"h\",\n",
    "    yanchor=\"bottom\",\n",
    "    xanchor=\"center\",\n",
    "    x=0.5,\n",
    "    y=-0.2,\n",
    "))\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42eac075-8603-43c6-8ec7-015abf7603cf",
   "metadata": {},
   "source": [
    "# Explore Learnt Representations\n",
    "\n",
    "Try to formalise the changes in representation and qualitative changes in dynamics. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba630f40-7895-433a-a61b-9058c1330c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "import pandas as pd\n",
    "\n",
    "def cosine_similarity(a, b):\n",
    "    # Compute the dot product of the two vectors\n",
    "    dot_product = jnp.dot(a, b)\n",
    "    \n",
    "    # Compute the L2 norms of the vectors\n",
    "    norm_a = jnp.linalg.norm(a)\n",
    "    norm_b = jnp.linalg.norm(b)\n",
    "    \n",
    "    # Compute the cosine similarity\n",
    "    cosine_sim = dot_product / (norm_a * norm_b)\n",
    "    \n",
    "    return cosine_sim\n",
    "\n",
    "# one-to-all\n",
    "cosine_similarity_mapped = jax.vmap(cosine_similarity, in_axes=(None, 0))\n",
    "# feedforward_similarities = cosine_similarity_mapped(feedforward_representations[0,0,:], feedforward_representations[0,:,:])\n",
    "feedforward_similarities = cosine_similarity_mapped(jnp.ones((5000,)), feedforward_representations[0,:,:])\n",
    "# ours_similarities = cosine_similarity_mapped(ours_representations[0,0,:], ours_representations[0,:,:])\n",
    "ours_similarities = cosine_similarity_mapped(jnp.ones((5000,)), ours_representations[0,:,:])\n",
    "\n",
    "df = pd.DataFrame({\n",
    "    \"index\":np.arange(100), \n",
    "    \"feedforward_cosine_similarities\": feedforward_similarities[:100],\n",
    "    \"ours_cosine_similarities\": ours_similarities,\n",
    "})\n",
    "\n",
    "# Melt the DataFrame to make it long-form for plotting multiple lines\n",
    "df_melted = df.melt(id_vars=\"index\", value_vars=[\"feedforward_cosine_similarities\", \"ours_cosine_similarities\"],\n",
    "                    var_name=\"method\", value_name=\"cosine_similarity\")\n",
    "\n",
    "# Plot the lines\n",
    "fig = px.line(df_melted, x=\"index\", y=\"cosine_similarity\", color=\"method\",\n",
    "              title=\"Cosine Similarity of Feedforward and Ours Over Time\")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "# lagged\n",
    "# cosine_similarity_mapped = jax.vmap(cosine_similarity, in_axes=(0, 0))\n",
    "# similarities = cosine_similarity_mapped(ours_representations[0,:-1,:], ours_representations[0,1:,:])\n",
    "# df = pd.DataFrame({\"index\":np.arange(similarities.shape[0]), \"cosine_similarities\": similarities})\n",
    "# fig = px.line(df, x=\"index\", y=\"cosine_similarities\", title=\"Cosine Similarity Adjacent Points\")\n",
    "# fig.show()\n",
    "\n",
    "# # rate of change of cosine similarity across series\n",
    "# df[\"cosine_similarity_deltas\"] = df[\"cosine_similarities\"].diff()\n",
    "# fig = px.line(df, x=\"index\", y=\"cosine_similarity_deltas\", title=\"Cosine Similarity Adjacent Points\")\n",
    "# fig.show()\n",
    "\n",
    "# -----------------------------------------------------------\n",
    "# Compute lagged cosine similarities (cosine similarity between adjacent points) for both feedforward and ours\n",
    "cosine_similarity_mapped = jax.vmap(cosine_similarity, in_axes=(0, 0))\n",
    "feedforward_lagged_similarities = cosine_similarity_mapped(feedforward_representations[0, :-1, :], feedforward_representations[0, 1:, :])\n",
    "ours_lagged_similarities = cosine_similarity_mapped(ours_representations[0, :-1, :], ours_representations[0, 1:, :])\n",
    "\n",
    "# Create DataFrame for lagged cosine similarities\n",
    "df_lagged = pd.DataFrame({\n",
    "    \"index\": np.arange(99),  # Assuming both have the same length\n",
    "    \"feedforward_lagged_cosine_similarities\": feedforward_lagged_similarities[:99],\n",
    "    \"ours_lagged_cosine_similarities\": ours_lagged_similarities\n",
    "})\n",
    "\n",
    "# Compute the rate of change (delta) for both representations\n",
    "df_lagged[\"feedforward_cosine_similarity_deltas\"] = df_lagged[\"feedforward_lagged_cosine_similarities\"].diff()\n",
    "df_lagged[\"ours_cosine_similarity_deltas\"] = df_lagged[\"ours_lagged_cosine_similarities\"].diff()\n",
    "\n",
    "# Melt the DataFrame to make it long-form for plotting multiple lines\n",
    "df_lagged_melted = df_lagged.melt(id_vars=\"index\", \n",
    "                                  value_vars=[\"feedforward_cosine_similarity_deltas\", \"ours_cosine_similarity_deltas\"],\n",
    "                                  var_name=\"method\", value_name=\"cosine_similarity_delta\")\n",
    "\n",
    "# Plot the rate of change for both feedforward and ours\n",
    "fig = px.line(df_lagged_melted, x=\"index\", y=\"cosine_similarity_delta\", color=\"method\",\n",
    "              title=\"Rate of Change in Cosine Similarity Between Adjacent Points (Feedforward vs Ours)\")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ee54a9-7c8c-4ebc-a511-8be41cc83f57",
   "metadata": {},
   "outputs": [],
   "source": [
    "with initialize(version_base=None, config_path=\".\", job_name=\"eval\"):\n",
    "    cfg = compose(config_name=\"config/feedforward\")\n",
    "    cfg = cfg[\"config\"]\n",
    "    cfg[\"dataset\"][\"shape\"] = \"Sshape\"\n",
    "    cfg[\"evaluation\"][\"noise_scale\"] = 0.25\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"num_predicted_actions\"] = 12\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"prediction_horizon\"] = 1\n",
    "    cfg[\"architecture\"][\"inference_params\"][\"ensemble_weight\"] = 0.001\n",
    "\n",
    "feedforward_model = load_feedforward(cfg)\n",
    "feedforward_demo_s, feedforward_predicted_s, feedforward_frechet_s, feedforward_latency_s = evaluate_demo_feedforward(feedforward_model, cfg)\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(go.Scatter(\n",
    "                        x=feedforward_predicted_s[0, :, 0], \n",
    "                        y=feedforward_predicted_s[0, :, 1],\n",
    "                        mode='lines', \n",
    "                        name='model prediction', \n",
    "                        line=dict(color='blue'),\n",
    "                        legendgroup=\"predicted\",\n",
    "                        showlegend=True,\n",
    "                        ))\n",
    "\n",
    "fig.add_trace(go.Scatter(\n",
    "                x=feedforward_demo_s[0, 0, :], \n",
    "                y=feedforward_demo_s[0, 1, :],\n",
    "                mode='lines', \n",
    "                name='expert demonstration', \n",
    "                line=dict(color='red'),\n",
    "                legendgroup=\"demo\",\n",
    "                showlegend=True,\n",
    "              ))\n",
    "\n",
    "fig.update_layout(\n",
    "    font=dict(size=6)\n",
    ")\n",
    "\n",
    "fig.update_xaxes(\n",
    "    title_text=\"x1\", \n",
    "    gridcolor='lightgrey',\n",
    "    gridwidth=1,\n",
    ")\n",
    "\n",
    "fig.update_yaxes(\n",
    "    title_text=\"x2\", \n",
    "    gridcolor='lightgrey',\n",
    "    gridwidth=1,\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a00759a-8984-42d2-8395-825a47fb7f0f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
