{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook is used to evaluate the models on the Dorschky2024 dataset.\n",
    "When running this notebook, the following steps are performed:\n",
    "1. Load the data from the Dorschky2024 dataset\n",
    "2. Load the models from the latest run\n",
    "3. Evaluate the models on the Dorschky2024 dataset\n",
    "You will find the results in the `eval_dataframe` dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_data = 'data/dorschky2024'\n",
    "path_to_data = '/Users/markusgambietz/Downloads/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import subprocess\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "os.chdir('..')\n",
    "\n",
    "from src.multibody_sim.utils import load_data\n",
    "from matplotlib.animation import FuncAnimation\n",
    "from IPython.display import HTML\n",
    "from src.multibody_sim.equations_of_motion import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions to load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Task 1: get some prediction from a checkpoint\n",
    "import omegaconf \n",
    "import hydra\n",
    "from src.utils import hydra_resolver, reconstruction\n",
    "from src import utils as ut\n",
    "predict_run_dir = \"logs/BASELINE\"\n",
    "checkpoint_path = \"tensorboard/tb_logs/all_models/checkpoints/\"\n",
    "checkpoint = \"last.ckpt\"\n",
    "cfg = omegaconf.OmegaConf.load(f\"{predict_run_dir}/.hydra/config.yaml\")\n",
    "import src.utils.utils as utt\n",
    "utt.init_global_cfg(cfg)\n",
    "log = ut.get_pylogger(__name__)\n",
    "hydra_resolver.register_resolvers(log)\n",
    "from src.models.sspinn_module import SspinnLitModule\n",
    "from src.models.components.lstm_net import LSTMNet\n",
    "from src.kinematics.kinematics_2d import *\n",
    "from src.multibody_sim.ground_contact import *\n",
    "from src.multibody_sim.equations_of_motion import *\n",
    "cfg = omegaconf.OmegaConf.load(f\"{predict_run_dir}/.hydra/config.yaml\")\n",
    "cfg.trainer.accelerator = device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latest_checkpoint():\n",
    "    path = \"logs/results-10.05_2/train/runs/\"\n",
    "    def get_newest_folder(path):\n",
    "        return max([os.path.join(path,d) for d in os.listdir(path) if os.path.isdir(os.path.join(path,d))], key=os.path.getmtime)\n",
    "    def get_newest_checkpoint(pred_dir, cp_path):\n",
    "        path2 = os.path.join(pred_dir,cp_path)\n",
    "        return max([os.path.join(path2,d) for d in os.listdir(path2) if d.endswith('.ckpt')], key=os.path.getmtime) \n",
    "    \n",
    "    pred_dir = get_newest_folder(path)\n",
    "    ckp = get_newest_checkpoint(pred_dir, checkpoint_path)\n",
    "    cfg = omegaconf.OmegaConf.load(f\"{pred_dir}/.hydra/config.yaml\")\n",
    "    cfg.trainer.accelerator = 'cpu'\n",
    "    return pred_dir, ckp, cfg\n",
    "def get_model(predict_run_dir, checkpoint_,  cfg):\n",
    "    datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "    global predict_dl\n",
    "    predict_dl = datamodule.predict_dataloader()\n",
    "\n",
    "    model = hydra.utils.instantiate(cfg.model)\n",
    "    try: \n",
    "        ckp = torch.load(f\"{predict_run_dir}/{checkpoint_path}/{checkpoint_}\")\n",
    "    except:\n",
    "        print(checkpoint_)\n",
    "        ckp = torch.load(f\"{checkpoint_}\",map_location=torch.device('cpu'))\n",
    "        \n",
    "    model.load_state_dict(ckp['state_dict'])\n",
    "    module = SspinnLitModule(model, None, None, torch.nn.functional.mse_loss, cfg.model.input_noise, cfg.model.loss_weights, cfg.model.input_variables, cfg.model.estimated_variables, cfg.model.loss_d_variables)\n",
    "    return datamodule, module, cfg\n",
    "    \n",
    "def get_latest_model():\n",
    "    predict_run_dir, checkpoint,  cfg = get_latest_checkpoint()\n",
    "    return get_model(predict_run_dir, checkpoint,  cfg)\n",
    "\n",
    "\n",
    "def get_specific_model(folder_path):\n",
    "    # Get the checkpoint:\n",
    "    path = \"logs/\"\n",
    "    pred_dir = os.path.normpath(path + folder_path)\n",
    "    def get_newest_checkpoint(pred_dir, cp_path):\n",
    "        path2 = os.path.join(pred_dir,cp_path)\n",
    "        print(path2)\n",
    "        return sorted([os.path.join(path2,d) for d in os.listdir(path2) if d.endswith('.ckpt')], key=os.path.getmtime) \n",
    "    list_ = get_newest_checkpoint(pred_dir, checkpoint_path)\n",
    "    cfg = omegaconf.OmegaConf.load(f\"{pred_dir}/.hydra/config.yaml\")\n",
    "\n",
    "    ckp = list_[-2]\n",
    "    return get_model(pred_dir, ckp,  cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "def load_constants_pkl(module, folder):    \n",
    "    file_ = os.path.join(\"logs/results-10.05_2/train/\",folder,\"tensorboard/tb_logs/all_models/constants.pkl\")\n",
    "    with open(file_, 'rb') as f:\n",
    "        data = pickle.load(f)\n",
    "    data = {k: torch.from_numpy(v).to(device) for k, v in data.items()}\n",
    "    module.constants = data\n",
    "    print('pkl constants loaded succesfully')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#How to get the data\n",
    "datapath = 'data/dorschky2024'\n",
    "from src.datamodules.components import dorschky2024_evaluate_dataset\n",
    "dataset = dorschky2024_evaluate_dataset.Dorschky2024EvalDataset(datapath)\n",
    "no_trials = len(dataset)\n",
    "def get_data(idx):#\n",
    "    datasample, start_idx = dataset.__getitem__(idx, get_metadata=True)\n",
    "    for key in datasample:\n",
    "        datasample[key] = datasample[key].unsqueeze(0).to(device)\n",
    "    return datasample, start_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forward function of our module (backwards-compatible)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_of_model(module, batch, cfg, do_physics = False):\n",
    "    # Rebuild my own forward\n",
    "\n",
    "    x = torch.cat([batch[var] for var in module.input_variables], dim=-1)\n",
    "    y_hat = module.model.model.forward(x)\n",
    "    estimation = {}\n",
    "    prev_idx = 0\n",
    "    # split the output into the estimated variables\n",
    "    for k, v in cfg.model.estimated_variables.items():\n",
    "        estimation[k] = y_hat[:, :, prev_idx : prev_idx + len(v)]\n",
    "        prev_idx += len(v)\n",
    "    estimation[\"IK_data\"][:,:,0] = torch.cumsum(estimation['IK_data'][:,:,1],dim=1)/cfg.fps\n",
    "    IK_pred = estimation[\"IK_data\"]\n",
    "\n",
    "    loss_l = module.calculate_limit_losses(IK_pred)\n",
    "\n",
    "    reconstructed_data, imu_data, gc_positions, ankle_imu_globals = global_kinematics(\n",
    "        IK_pred,\n",
    "        batch[\"body_constants\"],\n",
    "        batch[\"imu_offsets\"],\n",
    "        batch[\"imu_rotations\"],\n",
    "        batch[\"ground_contact_model\"],\n",
    "        cfg,\n",
    "        device=module.device,\n",
    "    )\n",
    "\n",
    "        # get the ankle speed from the kinematics\n",
    "    try:\n",
    "        v_ankle = torch.cat([reconstructed_data['ankle_r'][:, :, 1:2], reconstructed_data['ankle_l'][:, :, 1:2]], dim=-1)\n",
    "        v_ankle_sim = torch.cat([ankle_imu_globals['ankle_r'][:, :, 1:2], ankle_imu_globals['ankle_l'][:, :, 1:2]], dim=-1)\n",
    "        loss_foot_speed = torch.abs(v_ankle_sim - v_ankle)\n",
    "        loss_foot_speed = torch.relu(loss_foot_speed - 0.25 * torch.max(batch['speed'], keepdim=True, dim=-2).values) # 25% of the speed as error is allowed -> no exact guidance\n",
    "        loss_foot_speed = module.criterion(loss_foot_speed, torch.zeros_like(loss_foot_speed, device=module.device))\n",
    "\n",
    "        ## NN ground contact model loss, only use if wgc > 0 (is specified)\n",
    "        if module.loss_weights.wgc > 0 and cfg.gc_ss_level == 'cps':\n",
    "            loss_gc, gc_positions = module.calculate_gc_loss(estimation, gc_positions, reconstructed_data, cfg.fps, cfg.euler)\n",
    "        elif module.loss_weights.wgc > 0 and cfg.gc_ss_level == 'ankle':\n",
    "            loss_gc, gc_positions, ankle_globals, learned_mu = module.calculate_ankle_loss(estimation, batch, reconstructed_data, cfg.fps, cfg.euler)\n",
    "            if cfg.ankle_imu_position == 'ss_foot':\n",
    "                # reconstruct the ankle IMU data based on the self-supervised foot position\n",
    "                for foot, imu_idx in zip(['ankle_r','ankle_l'],[3,6]):\n",
    "                    ankle_globals[foot][:,:,2::3] = reconstructed_data[foot][:,:,2::3]\n",
    "                    imu_ = get_joint_position(ankle_globals[foot], torch.zeros_like(ankle_globals[foot][:,:,-3:]), batch['imu_offsets'][:,:,2*imu_idx:2*imu_idx+2], module.device)\n",
    "                    imu_data[:,:,3*imu_idx:3*imu_idx+3] = to_local_coordinates_imu(imu_, batch['body_constants'][:,:,-1], module.device) # Same as in global reconstruction\n",
    "        \n",
    "        else:\n",
    "            loss_gc = 0\n",
    "\n",
    "        loss_r = module.criterion(imu_data, batch[\"IMU_data\"])\n",
    "\n",
    "        loss_t = module.calculate_time_loss(IK_pred, cfg.fps, cfg.euler)\n",
    "\n",
    "        # Get GC model forces and CoP\n",
    "        try: \n",
    "            if cfg.gc_model == 'sliding':\n",
    "                grf, moment, grf_cps, cp_mix_ = sliding_contact_point(\n",
    "                    gc_positions,\n",
    "                    batch[\"ground_contact_model\"],\n",
    "                    gc_model,\n",
    "                    cfg,\n",
    "                    device = module.device,\n",
    "                    get_moments=True,\n",
    "                    kinematics=reconstructed_data\n",
    "                )\n",
    "            else:\n",
    "                grf, moment, grf_cps = contact_points_2d(\n",
    "                    gc_positions,\n",
    "                    batch[\"ground_contact_model\"],\n",
    "                    gc_model,\n",
    "                    cfg,\n",
    "                    device = \"cpu\",\n",
    "                    get_moments=True,\n",
    "                    kinematics=reconstructed_data\n",
    "                )\n",
    "        except Exception as e:\n",
    "            print('Warning: Fallback to contact_points 2D version because of', e)\n",
    "            grf, moment, grf_cps = contact_points_2d(\n",
    "                gc_positions,\n",
    "                batch[\"ground_contact_model\"],\n",
    "                gc_model,\n",
    "                cfg,\n",
    "                device = \"cpu\",\n",
    "                get_moments=True,\n",
    "                kinematics=reconstructed_data\n",
    "            )\n",
    "\n",
    "        # Kane's Loss        \n",
    "        loss_grf = module.calculate_grf_bounds_loss(grf)\n",
    "\n",
    "        if do_physics:\n",
    "            loss_p = module.calculate_physics_loss(\n",
    "                IK_pred,\n",
    "                estimation[\"torques\"],\n",
    "                grf,\n",
    "                moment,\n",
    "                batch[\"body_constants\"],\n",
    "            )\n",
    "        else:\n",
    "            loss_p = 0\n",
    "    \n",
    "        loss = (\n",
    "            + loss_r * module.loss_weights.wr\n",
    "            + loss_t * module.loss_weights.wt\n",
    "            + loss_p * module.loss_weights.wp\n",
    "        )\n",
    "\n",
    "        return estimation, reconstructed_data, imu_data, gc_positions, grf_cps, moment, loss_p, grf, ankle_imu_globals, loss_foot_speed\n",
    "    except Exception as e:\n",
    "        print('Warning: Fallback to old version because of', e)\n",
    "        return estimation, reconstructed_data, imu_data, gc_positions, None, None, None, None, None, None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(metric, cfg, metadata, reconstructed_data, omc_data):\n",
    "    reconstructed_data = reconstructed_data.copy()\n",
    "    if metric in ['jitter_q', 'jitter_qdot', 'jitter_qddot', 'relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']:\n",
    "        jitter = 0\n",
    "        i = 0\n",
    "        if metric in ['relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']:\n",
    "            for key in reconstructed_data:\n",
    "                reconstructed_data[key] = reconstructed_data[key] - reconstructed_data['pelvis']\n",
    "        if metric in ['jitter_q', 'relative_jitter_q']:\n",
    "            idx = 0\n",
    "            exponent = 3\n",
    "        elif metric in ['jitter_qdot', 'relative_jitter_qdot']:\n",
    "            idx = 1\n",
    "            exponent = 2\n",
    "        elif metric in ['jitter_qddot', 'relative_jitter_qddot']:\n",
    "            idx = 2\n",
    "            exponent = 1\n",
    "        for key in reconstructed_data:\n",
    "            if key in ['hip_r','hip_l'] or (key == 'pelvis' and metric in ['relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']):\n",
    "                continue # Skip the hip joints as they are the same as root\n",
    "            jitter += torch.mean(torch.sqrt(reconstructed_data[key][0,:,idx].diff(exponent)**2 + reconstructed_data[key][0,:,idx+3].diff(exponent)**2))*cfg.fps**exponent\n",
    "            i += 1\n",
    "        return (jitter/i).cpu().detach().numpy()\n",
    "    if metric in ['relative_jpe','absolute_jpe']:\n",
    "        jpe = 0\n",
    "        n = 0\n",
    "        omc_start = metadata['omc_start']\n",
    "        omc_end = reconstructed_data['pelvis'].shape[1]\n",
    "\n",
    "        # Average Markers if multiple markers are available\n",
    "        for orientation in ['_X','_Y']:\n",
    "            for joint in marker_joint_roots:\n",
    "                omc_data[joint+orientation] = np.mean([omc_data[marker+orientation] for marker in marker_joint_roots[joint]],axis=0)*1e-3\n",
    "\n",
    "        if len(omc_data) > (omc_end-omc_start):\n",
    "            omc_data = omc_data[:omc_end-omc_start]\n",
    "        if len(omc_data) < (omc_end-omc_start):\n",
    "            print(len(omc_data) , (omc_end-omc_start))\n",
    "            for key in reconstructed_data:\n",
    "                reconstructed_data[key] = reconstructed_data[key][:,:omc_start+len(omc_data),:]\n",
    "        if np.abs(len(omc_data) - omc_end + omc_start) > 1:\n",
    "            print('Warning: difference between omc and imu data larger than rounding error(', np.abs(len(omc_data) - omc_end + omc_start), ')', metadata) \n",
    "        if metric == 'relative_jpe':\n",
    "            align_x = omc_data['pelvis_X'] - reconstructed_data['pelvis'][0,omc_start:omc_end,0].detach().cpu().numpy()\n",
    "            align_y = omc_data['pelvis_Y'] - reconstructed_data['pelvis'][0,omc_start:omc_end,3].detach().cpu().numpy()\n",
    "        elif metric == 'absolute_jpe': # Align to the first frame only\n",
    "            align_x = omc_data['pelvis_X'].iloc[0] - reconstructed_data['pelvis'][0,omc_start,0].detach().cpu().numpy()\n",
    "            align_y = omc_data['pelvis_Y'].iloc[0] - reconstructed_data['pelvis'][0,omc_start,3].detach().cpu().numpy()\n",
    "\n",
    "        for key in reconstructed_data:\n",
    "            if key in ['hip_r','hip_l']:\n",
    "                continue # Skip the hip joints as they are the same as root\n",
    "            if metric == 'relative_jpe' and key == 'pelvis':\n",
    "                continue # Skip the pelvis as it is the reference\n",
    "            # Align the pelvis position to the OMC data at omc_start\n",
    "            aligned_x = reconstructed_data[key][0,omc_start:omc_end,0].detach().cpu().numpy() + align_x\n",
    "            aligned_y = reconstructed_data[key][0,omc_start:omc_end,3].detach().cpu().numpy() + align_y\n",
    "            # the shape of aligned_x and omc_data[key+'_X'] are not always the same\n",
    "\n",
    "            err_x = aligned_x - omc_data[key+'_X']\n",
    "            err_y = aligned_y - omc_data[key+'_Y']\n",
    "\n",
    "            jpe += np.mean(np.sqrt(err_x**2 + err_y**2))\n",
    "            n += 1\n",
    "        return (jpe/n)\n",
    "    if metric == 'speed_error':\n",
    "        # Calculate the speed of the pelvis_markers_x\n",
    "        pelvis_m_ = np.mean(omc_data[pelvis_markers_x],axis=1)\n",
    "        speed = (pelvis_m_.iloc[-1]-pelvis_m_.iloc[0])/len(pelvis_m_)*cfg.fps\n",
    "        speed_pred = torch.mean(reconstructed_data['pelvis'][0,:,1]).detach().cpu().numpy()\n",
    "        return np.abs(speed*1e-3-speed_pred)\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(metric, cfg, metadata, reconstructed_data, omc_data):\n",
    "    offset = 4 # The offset between the OMC data and the reconstructed data is ca. 4 frames - or 0.04s\n",
    "    reconstructed_data = reconstructed_data.copy()\n",
    "    if metric in ['jitter_q', 'jitter_qdot', 'jitter_qddot', 'relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']:\n",
    "        jitter = 0\n",
    "        i = 0\n",
    "        if metric in ['relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']:\n",
    "            for key in reconstructed_data:\n",
    "                reconstructed_data[key] = reconstructed_data[key] - reconstructed_data['pelvis']\n",
    "        if metric in ['jitter_q', 'relative_jitter_q']:\n",
    "            idx = 0\n",
    "            exponent = 3\n",
    "        elif metric in ['jitter_qdot', 'relative_jitter_qdot']:\n",
    "            idx = 1\n",
    "            exponent = 2\n",
    "        elif metric in ['jitter_qddot', 'relative_jitter_qddot']:\n",
    "            idx = 2\n",
    "            exponent = 1\n",
    "        for key in reconstructed_data:\n",
    "            if key in ['hip_r','hip_l'] or (key == 'pelvis' and metric in ['relative_jitter_q', 'relative_jitter_qdot', 'relative_jitter_qddot']):\n",
    "                continue # Skip the hip joints as they are the same as root\n",
    "            jitter += torch.mean(torch.sqrt(reconstructed_data[key][0,:,idx].diff(exponent)**2 + reconstructed_data[key][0,:,idx+3].diff(exponent)**2))*cfg.fps**exponent\n",
    "            i += 1\n",
    "        return (jitter/i).cpu().detach().numpy(), None\n",
    "    if metric in ['relative_jpe','absolute_jpe']:\n",
    "        jpe = 0\n",
    "        n = 0\n",
    "        omc_start = metadata['omc_start']\n",
    "        omc_end = omc_start + int(100*(omc_data.TIME.iloc[-1]-omc_data.TIME.iloc[0]))\n",
    "        len_max = omc_end - omc_start\n",
    "        len_max2 = reconstructed_data['pelvis'].shape[1]-omc_start-offset\n",
    "        # Average Markers if multiple markers are available\n",
    "        for orientation in ['_X','_Y']:\n",
    "            for joint in marker_joint_roots:\n",
    "                omc_data[joint+orientation] = np.mean([omc_data[marker+orientation] for marker in marker_joint_roots[joint]],axis=0)*1e-3\n",
    "        if len(omc_data) > len_max2:\n",
    "            omc_data = omc_data.iloc[:len_max2]\n",
    "        if len(omc_data) < (omc_end-omc_start):\n",
    "            for key in reconstructed_data:\n",
    "                reconstructed_data[key] = reconstructed_data[key][:,:,:]\n",
    "        if np.abs(len(omc_data) - len_max2) > 1:\n",
    "            print('Warning: Framedrops in', metadata, ' - This trial will be skipped')\n",
    "            print(len(omc_data), omc_end, omc_start)\n",
    "            # The downsampled signal could be +-1 sample in length, but shouldn't be outside that range. On closer inspections of these trials, we found some framedrops.\n",
    "            return None, None\n",
    "        if metric == 'relative_jpe':\n",
    "            align_x = omc_data['pelvis_X'] - reconstructed_data['pelvis'][0,omc_start+offset:omc_end+offset,0].detach().cpu().numpy()\n",
    "            align_y = omc_data['pelvis_Y'] - reconstructed_data['pelvis'][0,omc_start+offset:omc_end+offset,3].detach().cpu().numpy()\n",
    "        elif metric == 'absolute_jpe': # Align to the first frame only\n",
    "            align_x = omc_data['pelvis_X'].iloc[0] - reconstructed_data['pelvis'][0,omc_start+offset,0].detach().cpu().numpy()\n",
    "            align_y = omc_data['pelvis_Y'].iloc[0] - reconstructed_data['pelvis'][0,omc_start+offset,3].detach().cpu().numpy()\n",
    "        all_dict = {}\n",
    "        for key in reconstructed_data:\n",
    "            if key in ['hip_r','hip_l']:\n",
    "                continue # Skip the hip joints as they are the same as root\n",
    "            if metric == 'relative_jpe' and key == 'pelvis':\n",
    "                continue # Skip the pelvis as it is the reference\n",
    "            # Align the pelvis position to the OMC data at omc_start\n",
    "            aligned_x = reconstructed_data[key][0,omc_start+offset:omc_end+offset,0].detach().cpu().numpy() + align_x\n",
    "            aligned_y = reconstructed_data[key][0,omc_start+offset:omc_end+offset,3].detach().cpu().numpy() + align_y\n",
    "            # the shape of aligned_x and omc_data[key+'_X'] are not always the same\n",
    "\n",
    "            err_x = aligned_x - omc_data[key+'_X']\n",
    "            err_y = aligned_y - omc_data[key+'_Y']\n",
    "\n",
    "\n",
    "            jpe += np.mean(np.sqrt(err_x**2 + err_y**2))\n",
    "            all_dict[key] = np.array(np.sqrt(err_x**2 + err_y**2))\n",
    "            n += 1\n",
    "        return (jpe/n), all_dict\n",
    "    if metric == 'speed_error':\n",
    "        # Calculate the speed of the pelvis_markers_x\n",
    "        pelvis_m_ = np.mean(omc_data[pelvis_markers_x],axis=1)\n",
    "        speed = (pelvis_m_.iloc[-1]-pelvis_m_.iloc[0])/len(pelvis_m_)*cfg.fps\n",
    "        speed_pred = torch.mean(reconstructed_data['pelvis'][0,:,1]).detach().cpu().numpy()\n",
    "        return np.abs(speed*1e-3-speed_pred), None\n",
    "    if metric == 'speed_error_rmsd':\n",
    "        # Calculate the speed of the pelvis_markers_x\n",
    "        pelvis_m_ = np.mean(omc_data[pelvis_markers_x],axis=1)\n",
    "        speed = (pelvis_m_.iloc[-1]-pelvis_m_.iloc[0])/len(pelvis_m_)*cfg.fps\n",
    "        speed_pred = torch.mean(reconstructed_data['pelvis'][0,:,1]).detach().cpu().numpy()\n",
    "        return np.abs(speed*1e-3-speed_pred)**2, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dictionary for the different metrics\n",
    "marker_joint_roots = {\n",
    "    'pelvis': ['LGTRO', 'RGTRO'],\n",
    "    'knee_r': ['RKNE'],\n",
    "    'knee_l': ['LKNE'],\n",
    "    'ankle_r': ['RANK'],\n",
    "    'ankle_l': ['LANK'],\n",
    "}\n",
    "pelvis_markers_x = ['LASI_X', 'RASI_X', 'LPSI_X', 'RPSI_X', 'SACR_X']   \n",
    "def get_current_omc_trial(omc_data, trial, trigger_idx):\n",
    "    curr_trigger_event = omc_data[omc_data['TRIGGER'] == trigger_idx][omc_data['TRIAL'] == trial].index\n",
    "    t_start = omc_data[omc_data['TIME'] == 0].index\n",
    "    # get the next-lowest t_start from the current trigger event\n",
    "    idx = len(t_start[t_start < np.array(curr_trigger_event)[0]])-1\n",
    "    t_end = t_start-1\n",
    "    omc_data_roi = omc_data[t_start[idx]:t_end[(idx+1)%len(t_end)]]\n",
    "    # return every second row where CLAV_X is not NaN -> 100 Hz data\n",
    "    return omc_data_roi[~np.isnan(omc_data_roi['CLAV_X'])][::2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loop over all results folders "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loop over all folders in logs/results/train and return the folders as a list\n",
    "log_folders = [f for f in os.listdir('logs/') if os.path.isdir(os.path.join('logs/', f))]\n",
    "log_folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define all metrics\n",
    "metrics = [\n",
    "    'jitter_q',\n",
    "    'jitter_qdot',\n",
    "    'jitter_qddot',\n",
    "    'absolute_jpe',\n",
    "    'relative_jpe',\n",
    "    'speed_error',\n",
    "    'speed_error_rmsd',\n",
    "]\n",
    "metrics_walking = [metric+'_walking' for metric in metrics]\n",
    "metrics_running = [metric+'_running' for metric in metrics]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Initialize the results dataframe; this throws some warnings that are irrelevant for this version\n",
    "eval_dataframe = pd.DataFrame(columns=[*metrics, *metrics_walking, *metrics_running])\n",
    "# Skip the bool warnings in pandas\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "#  Loop over all runs and metrics\n",
    "for folder in log_folders:\n",
    "    try:\n",
    "        datamodule, module, cfg = get_specific_model(os.path.join('',folder))\n",
    "    except:\n",
    "        try: \n",
    "            datamodule, module, cfg = get_specific_model(os.path.join('',folder),-2)\n",
    "        except:\n",
    "            print('Error in', folder, 'wont load model')\n",
    "            for metric in metrics:\n",
    "                eval_dataframe.loc[folder, metric] = None\n",
    "            continue\n",
    "    try:\n",
    "        if cfg.model.optimize_constants is not None:\n",
    "            load_constants_pkl(module, folder)\n",
    "    except:\n",
    "        pass\n",
    "    dataset = dorschky2024_evaluate_dataset.Dorschky2024EvalDataset(datapath, cfg.datamodule.test_dataset.subjects)\n",
    "    module.to(device)\n",
    "    no_trials = len(dataset)\n",
    "\n",
    "    # create a dataframe for the evaluation results of the current model\n",
    "    current_metrics = pd.DataFrame(columns=['trial_idx', *metrics])\n",
    "    current_subject = -1\n",
    "    for i in range(no_trials):\n",
    "        datasample, metadata = get_data(i)\n",
    "        if metadata['trigger_no'] < 8:\n",
    "            continue\n",
    "        if int(metadata['subject']) != int(current_subject):\n",
    "            current_subject = metadata['subject']\n",
    "            omc_data_0 = pd.read_parquet(f'{path_to_data}/P{str(current_subject).zfill(2)}_OMC.parquet') # From 1000 Hz (Forceplate) / 200 Hz (OMC) to 100 Hz\n",
    "            omc_data_0.columns = [col.replace('-', '_') for col in omc_data_0.columns]\n",
    "            print('Subject loaded: ', current_subject)\n",
    "        omc_data = get_current_omc_trial(omc_data_0, metadata['trigger_no'], metadata['trigger_idx'])\n",
    "        estimation, reconstructed_data, imu_data, gc_positions, grf_cps, moment, loss_p, grf, ankle_imu_globals, _ = forward_of_model(module, datasample, cfg, do_physics=True)\n",
    "        for metric in metrics:\n",
    "            current_metrics.loc[i, metric], _ = evaluate(metric, cfg, metadata, reconstructed_data, omc_data)\n",
    "        current_metrics.loc[i, 'trial_idx'] = metadata['trigger_idx']\n",
    "    for metric in metrics:\n",
    "        eval_dataframe.loc[folder, metric] = current_metrics[metric].mean()\n",
    "    for metric2, metric in zip(metrics_walking, metrics):\n",
    "        c2 = current_metrics[current_metrics.trial_idx < 9.5]\n",
    "        eval_dataframe.loc[folder, metric2] = c2[metric].mean()\n",
    "    for metric2, metric in zip(metrics_running, metrics):\n",
    "        c2 = current_metrics[current_metrics.trial_idx > 9.5]\n",
    "        eval_dataframe.loc[folder, metric2] = c2[metric].mean()\n",
    "    display(eval_dataframe.loc[folder,metrics])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataframe"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pinn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
