{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6091e28a92c93161",
   "metadata": {},
   "source": [
    "# This notebook is used to visualize the results of the model.\n",
    "But first, we need to do the following:\n",
    "1. Imports \n",
    "2. Setting up the hydra configs\n",
    "3. Create a copy forward function of the model - in a version that returns all the plot values\n",
    "4. Then, the visualizations are shown. This section is explained in more details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a84bc2f1-bb79-4de9-a883-6e23c20e9db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import subprocess\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "os.chdir('..')\n",
    "\n",
    "from src.multibody_sim.utils import load_data\n",
    "from matplotlib.animation import FuncAnimation\n",
    "from IPython.display import HTML\n",
    "from src.multibody_sim.equations_of_motion import *\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "771a0700-edb5-492a-ae92-9cab373455c6",
   "metadata": {},
   "source": [
    "# Setup the hydra configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e186140-5bbc-499e-9715-f09525249b1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Task 1: get some prediction from a checkpoint\n",
    "import omegaconf \n",
    "import hydra\n",
    "from src.utils import hydra_resolver, reconstruction\n",
    "from src import utils as ut\n",
    "predict_run_dir = \"logs/BASELINE\"\n",
    "checkpoint_path = \"tensorboard/tb_logs/all_models/checkpoints/\"\n",
    "checkpoint = \"last.ckpt\"\n",
    "cfg = omegaconf.OmegaConf.load(f\"{predict_run_dir}/.hydra/config.yaml\")\n",
    "import src.utils.utils as utt\n",
    "utt.init_global_cfg(cfg)\n",
    "log = ut.get_pylogger(__name__)\n",
    "hydra_resolver.register_resolvers(log)\n",
    "\n",
    "from src.models.sspinn_module import SspinnLitModule\n",
    "from src.models.components.lstm_net import LSTMNet\n",
    "from src.kinematics.kinematics_2d import *\n",
    "from src.multibody_sim.ground_contact import *\n",
    "from src.multibody_sim.equations_of_motion import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0de064-6d53-42e5-96c9-ebcd7d38899c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = omegaconf.OmegaConf.load(f\"{predict_run_dir}/.hydra/config.yaml\")\n",
    "cfg.trainer.accelerator = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d06480c-04d2-4822-bd48-ed61b3cda2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latest_checkpoint():\n",
    "    path = \"logs/results-10.05_2/train/runs/\"\n",
    "    def get_newest_folder(path):\n",
    "        return max([os.path.join(path,d) for d in os.listdir(path) if os.path.isdir(os.path.join(path,d))], key=os.path.getmtime)\n",
    "    def get_newest_checkpoint(pred_dir, cp_path):\n",
    "        path2 = os.path.join(pred_dir,cp_path)\n",
    "        return max([os.path.join(path2,d) for d in os.listdir(path2) if d.endswith('.ckpt')], key=os.path.getmtime) \n",
    "    \n",
    "    pred_dir = get_newest_folder(path)\n",
    "    ckp = get_newest_checkpoint(pred_dir, checkpoint_path)\n",
    "    cfg = omegaconf.OmegaConf.load(f\"{pred_dir}/.hydra/config.yaml\")\n",
    "    cfg.trainer.accelerator = 'cpu'\n",
    "    return pred_dir, ckp, cfg\n",
    "def get_model(predict_run_dir, checkpoint_,  cfg):\n",
    "    datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "    global predict_dl\n",
    "    predict_dl = datamodule.predict_dataloader()\n",
    "\n",
    "    model = hydra.utils.instantiate(cfg.model)\n",
    "    try: \n",
    "        ckp = torch.load(f\"{predict_run_dir}/{checkpoint_path}/{checkpoint_}\")\n",
    "    except:\n",
    "        print(checkpoint_)\n",
    "        ckp = torch.load(f\"{checkpoint_}\",map_location=torch.device('cpu'))\n",
    "        \n",
    "    model.load_state_dict(ckp['state_dict'])\n",
    "    module = SspinnLitModule(model, None, None, torch.nn.functional.mse_loss, cfg.model.input_noise, cfg.model.loss_weights, cfg.model.input_variables, cfg.model.estimated_variables, cfg.model.loss_d_variables)\n",
    "    return datamodule, module, cfg\n",
    "    \n",
    "def get_latest_model():\n",
    "    predict_run_dir, checkpoint,  cfg = get_latest_checkpoint()\n",
    "    return get_model(predict_run_dir, checkpoint,  cfg)\n",
    "\n",
    "\n",
    "def get_specific_model(folder_path):\n",
    "    # Get the checkpoint:\n",
    "    path = \"logs/BASELINE\"\n",
    "    pred_dir = os.path.normpath(path + folder_path)\n",
    "    def get_newest_checkpoint(pred_dir, cp_path):\n",
    "        path2 = os.path.join(pred_dir,cp_path)\n",
    "        print(path2)\n",
    "        return sorted([os.path.join(path2,d) for d in os.listdir(path2) if d.endswith('.ckpt')], key=os.path.getmtime) \n",
    "    list_ = get_newest_checkpoint(pred_dir, checkpoint_path)\n",
    "    cfg = omegaconf.OmegaConf.load(f\"{pred_dir}/.hydra/config.yaml\")\n",
    "\n",
    "    ckp = list_[-2]\n",
    "    return get_model(pred_dir, ckp,  cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5998cb80-1d02-417b-b512-dad9de604f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "#How to get the data\n",
    "def get_data(idx):#\n",
    "    datasample, start_idx = predict_dl.dataset.__getitem__(idx, True)\n",
    "    for key in datasample:\n",
    "        datasample[key] = datasample[key].unsqueeze(0)\n",
    "    return datasample, start_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53f0c4c8-82b8-454b-a6cf-c7d05c35b38a",
   "metadata": {},
   "source": [
    "# Forward of model fn and plots "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c2b9bb9-0bef-4a1c-949b-c302c576739d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_of_model(module, batch, cfg, do_physics = False):\n",
    "    # Rebuild my own forward\n",
    "\n",
    "    x = torch.cat([batch[var] for var in module.input_variables], dim=-1)\n",
    "    y_hat = module.model.model.forward(x)\n",
    "    estimation = {}\n",
    "    prev_idx = 0\n",
    "    # split the output into the estimated variables\n",
    "    for k, v in cfg.model.estimated_variables.items():\n",
    "        estimation[k] = y_hat[:, :, prev_idx : prev_idx + len(v)]\n",
    "        prev_idx += len(v)\n",
    "    estimation[\"IK_data\"][:,:,0] = torch.cumsum(estimation['IK_data'][:,:,1],dim=1)/cfg.fps\n",
    "    IK_pred = estimation[\"IK_data\"]\n",
    "\n",
    "    loss_l = module.calculate_limit_losses(IK_pred)\n",
    "\n",
    "    reconstructed_data, imu_data, gc_positions, ankle_imu_globals = global_kinematics(\n",
    "        IK_pred,\n",
    "        batch[\"body_constants\"],\n",
    "        batch[\"imu_offsets\"],\n",
    "        batch[\"imu_rotations\"],\n",
    "        batch[\"ground_contact_model\"],\n",
    "        cfg,\n",
    "        device=module.device,\n",
    "    )\n",
    "\n",
    "        # get the ankle speed from the kinematics\n",
    "    try:\n",
    "        v_ankle = torch.cat([reconstructed_data['ankle_r'][:, :, 1:2], reconstructed_data['ankle_l'][:, :, 1:2]], dim=-1)\n",
    "        v_ankle_sim = torch.cat([ankle_imu_globals['ankle_r'][:, :, 1:2], ankle_imu_globals['ankle_l'][:, :, 1:2]], dim=-1)\n",
    "        loss_foot_speed = torch.abs(v_ankle_sim - v_ankle)\n",
    "        loss_foot_speed = torch.relu(loss_foot_speed - 0.25 * torch.max(batch['speed'], keepdim=True, dim=-2).values) # 25% of the speed as error is allowed -> no exact guidance\n",
    "        loss_foot_speed = module.criterion(loss_foot_speed, torch.zeros_like(loss_foot_speed, device=module.device))\n",
    "\n",
    "        ## NN ground contact model loss, only use if wgc > 0 (is specified)\n",
    "        if module.loss_weights.wgc > 0 and cfg.gc_ss_level == 'cps':\n",
    "            loss_gc, gc_positions = module.calculate_gc_loss(estimation, gc_positions, reconstructed_data, cfg.fps, cfg.euler)\n",
    "        elif module.loss_weights.wgc > 0 and cfg.gc_ss_level == 'ankle':\n",
    "            loss_gc, gc_positions, ankle_globals, learned_mu = module.calculate_ankle_loss(estimation, batch, reconstructed_data, cfg.fps, cfg.euler)\n",
    "            if cfg.ankle_imu_position == 'ss_foot':\n",
    "                # reconstruct the ankle IMU data based on the self-supervised foot position\n",
    "                for foot, imu_idx in zip(['ankle_r','ankle_l'],[3,6]):\n",
    "                    ankle_globals[foot][:,:,2::3] = reconstructed_data[foot][:,:,2::3]\n",
    "                    imu_ = get_joint_position(ankle_globals[foot], torch.zeros_like(ankle_globals[foot][:,:,-3:]), batch['imu_offsets'][:,:,2*imu_idx:2*imu_idx+2], module.device)\n",
    "                    imu_data[:,:,3*imu_idx:3*imu_idx+3] = to_local_coordinates_imu(imu_, batch['body_constants'][:,:,-1], module.device) # Same as in global reconstruction\n",
    "        \n",
    "        else:\n",
    "            loss_gc = 0\n",
    "\n",
    "        loss_r = module.criterion(imu_data, batch[\"IMU_data\"])\n",
    "\n",
    "        loss_t = module.calculate_time_loss(IK_pred, cfg.fps, cfg.euler)\n",
    "\n",
    "        # Get GC model forces and CoP\n",
    "        try: \n",
    "            if cfg.gc_model == 'sliding':\n",
    "                grf, moment, grf_cps, cp_mix_ = sliding_contact_point(\n",
    "                    gc_positions,\n",
    "                    batch[\"ground_contact_model\"],\n",
    "                    gc_model,\n",
    "                    cfg,\n",
    "                    device = module.device,\n",
    "                    get_moments=True,\n",
    "                    kinematics=reconstructed_data,\n",
    "                    learned_mu=learned_mu\n",
    "                )\n",
    "            else:\n",
    "                grf, moment, grf_cps = contact_points_2d(\n",
    "                    gc_positions,\n",
    "                    batch[\"ground_contact_model\"],\n",
    "                    gc_model,\n",
    "                    cfg,\n",
    "                    device = \"cpu\",\n",
    "                    get_moments=True,\n",
    "                    kinematics=reconstructed_data\n",
    "                )\n",
    "        except Exception as e:\n",
    "            print('Warning: Fallback to contact_points 2D version because of', e)\n",
    "            grf, moment, grf_cps = contact_points_2d(\n",
    "                gc_positions,\n",
    "                batch[\"ground_contact_model\"],\n",
    "                gc_model,\n",
    "                cfg,\n",
    "                device = \"cpu\",\n",
    "                get_moments=True,\n",
    "                kinematics=reconstructed_data\n",
    "            )\n",
    "\n",
    "        # Kane's Loss        \n",
    "        loss_grf = module.calculate_grf_bounds_loss(grf)\n",
    "\n",
    "        if do_physics:\n",
    "            loss_p = module.calculate_physics_loss(\n",
    "                IK_pred,\n",
    "                estimation[\"torques\"],\n",
    "                grf,\n",
    "                moment,\n",
    "                batch[\"body_constants\"],\n",
    "            )\n",
    "        else:\n",
    "            loss_p = 0\n",
    "    \n",
    "        loss = (\n",
    "            + loss_r * module.loss_weights.wr\n",
    "            + loss_t * module.loss_weights.wt\n",
    "            + loss_p * module.loss_weights.wp\n",
    "        )\n",
    "\n",
    "        return estimation, reconstructed_data, imu_data, gc_positions, grf_cps, moment, loss_p, grf, ankle_imu_globals, loss_foot_speed\n",
    "    except Exception as e:\n",
    "        print('Warning: Fallback to old version because of', e)\n",
    "        return estimation, reconstructed_data, imu_data, gc_positions, None, None, None, None, None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4043518-9668-44e7-9d02-eb4f75ebec61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_imu_pre_after(imu_reconstructed, imu_raw,cfg):\n",
    "    # Define time vector\n",
    "    imu_reconstructed = imu_reconstructed.cpu().clone().detach()\n",
    "    x_ = np.linspace(0,imu_data.shape[1]/cfg.fps,imu_data.shape[1])\n",
    "    plt.figure(figsize=(5/3*imu_data.shape[-1], 15))\n",
    "    count = 1\n",
    "    c = 0\n",
    "    ikey = cfg.datamodule.dataset_variables.IMU_data\n",
    "    for idx, imu in enumerate(ikey):\n",
    "        plt.subplot(4, 2, count)\n",
    "        colors = ['r','g','b']\n",
    "        plt.plot(x_,imu_reconstructed[0, :, idx],colors[c%3])\n",
    "        plt.plot(x_,imu_raw[0, :, idx],f'{colors[c%3]}--')\n",
    "        plt.xlabel('time in [s]')\n",
    "        plt.ylabel('signal in [m/s^2 | 1/s]')\n",
    "        plt.grid()\n",
    "        c+=1\n",
    "        if c%3 == 0:\n",
    "            plt.legend([f'{ikey[idx-2][4:]} sim',f'{ikey[idx-2][4:]} raw',f'{ikey[idx-1][4:]} sim',f'{ikey[idx-1][4:]} raw',f'{imu[4:]} sim',f'{imu[4:]} raw'])\n",
    "            count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48358daa-df6f-420c-b22a-c9bc47bbe2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_kinematics(IK_data_sim, IK_data_gt, start_idx, cfg, offset):\n",
    "    IK_data_sim = IK_data_sim.cpu().detach().numpy()\n",
    "    plt.figure(figsize=(15, 5))\n",
    "    x_ = np.linspace(0,IK_data_sim.shape[1]/cfg.fps,IK_data_sim.shape[1])\n",
    "    siglen = IK_data_sim.shape[1]\n",
    "    count = 1\n",
    "    for idx, ik in enumerate(cfg.datamodule.dataset_variables.IK_data):\n",
    "        if idx % 3 != 0:\n",
    "            continue\n",
    "        plt.subplot(3,3, count)\n",
    "        if idx > 0:\n",
    "            plt.plot(x_,IK_data_sim[0, :, idx])\n",
    "        if IK_data_gt is not None:\n",
    "            plt.plot(x_,IK_data_gt[0, start_idx+offset:start_idx+siglen+offset, idx])\n",
    "        plt.xlabel('time in [s]')\n",
    "        plt.ylabel('signal in [m | rad]')\n",
    "        if idx == 0:\n",
    "            plt.plot(x_,IK_data_sim[0, :, 1])\n",
    "            plt.legend(['v_sim (m/s)'])\n",
    "        else:\n",
    "            plt.legend([f'{ik} sim',f'{ik} raw'])\n",
    "        plt.grid()\n",
    "        count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d92d8328-5257-4504-be11-bd1f738ebe61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_kinetics(torques):\n",
    "    torques = torques.cpu().detach().numpy()\n",
    "    x_ = np.linspace(0,torques.shape[1]/cfg.fps,torques.shape[1])\n",
    "    plt.figure(figsize=(15, 5))\n",
    "    for idx, ik in enumerate(cfg.model.estimated_variables.torques):\n",
    "        plt.subplot(2,3, idx+1)\n",
    "        plt.plot(x_,torques[0, :, idx])\n",
    "        plt.legend([f'{ik}'])\n",
    "        plt.grid()\n",
    "        plt.xlabel('time in [s]')\n",
    "        plt.ylabel('torque in Nm/kg(BW)')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5817723-16d8-44d8-a942-33a6e6829991",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grf(grf,cop,cfg):\n",
    "    grf = grf.cpu().detach().numpy()\n",
    "    grf = grf / 9.81 # in BW\n",
    "    if grf.shape[-1] == 8:\n",
    "        grf_ = grf[:,:,[0,1,4,5]] + grf[:,:,[2,3,6,7]]\n",
    "    else:\n",
    "        grf_ = grf\n",
    "    plt.figure(figsize=(10, 5))\n",
    "    x_ = np.linspace(0,grf.shape[1]/cfg.fps,grf.shape[1])\n",
    "    grf_names = ['grf_x_r','grf_y_r','grf_x_l','grf_y_l']\n",
    "    idx_heel = [0,1,4,5]\n",
    "    idx_toe = [2,3,6,7]\n",
    "    for i in range(4):\n",
    "        plt.subplot(2,2, i+1)\n",
    "        plt.plot(grf_[0,:,i])\n",
    "        if grf.shape[-1] == 8:\n",
    "            plt.plot(grf[0,:,idx_heel[i]],'k--')\n",
    "            plt.plot(grf[0,:,idx_toe[i]],'k-.')\n",
    "            plt.legend([grf_names[i],'heel','toe'])\n",
    "\n",
    "        plt.xlabel('time in [s]')\n",
    "        plt.ylabel('signal in [BW]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d596c73-7d68-4918-b8af-792a872c9bd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_kanesloss(loss_k, pow = 1):\n",
    "    loss_k = loss_k.detach().numpy()\n",
    "    fig = plt.figure(figsize=(20, 5))\n",
    "    count = 1\n",
    "    fig.add_subplot(1,3,count)\n",
    "    for idx, col in enumerate(['pelvis_x', 'pelvis_y','pelvis_a']):\n",
    "        line, = plt.plot(loss_k[0,:,idx]**pow)\n",
    "        line.set_label(col)\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "    plt.xlabel('timestep')\n",
    "    plt.ylabel('loss in [N/kg(BW)]/[Nm/kg(BW)]')\n",
    "    \n",
    "    fig.add_subplot(1,3,count+1)\n",
    "    for idx, col in enumerate(['hip_r','knee_r','ankle_r']):\n",
    "        line, = plt.plot(loss_k[0,:,idx+3]**pow)\n",
    "        line.set_label(col)\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "    plt.xlabel('timestep')\n",
    "    plt.ylabel('loss in [Nm/kg(BW)]')\n",
    "    fig.add_subplot(1,3,count+2)\n",
    "    for idx, col in enumerate(['hip_l','knee_l','ankle_l']):\n",
    "        line, = plt.plot(loss_k[0,:,idx+6]**pow)\n",
    "        line.set_label(col)\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "    plt.xlabel('timestep')\n",
    "    plt.ylabel('loss in [Nm/kg(BW)]')\n",
    "                \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6427751-2b21-4260-8649-4c533e0d10d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_euler(IK_data):\n",
    "    a = IK_data.detach().cpu()\n",
    "    fig = plt.figure(figsize = (20,10))\n",
    "    count = 1\n",
    "    \n",
    "    for i in range(9):\n",
    "        fig.add_subplot(3,3,count)\n",
    "\n",
    "        if i >= 0:\n",
    "            var = torch.std(a[0,:,3*i+1])\n",
    "            plt.plot(a[0,:,3*i].diff(1)*100/var,'k-')\n",
    "            plt.plot(a[0,:,3*i+1]/var,'k--')\n",
    "\n",
    "        var = torch.std(a[0,:,3*i+2])\n",
    "\n",
    "        plt.plot(a[0,:,3*i+1].diff(1)*100/var,'r-')\n",
    "        plt.plot(a[0,:,3*i+2]/var,'r--')\n",
    "        count += 1\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe5caf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_zones(trigger_value):\n",
    "    cases =  {7:\"0.9-1.0\",\n",
    "        8:\"1.2-1.4\",\n",
    "        9:\"1.8-2.0\",\n",
    "        10:\"3.0-3.3\",\n",
    "        11:\"3.9-4.1\",\n",
    "        12:\"4.7-4.9\"}\n",
    "    return cases[trigger_value]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ebed4cb-835d-46cc-be61-0c83198a7eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.plot_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f50ac9-0c75-4c7e-a8ce-2d1aaa135776",
   "metadata": {},
   "source": [
    "# All visualizations are shown here "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9494489a838723f1",
   "metadata": {},
   "source": [
    "Load the latest - or a specific model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fae01756-4890-4a5b-8c1e-859137368eb4",
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load a specific model\n",
    "baseline = ''\n",
    "datamodule, module, cfg = get_specific_model(baseline)#\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f36d9c24105e5d3b",
   "metadata": {},
   "source": [
    "Then, we can load a specific data sample and run the forward function of the model - set a different index to see different samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "289a4332-4754-47ae-b2d6-95d4bdc03e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 450\n",
    "idx = 3\n",
    "datasample, start_idx = get_data(idx)\n",
    "estimation, reconstructed_data, imu_data, gc_positions, grf, moments, loss_k, grf_, ankle_imu_gl, lfs = forward_of_model(module, datasample, cfg, False)\n",
    "estimation['IK_data'][:,:,0] = torch.cumsum(estimation['IK_data'][:,:,1],dim=1)/cfg.fps\n",
    "#IK_data_gt = get_data_IK(predict_dl, idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb8917468f78689",
   "metadata": {},
   "source": [
    "Plot the recorded and  simulated IMU data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbb71ba8-78d3-4109-a14e-9402b3232fc7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_imu_pre_after(imu_data, datasample['IMU_data'], cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d86f8ee581006b3",
   "metadata": {},
   "source": [
    "Plot the kinematics over the full sequence - we print the speed of the foot in the title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2217754a-e7d5-4d88-a94a-507ef5ed69ee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_kinematics(estimation['IK_data'], None, start_idx, cfg, 19)\n",
    "print(f\"Speed: {torch.mean(torch.abs(estimation['IK_data'][0,:,1])):.2f} m/s, target zone: {target_zones(predict_dl.dataset.imu_data.hastrigger.iloc[idx])} m/s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a111c22e1c8ba3b3",
   "metadata": {},
   "source": [
    "Stick figure representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6116a336-45b5-4877-92b1-5ecb5c1fe8c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_stick_2d_data(reconstructed_data, gc_positions,0,256,8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7be4e81d25c710cc",
   "metadata": {},
   "source": [
    "Finally, an animation of the stick figure. By setting subtrans to true, the stick figure keeps a fixed position in the middle of the screen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17b476c-73f9-4f18-bc1a-9a8a1954b5b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "animate_stick_2d(reconstructed_data, gc_positions, 0, 256, subtrans = True, save = False, fps = 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee7945b7bb2d6128",
   "metadata": {},
   "source": [
    "### Further Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b72bf03-1e15-47a7-9dd0-88f54cf00cb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GRF\n",
    "plot_grf(grf,None,cfg)\n",
    "grf[0,:,::2].sum()/256/9.81, grf[0,:,1::2].sum()/256/9.81"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c335fddf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Foot-Speed\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "for i in range(2): \n",
    "    fig.add_subplot(1,2,i+1)\n",
    "    plt.plot(datasample['speed'][0,:,i].detach().cpu().numpy(),'k')\n",
    "    m = np.max(datasample['speed'][0,:,i].detach().cpu().numpy())\n",
    "    mmax = 0.25\n",
    "    plt.fill_between(np.arange(256),datasample['speed'][0,:,i].detach().cpu().numpy()-mmax*m,datasample['speed'][0,:,i].detach().cpu().numpy()+mmax*m, color='k', alpha=0.2)\n",
    "    plt.plot(ankle_imu_gl['ankle_r'][0,:,1].detach().cpu().numpy() if i == 0 else ankle_imu_gl['ankle_l'][0,:,1].detach().cpu().numpy(),'r')\n",
    "    plt.title(f'foot-IMU speed {\"right\" if i == 0 else \"left\"}')\n",
    "    plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99b5b4a8-5908-4290-8526-6f26558583d3",
   "metadata": {
    "jupyter": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Temporal consistency\n",
    "plot_euler(estimation['IK_data'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3e9789c-1eea-4785-9802-8bfaa02014eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Kane's Loss\n",
    "K_loss = Kmatrix_loss_moveEst(estimation['IK_data'],estimation['torques'],grf_,moments,datasample['body_constants'],1,device='cpu')\n",
    "plot_kanesloss(K_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed4c0c8-4700-45f4-b6c7-bcaa40950022",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Joint torques\n",
    "plot_kinetics(estimation['torques'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05976ab7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pinn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
