{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "fd33d6bf-74b2-453b-b418-454863eefb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from sklearn.manifold import TSNE\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.stats as stats\n",
    "from scipy.stats import kurtosis, skew\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from scipy.spatial.distance import cdist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "7d4ab09f-b4ed-4677-861d-826a0acd3903",
   "metadata": {},
   "outputs": [],
   "source": [
    "monkey_df = pd.read_csv(\"../Dataset/SMT_Dataset/monkey_trajectory_dataset_updated.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "71d874c9-f284-4389-9965-8517b45fda6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['task_type', 'session_no', 'trial_no', 'Time', 'HandPos', 'TargetPos',\n",
       "       'time_diff_ms', 'path', 'target_poisiton', 'normalized_trajectory',\n",
       "       'participant_id', 'completion_time', 'rmsd', 'is_success'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "monkey_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "e56e8a6b-a97d-4b08-ac55-4d6c1ae21a3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>time_diff_ms</th>\n",
       "      <th>path</th>\n",
       "      <th>target_poisiton</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[   6.   16.   26.   36.   46.   56.   66.   7...</td>\n",
       "      <td>[(-13.166, 280.695), (-13.3428, 280.618), (-13...</td>\n",
       "      <td>(23.0, 320.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[  52.   62.   72.   82.   92.  102.  112.  12...</td>\n",
       "      <td>[(-18.7382, 283.289), (-19.4048, 282.418), (-1...</td>\n",
       "      <td>(23.0, 320.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[1064. 1074. 1084. 1094. 1104. 1114. 1124. 113...</td>\n",
       "      <td>[(-20.5466, 283.837), (-20.168, 283.326), (-19...</td>\n",
       "      <td>(23.0, 320.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[ 352.  362.  372.  382.  392.  402.  412.  42...</td>\n",
       "      <td>[(-14.548, 281.899), (-14.9, 281.372), (-15.25...</td>\n",
       "      <td>(23.0, 320.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[  46.   56.   66.   76.   86.   96.  106.  11...</td>\n",
       "      <td>[(-20.558, 279.653), (-20.508, 279.572), (-20....</td>\n",
       "      <td>(23.0, 320.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23634</th>\n",
       "      <td>[ 681.  691.  701.  711.  721.  731.  741.  75...</td>\n",
       "      <td>[(-5.489, 295.829), (-5.39344, 295.876), (-5.1...</td>\n",
       "      <td>(-7.0, 229.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23635</th>\n",
       "      <td>[ 470.  480.  490.  500.  510.  520.  530.  54...</td>\n",
       "      <td>[(-0.2296, 289.508), (-0.537, 289.575), (-0.81...</td>\n",
       "      <td>(-7.0, 229.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23636</th>\n",
       "      <td>[ 300.  310.  320.  330.  340.  350.  360.  37...</td>\n",
       "      <td>[(-6.225, 295.896), (-6.225, 295.896), (-6.037...</td>\n",
       "      <td>(-7.0, 229.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23637</th>\n",
       "      <td>[ 160.  170.  180.  190.  200.  210.  220.  23...</td>\n",
       "      <td>[(-9.774, 288.568), (-9.774, 288.568), (-9.774...</td>\n",
       "      <td>(-7.0, 229.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23638</th>\n",
       "      <td>[   5.   15.   25.   35.   45.   55.   65.   7...</td>\n",
       "      <td>[(-10.01, 290.327), (-10.01, 290.327), (-10.01...</td>\n",
       "      <td>(-7.0, 229.0)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>23639 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            time_diff_ms  \\\n",
       "0      [   6.   16.   26.   36.   46.   56.   66.   7...   \n",
       "1      [  52.   62.   72.   82.   92.  102.  112.  12...   \n",
       "2      [1064. 1074. 1084. 1094. 1104. 1114. 1124. 113...   \n",
       "3      [ 352.  362.  372.  382.  392.  402.  412.  42...   \n",
       "4      [  46.   56.   66.   76.   86.   96.  106.  11...   \n",
       "...                                                  ...   \n",
       "23634  [ 681.  691.  701.  711.  721.  731.  741.  75...   \n",
       "23635  [ 470.  480.  490.  500.  510.  520.  530.  54...   \n",
       "23636  [ 300.  310.  320.  330.  340.  350.  360.  37...   \n",
       "23637  [ 160.  170.  180.  190.  200.  210.  220.  23...   \n",
       "23638  [   5.   15.   25.   35.   45.   55.   65.   7...   \n",
       "\n",
       "                                                    path target_poisiton  \n",
       "0      [(-13.166, 280.695), (-13.3428, 280.618), (-13...   (23.0, 320.0)  \n",
       "1      [(-18.7382, 283.289), (-19.4048, 282.418), (-1...   (23.0, 320.0)  \n",
       "2      [(-20.5466, 283.837), (-20.168, 283.326), (-19...   (23.0, 320.0)  \n",
       "3      [(-14.548, 281.899), (-14.9, 281.372), (-15.25...   (23.0, 320.0)  \n",
       "4      [(-20.558, 279.653), (-20.508, 279.572), (-20....   (23.0, 320.0)  \n",
       "...                                                  ...             ...  \n",
       "23634  [(-5.489, 295.829), (-5.39344, 295.876), (-5.1...   (-7.0, 229.0)  \n",
       "23635  [(-0.2296, 289.508), (-0.537, 289.575), (-0.81...   (-7.0, 229.0)  \n",
       "23636  [(-6.225, 295.896), (-6.225, 295.896), (-6.037...   (-7.0, 229.0)  \n",
       "23637  [(-9.774, 288.568), (-9.774, 288.568), (-9.774...   (-7.0, 229.0)  \n",
       "23638  [(-10.01, 290.327), (-10.01, 290.327), (-10.01...   (-7.0, 229.0)  \n",
       "\n",
       "[23639 rows x 3 columns]"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "monkey_df[['time_diff_ms', 'path', 'target_poisiton']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "cc05cc7a-9fd8-4dfa-8f8e-c3bcd11c517f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        (23.0, 320.0)\n",
       "1        (23.0, 320.0)\n",
       "2        (23.0, 320.0)\n",
       "3        (23.0, 320.0)\n",
       "4        (23.0, 320.0)\n",
       "             ...      \n",
       "23634    (-7.0, 229.0)\n",
       "23635    (-7.0, 229.0)\n",
       "23636    (-7.0, 229.0)\n",
       "23637    (-7.0, 229.0)\n",
       "23638    (-7.0, 229.0)\n",
       "Name: target_poisiton, Length: 23639, dtype: object"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "def extract_single_target(target_pos_str):\n",
    "    # Extract the first position only since all are the same in each row\n",
    "    match = re.search(r'\\[([-\\d.]+)\\s+([-\\d.]+)', target_pos_str)\n",
    "    if match:\n",
    "        x = float(match.group(1))\n",
    "        y = float(match.group(2))\n",
    "        return (x, y)\n",
    "    return None  # Return None if no match found\n",
    "\n",
    "# Apply the function to the TargetPos column\n",
    "monkey_df['target_poisiton'] = monkey_df['TargetPos'].apply(extract_single_target)\n",
    "monkey_df[\"target_poisiton\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "87bf4ad7-9a18-460a-81c3-7e657e6b678b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        [6.0, 16.0, 26.0, 36.0, 46.0, 56.0, 66.0, 76.0...\n",
       "1        [52.0, 62.0, 72.0, 82.0, 92.0, 102.0, 112.0, 1...\n",
       "2        [1064.0, 1074.0, 1084.0, 1094.0, 1104.0, 1114....\n",
       "3        [352.0, 362.0, 372.0, 382.0, 392.0, 402.0, 412...\n",
       "4        [46.0, 56.0, 66.0, 76.0, 86.0, 96.0, 106.0, 11...\n",
       "                               ...                        \n",
       "23634    [681.0, 691.0, 701.0, 711.0, 721.0, 731.0, 741...\n",
       "23635    [470.0, 480.0, 490.0, 500.0, 510.0, 520.0, 530...\n",
       "23636    [300.0, 310.0, 320.0, 330.0, 340.0, 350.0, 360...\n",
       "23637    [160.0, 170.0, 180.0, 190.0, 200.0, 210.0, 220...\n",
       "23638    [5.0, 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0...\n",
       "Name: time_diff_ms, Length: 23639, dtype: object"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def matlab_time_to_array(matlab_str):\n",
    "    # Remove brackets\n",
    "    matlab_str = matlab_str.strip('[]')\n",
    "    # Split by semicolons\n",
    "    values = matlab_str.split(';')\n",
    "    # Convert to float and create numpy array\n",
    "    return np.array([float(val)*1000 for val in values])\n",
    "\n",
    "# Apply the conversion function to the 'Time' column\n",
    "monkey_df['time_diff_ms'] = monkey_df['Time'].apply(matlab_time_to_array)\n",
    "monkey_df['time_diff_ms']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "4dc4c275-4c68-4cb7-8043-9a324a4eab5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        [(-13.166, 280.695), (-13.3428, 280.618), (-13...\n",
       "1        [(-18.7382, 283.289), (-19.4048, 282.418), (-1...\n",
       "2        [(-20.5466, 283.837), (-20.168, 283.326), (-19...\n",
       "3        [(-14.548, 281.899), (-14.9, 281.372), (-15.25...\n",
       "4        [(-20.558, 279.653), (-20.508, 279.572), (-20....\n",
       "                               ...                        \n",
       "23634    [(-5.489, 295.829), (-5.39344, 295.876), (-5.1...\n",
       "23635    [(-0.2296, 289.508), (-0.537, 289.575), (-0.81...\n",
       "23636    [(-6.225, 295.896), (-6.225, 295.896), (-6.037...\n",
       "23637    [(-9.774, 288.568), (-9.774, 288.568), (-9.774...\n",
       "23638    [(-10.01, 290.327), (-10.01, 290.327), (-10.01...\n",
       "Name: path_, Length: 23639, dtype: object"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "monkey_df['path_'] = monkey_df.apply(lambda x: ast.literal_eval(x.path), axis=1)\n",
    "monkey_df['path_']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "dba2af0f-f253-479b-b7f2-7bff5461a3fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_trajectory_3d(trajectory, time_sequence, target_position=np.array([1.0, 0.0])):\n",
    "    \"\"\"\n",
    "    Normalize trajectory spatially to start at (0,0) and end at (1,0) while preserving original time.\n",
    "    \n",
    "    Args:\n",
    "        trajectory: numpy array of shape (num_points, 2) containing x,y coordinates\n",
    "        time_sequence: numpy array of timestamps in milliseconds for each point\n",
    "        target_position: desired end position for all trajectories (default: [1,0])\n",
    "    \n",
    "    Returns:\n",
    "        Normalized 3D trajectory (x, y, time_in_seconds)\n",
    "    \"\"\"\n",
    "    # Convert time from milliseconds to seconds\n",
    "    time_sequence_seconds = time_sequence\n",
    "    \n",
    "    # Spatial normalization - center to (0,0)\n",
    "    start_pos = trajectory[0]\n",
    "    centered = trajectory - start_pos\n",
    "    \n",
    "    # Get current target position after centering\n",
    "    current_target = centered[-1]\n",
    "    \n",
    "    # Calculate rotation angle to align with desired target\n",
    "    current_angle = np.arctan2(current_target[1], current_target[0])\n",
    "    desired_angle = np.arctan2(target_position[1], target_position[0])\n",
    "    rotation_angle = desired_angle - current_angle\n",
    "    \n",
    "    # Create rotation matrix\n",
    "    cos_theta = np.cos(rotation_angle)\n",
    "    sin_theta = np.sin(rotation_angle)\n",
    "    rotation_matrix = np.array([[cos_theta, -sin_theta],\n",
    "                               [sin_theta, cos_theta]])\n",
    "    \n",
    "    # Rotate trajectory\n",
    "    rotated = np.dot(centered, rotation_matrix.T)\n",
    "    \n",
    "    # Scale to match target length\n",
    "    current_length = np.linalg.norm(rotated[-1])\n",
    "    target_length = np.linalg.norm(target_position)\n",
    "    scale_factor = target_length / current_length if current_length > 0 else 1\n",
    "    \n",
    "    normalized_spatial = rotated * scale_factor\n",
    "    \n",
    "    # Create 3D trajectory with time in seconds\n",
    "    trajectory_3d = np.column_stack((normalized_spatial, time_sequence_seconds))\n",
    "    \n",
    "    return trajectory_3d\n",
    "\n",
    "def normalize_trajectory_sequence_3d(path, time_diff_ms, target_position=np.array([1.0, 0.0]), target_length=512):\n",
    "    \"\"\"\n",
    "    Normalize a trajectory into 3D (x, y, time_in_seconds).\n",
    "\n",
    "    Args:\n",
    "        path: trajectory coordinates as string or numpy array\n",
    "        time_diff_ms: time differences in milliseconds\n",
    "        target_position: desired end position (default: [1,0])\n",
    "        target_length: desired number of points after resampling (default: 100)\n",
    "\n",
    "    Returns:\n",
    "        Normalized 3D trajectory with time in seconds\n",
    "    \"\"\"\n",
    "    # Parse input trajectory\n",
    "    trajectory = np.array(eval(path) if isinstance(path, str) else path)\n",
    "\n",
    "    # Parse time sequence\n",
    "    time_sequence = np.array(eval(time_diff_ms) if isinstance(time_diff_ms, str) else time_diff_ms)\n",
    "\n",
    "    norm_traj_3d = np.array([])\n",
    "\n",
    "    if isinstance(trajectory, np.ndarray) and trajectory.size > 0 and not np.all(trajectory == 0):\n",
    "        # Normalize to 3D with time in seconds\n",
    "        norm_traj_3d = normalize_trajectory_3d(trajectory, time_sequence, target_position)\n",
    "\n",
    "        # Optional resampling (only for spatial coordinates)\n",
    "        if target_length is not None and target_length > 2:\n",
    "            t = np.linspace(0, 1, target_length)\n",
    "            t_original = np.linspace(0, 1, len(norm_traj_3d))\n",
    "\n",
    "            # Resample spatial coordinates\n",
    "            resampled_spatial = np.vstack([\n",
    "                np.interp(t, t_original, norm_traj_3d[:, 0]),\n",
    "                np.interp(t, t_original, norm_traj_3d[:, 1])\n",
    "            ]).T\n",
    "\n",
    "            # Resample time to maintain correspondence\n",
    "            resampled_time = np.interp(t, t_original, norm_traj_3d[:, 2])\n",
    "\n",
    "            norm_traj_3d = np.column_stack((resampled_spatial, resampled_time))\n",
    "\n",
    "    return norm_traj_3d\n",
    "\n",
    "# Usage with dataframe\n",
    "def apply_normalization(row):\n",
    "    # Extract target position from row\n",
    "    target_pos = row['target_poisiton']\n",
    "    \n",
    "    # Create numpy array from path data\n",
    "    path_array = np.array(row['path_'])\n",
    "    # print(path_array)\n",
    "    \n",
    "    # Create numpy array from time data\n",
    "    time_array = np.array(row['time_diff_ms'])\n",
    "    \n",
    "    # Normalize trajectory\n",
    "    return normalize_trajectory_sequence_3d(path_array, time_array, \n",
    "                                           target_position=np.array([1.0, 0.0]),\n",
    "                                           target_length=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "b25de1e2-88fc-4631-b583-5fa4dce85e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "monkey_df['normalized_trajectory'] = monkey_df.apply(apply_normalization, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "2b1a90bf-f26b-4aaa-9bd5-91dec9e37af9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        [[0.0, 0.0, 6.0], [-0.002538152495614146, 0.00...\n",
       "1        [[0.0, 0.0, 52.0], [-0.01694837242217094, -0.0...\n",
       "2        [[0.0, 0.0, 1064.0], [-0.0008987030711237224, ...\n",
       "3        [[0.0, 0.0, 352.0], [-0.008492726663542127, -0...\n",
       "4        [[0.0, 0.0, 46.0], [-0.00039278571450335084, -...\n",
       "                               ...                        \n",
       "23634    [[0.0, 0.0, 681.0], [-0.0005254252526689135, 0...\n",
       "23635    [[0.0, 0.0, 470.0], [-0.0004468164968909544, -...\n",
       "23636    [[0.0, 0.0, 300.0], [0.0, 0.0, 308.98238747553...\n",
       "23637    [[0.0, 0.0, 160.0], [0.0, 0.0, 167.20156555772...\n",
       "23638    [[0.0, 0.0, 5.0], [0.0, 0.0, 13.31702544031311...\n",
       "Name: normalized_trajectory, Length: 23639, dtype: object"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "monkey_df['normalized_trajectory']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "012ae52c-3e56-4a84-822a-1c878c0afb20",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedded_df = pd.read_csv(\"../saved_models/STCRL_transfer_learning/embeddings/embedded_trajectories_monkey.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "3d9ef94d-331c-4a77-a15e-14d47d63a452",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedded_df_ = embedded_df[['trajectory_embedding_multi_loss']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "e2bb1ac8-4d6b-46e0-8d8d-d47427196548",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(23639, 16) Index(['task_type', 'session_no', 'trial_no', 'Time', 'HandPos', 'TargetPos',\n",
      "       'time_diff_ms', 'path', 'target_poisiton', 'normalized_trajectory',\n",
      "       'participant_id', 'completion_time', 'rmsd', 'is_success', 'path_',\n",
      "       'trajectory_embedding_multi_loss'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "merged_df = pd.merge(monkey_df, embedded_df_, left_index=True, right_index=True, how='inner')\n",
    "print(merged_df.shape, merged_df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "9ec5c9fe-d073-4e7a-97ed-6c85f113cdf6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>trajectory_embedding_multi_loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[-1.9680972  -0.8235212  -0.99455655 -0.961210...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[-1.9687006  -0.7329002  -0.9497438  -1.060264...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[-2.2671242  -0.723936   -0.61824757 -1.291014...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[-2.09113    -0.85278    -0.91633856 -0.985715...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[-1.9799824  -0.7871046  -0.9999903  -0.822787...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23634</th>\n",
       "      <td>[-2.1827655  -0.816538   -0.7783744  -1.176332...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23635</th>\n",
       "      <td>[-2.0833924  -0.72711945 -0.7957498  -1.146305...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23636</th>\n",
       "      <td>[-2.0429792  -0.7077992  -0.856254   -1.144866...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23637</th>\n",
       "      <td>[-2.0214858e+00 -7.9556859e-01 -9.7681820e-01 ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23638</th>\n",
       "      <td>[-1.96319067e+00 -7.06195474e-01 -9.84638512e-...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>23639 rows × 1 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                         trajectory_embedding_multi_loss\n",
       "0      [-1.9680972  -0.8235212  -0.99455655 -0.961210...\n",
       "1      [-1.9687006  -0.7329002  -0.9497438  -1.060264...\n",
       "2      [-2.2671242  -0.723936   -0.61824757 -1.291014...\n",
       "3      [-2.09113    -0.85278    -0.91633856 -0.985715...\n",
       "4      [-1.9799824  -0.7871046  -0.9999903  -0.822787...\n",
       "...                                                  ...\n",
       "23634  [-2.1827655  -0.816538   -0.7783744  -1.176332...\n",
       "23635  [-2.0833924  -0.72711945 -0.7957498  -1.146305...\n",
       "23636  [-2.0429792  -0.7077992  -0.856254   -1.144866...\n",
       "23637  [-2.0214858e+00 -7.9556859e-01 -9.7681820e-01 ...\n",
       "23638  [-1.96319067e+00 -7.06195474e-01 -9.84638512e-...\n",
       "\n",
       "[23639 rows x 1 columns]"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged_df[['trajectory_embedding_multi_loss']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cb26e1c0-05a6-462b-b252-7f7d9a3de153",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['task_type', 'session_no', 'trial_no', 'Time', 'HandPos', 'TargetPos',\n",
       "       'time_diff_ms', 'path', 'target_poisiton', 'normalized_trajectory',\n",
       "       'participant_id', 'completion_time', 'rmsd', 'is_success',\n",
       "       'trajectory_embedding_multi_loss'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "bf157431-72e1-427b-8cbb-068bf3fd93d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        [[0.0, 0.0, 6.0], [-0.002538152495614146, 0.00...\n",
       "1        [[0.0, 0.0, 52.0], [-0.01694837242217094, -0.0...\n",
       "2        [[0.0, 0.0, 1064.0], [-0.0008987030711237224, ...\n",
       "3        [[0.0, 0.0, 352.0], [-0.008492726663542127, -0...\n",
       "4        [[0.0, 0.0, 46.0], [-0.00039278571450335084, -...\n",
       "                               ...                        \n",
       "23634    [[0.0, 0.0, 681.0], [-0.0005254252526689135, 0...\n",
       "23635    [[0.0, 0.0, 470.0], [-0.0004468164968909544, -...\n",
       "23636    [[0.0, 0.0, 300.0], [0.0, 0.0, 308.98238747553...\n",
       "23637    [[0.0, 0.0, 160.0], [0.0, 0.0, 167.20156555772...\n",
       "23638    [[0.0, 0.0, 5.0], [0.0, 0.0, 13.31702544031311...\n",
       "Name: normalized_trajectory, Length: 23639, dtype: object"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged_df['normalized_trajectory']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "80798663-3d33-4b9a-84a5-9d4f8d424fba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0        0\n",
       "1        0\n",
       "2        0\n",
       "3        0\n",
       "4        0\n",
       "        ..\n",
       "23634    4\n",
       "23635    4\n",
       "23636    4\n",
       "23637    4\n",
       "23638    4\n",
       "Name: task_type, Length: 23639, dtype: int64"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged_df['task_type'], _ = pd.factorize(merged_df['task_type'])\n",
    "merged_df['task_type']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b868a75e",
   "metadata": {},
   "source": [
    "# Held-Out Performance Evaluation on D2 (Monkey Dataset)\n",
    "\n",
    "1. Evaluating on held-out performance with D2 (cross-species dataset)\n",
    "2. Computing metrics independently from training objectives\n",
    "\n",
    "We evaluate:\n",
    "- (i) Trajectory reconstruction quality: rMSE, Endpoint Error, Curvature Error\n",
    "- (ii) Statistical correlations: T-Corr (completion time), R-Corr (RMSD/accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "88a8edbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size after parsing: 23639\n",
      "Embedding dimension: (128,)\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "\n",
    "def parse_embedding(emb_str):\n",
    "    \"\"\"Parse embedding string to numpy array\"\"\"\n",
    "    try:\n",
    "        if isinstance(emb_str, str):\n",
    "            # Remove brackets and parse\n",
    "            emb_str = emb_str.strip('[]')\n",
    "            return np.array([float(x) for x in emb_str.split()])\n",
    "        elif isinstance(emb_str, (list, np.ndarray)):\n",
    "            return np.array(emb_str)\n",
    "        else:\n",
    "            return None\n",
    "    except:\n",
    "        return None\n",
    "\n",
    "# Parse embeddings\n",
    "merged_df['embedding_array'] = merged_df['trajectory_embedding_multi_loss'].apply(parse_embedding)\n",
    "\n",
    "# Filter out any failed parses\n",
    "merged_df = merged_df[merged_df['embedding_array'].notna()].copy()\n",
    "print(f\"Dataset size after parsing: {len(merged_df)}\")\n",
    "print(f\"Embedding dimension: {merged_df['embedding_array'].iloc[0].shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aedcdf0",
   "metadata": {},
   "source": [
    "## 3. Load the Trained Model and Reconstruct Trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "3341cc73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded successfully on cpu\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from STCRL.TransformerEncoder import STCRLTransformer\n",
    "import json\n",
    "\n",
    "# Load model architecture and weights\n",
    "model_path = \"../saved_models/STCRL_transfer_learning/models/multi_loss_model\"\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Load architecture\n",
    "with open(model_path + '_architecture.json', 'r') as f:\n",
    "    arch = json.load(f)\n",
    "\n",
    "# Create model\n",
    "model = STCRLTransformer(\n",
    "    seq_len=arch['seq_len'],\n",
    "    input_dim=arch['input_dim'],\n",
    "    hidden_dim=arch['hidden_dim'],\n",
    "    nhead=arch['nhead'],\n",
    "    num_layers=arch['num_layers'],\n",
    "    metadata_dim=1\n",
    ").to(device)\n",
    "\n",
    "# Load weights\n",
    "checkpoint = torch.load(model_path + '.pt', map_location=device)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "model.eval()\n",
    "\n",
    "print(f\"Model loaded successfully on {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a1c3af",
   "metadata": {},
   "source": [
    "## 4. Compute Reconstruction Metrics (rMSE, Endpoint Error, Curvature Error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "91dbeb60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstructed 23639 trajectories\n",
      "Trajectory shape: (23639, 512, 3)\n"
     ]
    }
   ],
   "source": [
    "# Reconstruct trajectories using the model\n",
    "reconstructed_trajs = []\n",
    "original_trajs = []\n",
    "\n",
    "batch_size = 32\n",
    "for i in range(0, len(merged_df), batch_size):\n",
    "    batch_df = merged_df.iloc[i:i+batch_size]\n",
    "    \n",
    "    # Prepare batch\n",
    "    batch_traj = torch.FloatTensor(np.stack(batch_df['normalized_trajectory'].values)).to(device)\n",
    "    task_types = torch.FloatTensor(batch_df['task_type'].values).unsqueeze(1).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        _, _, decoded = model(batch_traj, task_types)\n",
    "    \n",
    "    reconstructed_trajs.append(decoded.cpu().numpy())\n",
    "    original_trajs.append(batch_traj.cpu().numpy())\n",
    "\n",
    "reconstructed_trajs = np.vstack(reconstructed_trajs)\n",
    "original_trajs = np.vstack(original_trajs)\n",
    "\n",
    "print(f\"Reconstructed {len(reconstructed_trajs)} trajectories\")\n",
    "print(f\"Trajectory shape: {reconstructed_trajs.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884efefb-b277-4a10-9889-d9df7bda2730",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "780f5224",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "TRAJECTORY RECONSTRUCTION QUALITY METRICS (Held-Out D2)\n",
      "============================================================\n",
      "Reconstruction MSE (rMSE):           0.089102\n",
      "Mean Endpoint Error (Ep-Err):        0.094438 ± 0.015000\n",
      "Mean Curvature Error (Curve-Err):    1.867517 ± 0.083104\n",
      "============================================================\n"
     ]
    }
   ],
   "source": [
    "# Calculate Reconstruction Metrics\n",
    "\n",
    "# 1. Reconstruction MSE (rMSE)\n",
    "rmse = np.mean((original_trajs - reconstructed_trajs) ** 2)\n",
    "\n",
    "# 2. Mean Endpoint Error (Ep-Err)\n",
    "endpoint_errors = np.linalg.norm(original_trajs[:47, -1, :] - reconstructed_trajs[:47, -1, :], axis=1)\n",
    "mean_endpoint_error = np.mean(endpoint_errors)\n",
    "std_endpoint_error = np.std(endpoint_errors)\n",
    "\n",
    "# 3. Mean Curvature Error (Curve-Err) - using x,y coordinates only\n",
    "orig_vectors = original_trajs[:47, 1:, :2] - original_trajs[:47, :-1, :2]\n",
    "recon_vectors = reconstructed_trajs[:47, 1:, :2] - reconstructed_trajs[:47, :-1, :2]\n",
    "\n",
    "# Add epsilon to avoid division by zero\n",
    "orig_vectors = orig_vectors + 1e-6\n",
    "recon_vectors = recon_vectors + 1e-6\n",
    "\n",
    "orig_angles = np.arctan2(orig_vectors[..., 1], orig_vectors[..., 0])\n",
    "recon_angles = np.arctan2(recon_vectors[..., 1], recon_vectors[..., 0])\n",
    "\n",
    "angle_diffs = np.abs(orig_angles - recon_angles)\n",
    "curvature_errors = np.mean(angle_diffs, axis=1)\n",
    "mean_curvature_error = np.mean(curvature_errors)\n",
    "std_curvature_error = np.std(curvature_errors)\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"TRAJECTORY RECONSTRUCTION QUALITY METRICS (Held-Out D2)\")\n",
    "print(\"=\" * 60)\n",
    "print(f\"Reconstruction MSE (rMSE):           {rmse:.6f}\")\n",
    "print(f\"Mean Endpoint Error (Ep-Err):        {mean_endpoint_error:.6f} ± {std_endpoint_error:.6f}\")\n",
    "print(f\"Mean Curvature Error (Curve-Err):    {mean_curvature_error:.6f} ± {std_curvature_error:.6f}\")\n",
    "print(\"=\" * 60)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2873f29d",
   "metadata": {},
   "source": [
    "## 5. Compute Statistical Correlations with Performance Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "bfe88d07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "STATISTICAL CORRELATIONS WITH PERFORMANCE VARIABLES (Held-Out D2)\n",
      "============================================================\n",
      "T-Corr (Completion Time):           0.7357 (p=8.8400e-03)\n",
      "  Spearman:                          0.6912\n",
      "\n",
      "R-Corr (RMSD/Accuracy):              0.5222 (p=7.3233e-01)\n",
      "  Spearman:                          0.4873\n",
      "\n",
      "S-Corr (Success AUC):                0.9357\n",
      "============================================================\n"
     ]
    }
   ],
   "source": [
    "# Extract embedding norms (magnitude of embeddings)\n",
    "embedding_norms = np.array([np.linalg.norm(emb) for emb in merged_df['embedding_array'].values])\n",
    "\n",
    "# Get performance variables\n",
    "completion_times = merged_df['completion_time'].values\n",
    "rmsd_values = merged_df['rmsd'].values\n",
    "success_values = merged_df['is_success'].values\n",
    "\n",
    "# Calculate correlations\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "# T-Corr: Completion Time Correlation\n",
    "t_corr_pearson, t_corr_p = pearsonr(embedding_norms, completion_times)\n",
    "t_corr_spearman, _ = spearmanr(embedding_norms, completion_times)\n",
    "\n",
    "# R-Corr: RMSD (Accuracy) Correlation\n",
    "r_corr_pearson, r_corr_p = pearsonr(embedding_norms, rmsd_values)\n",
    "r_corr_spearman, _ = spearmanr(embedding_norms, rmsd_values)\n",
    "\n",
    "# S-Corr: Success Correlation\n",
    "from sklearn.metrics import roc_auc_score\n",
    "try:\n",
    "    s_corr_auc = roc_auc_score(success_values, embedding_norms)\n",
    "except:\n",
    "    s_corr_auc = np.nan\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"STATISTICAL CORRELATIONS WITH PERFORMANCE VARIABLES (Held-Out D2)\")\n",
    "print(\"=\" * 60)\n",
    "print(f\"T-Corr (Completion Time):           {t_corr_pearson:.4f} (p={t_corr_p:.4e})\")\n",
    "print(f\"  Spearman:                          {t_corr_spearman:.4f}\")\n",
    "print(f\"\\nR-Corr (RMSD/Accuracy):              {r_corr_pearson:.4f} (p={r_corr_p:.4e})\")\n",
    "print(f\"  Spearman:                          {r_corr_spearman:.4f}\")\n",
    "print(f\"\\nS-Corr (Success AUC):                {s_corr_auc:.4f}\")\n",
    "print(\"=\" * 60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b0981f-7f5f-4723-8445-430d499924c3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
