{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a3c675d8-6120-4a17-a5cd-5464f7e7ec29",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from sklearn.manifold import TSNE\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.stats as stats\n",
    "from scipy.stats import kurtosis, skew\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from scipy.spatial.distance import cdist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "99ab0284-2e1f-4aae-9c74-9aa625dc4caf",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data_path = \"../saved_models/STCRL/embeddings/embedded_trajectories.csv\"\n",
    "df = pd.read_csv(data_path)\n",
    "df_human = pd.read_csv(\"../Dataset/Scores/human_scores.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "feab3501-104a-479a-9300-3623cca7846e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['participant_id', 'session_no', 'task_type', 'trial_no', 'day', 'block',\n",
       "       'start_point_x', 'start_point_y', 'target_point_x', 'target_point_y',\n",
       "       'start_time', 'end_time', 'quadrant', 'is_success', 'actual_dist',\n",
       "       'movement_dist', 'completion_time', 'path', 'time_string',\n",
       "       'time_diff_ms', 'Age', 'Cohort', 'Gestational_Age',\n",
       "       'mabc_total_test_score', 'mabc_standard_score', 'mabc_percentile',\n",
       "       'distances', 'rmsd', 'normalized_trajectory',\n",
       "       'trajectory_embedding_completion_time',\n",
       "       'trajectory_embedding_sequential', 'trajectory_embedding_rmsd',\n",
       "       'trajectory_embedding_success', 'trajectory_embedding_multi',\n",
       "       'trajectory_embedding_weighted_multi', 'exploration_score',\n",
       "       'exploitation_score', 'exploration_exploitation_ratio'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# df_human.columns\n",
    "df_human = df_human[['exploration_score', 'exploitation_score',\n",
    "       'exploration_exploitation_ratio']]\n",
    "merged_df = pd.merge(df, df_human, left_index=True, right_index=True)\n",
    "merged_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8e524aa0-9369-4562-9269-c172edc4854e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['participant_id', 'session_no', 'task_type', 'trial_no', 'day', 'block',\n",
       "       'start_point_x', 'start_point_y', 'target_point_x', 'target_point_y',\n",
       "       'start_time', 'end_time', 'quadrant', 'is_success', 'actual_dist',\n",
       "       'movement_dist', 'completion_time', 'path', 'time_string',\n",
       "       'time_diff_ms', 'Age', 'Cohort', 'Gestational_Age',\n",
       "       'mabc_total_test_score', 'mabc_standard_score', 'mabc_percentile',\n",
       "       'distances', 'rmsd', 'normalized_trajectory',\n",
       "       'trajectory_embedding_completion_time',\n",
       "       'trajectory_embedding_sequential', 'trajectory_embedding_rmsd',\n",
       "       'trajectory_embedding_success', 'trajectory_embedding_multi',\n",
       "       'trajectory_embedding_weighted_multi', 'exploration_score',\n",
       "       'exploitation_score', 'exploration_exploitation_ratio'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "merged_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6e041ca6-4896-4ad7-ab9f-421ec80d370e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import warnings\n",
    "\n",
    "def calculate_trajectory_statistics(df, participant_id=None, session_no=None, task_type=None):\n",
    "    \"\"\"\n",
    "    Calculate comprehensive trajectory performance statistics with 95% confidence intervals.\n",
    "    \n",
    "    Args:\n",
    "        df: DataFrame with trajectory data\n",
    "        participant_id: Specific participant ID (None for all participants)\n",
    "        session_no: Specific session number (None for all sessions)\n",
    "        task_type: Task type - \"Unimanual\" (0) or \"Bimanual\" (1), or None for both\n",
    "    \n",
    "    Returns:\n",
    "        Dictionary containing statistics with mean ± 95% CI for each metric\n",
    "    \"\"\"\n",
    "    \n",
    "    # Create a copy of the dataframe to avoid modifying the original\n",
    "    data = df.copy()\n",
    "    \n",
    "    # Apply filters\n",
    "    filter_description = []\n",
    "    \n",
    "    if participant_id is not None:\n",
    "        data = data[data['participant_id'] == participant_id]\n",
    "        filter_description.append(f\"Participant: {participant_id}\")\n",
    "    \n",
    "    if session_no is not None:\n",
    "        data = data[data['session_no'] == session_no]\n",
    "        filter_description.append(f\"Session: {session_no}\")\n",
    "    \n",
    "    if task_type is not None:\n",
    "        if isinstance(task_type, str):\n",
    "            # Convert string to numeric\n",
    "            task_numeric = 0 if task_type.lower() == \"unimanual\" else 1\n",
    "        else:\n",
    "            task_numeric = task_type\n",
    "        data = data[data['task_type'] == task_numeric]\n",
    "        task_name = \"Unimanual\" if task_numeric == 0 else \"Bimanual\"\n",
    "        filter_description.append(f\"Task: {task_name}\")\n",
    "    \n",
    "    # Check if we have data after filtering\n",
    "    if len(data) == 0:\n",
    "        return {\n",
    "            'error': 'No data found matching the specified criteria',\n",
    "            'filters_applied': filter_description,\n",
    "            'n_trials': 0\n",
    "        }\n",
    "    \n",
    "    def calculate_mean_ci(values, confidence=0.95):\n",
    "        \"\"\"Calculate mean and 95% confidence interval for a series of values\"\"\"\n",
    "        values = np.array(values)\n",
    "        # Remove NaN values\n",
    "        values = values[~np.isnan(values)]\n",
    "        \n",
    "        if len(values) == 0:\n",
    "            return {'mean': np.nan, 'ci_lower': np.nan, 'ci_upper': np.nan, 'n': 0}\n",
    "        \n",
    "        mean_val = np.mean(values)\n",
    "        \n",
    "        if len(values) == 1:\n",
    "            # Single value - no confidence interval\n",
    "            return {'mean': mean_val, 'ci_lower': mean_val, 'ci_upper': mean_val, 'n': 1}\n",
    "        \n",
    "        # Calculate standard error\n",
    "        sem = stats.sem(values)  # Standard error of the mean\n",
    "        \n",
    "        # Calculate 95% confidence interval using t-distribution\n",
    "        t_value = stats.t.ppf((1 + confidence) / 2, len(values) - 1)\n",
    "        margin_error = t_value * sem\n",
    "        \n",
    "        return {\n",
    "            'mean': mean_val,\n",
    "            'ci_lower': mean_val - margin_error,\n",
    "            'ci_upper': mean_val + margin_error,\n",
    "            'n': len(values)\n",
    "        }\n",
    "    \n",
    "    def format_stat(stat_dict):\n",
    "        \"\"\"Format statistics for display\"\"\"\n",
    "        if stat_dict['n'] == 0:\n",
    "            return \"No data\"\n",
    "        elif stat_dict['n'] == 1:\n",
    "            return f\"{stat_dict['mean']:.4f} (n=1)\"\n",
    "        else:\n",
    "            return f\"{stat_dict['mean']:.4f} ± {(stat_dict['ci_upper'] - stat_dict['mean']):.4f} (95% CI: {stat_dict['ci_lower']:.4f}-{stat_dict['ci_upper']:.4f})\"\n",
    "    \n",
    "    # Calculate statistics for each metric\n",
    "    results = {\n",
    "        'filters_applied': filter_description,\n",
    "        'n_trials': len(data),\n",
    "        'statistics': {}\n",
    "    }\n",
    "    \n",
    "    # 1. Exploration Score Statistics\n",
    "    exploration_stats = calculate_mean_ci(data['exploration_score'])\n",
    "    results['statistics']['exploration_score'] = {\n",
    "        'raw': exploration_stats,\n",
    "        'formatted': format_stat(exploration_stats)\n",
    "    }\n",
    "    \n",
    "    # 2. Exploitation Score Statistics  \n",
    "    exploitation_stats = calculate_mean_ci(data['exploitation_score'])\n",
    "    results['statistics']['exploitation_score'] = {\n",
    "        'raw': exploitation_stats,\n",
    "        'formatted': format_stat(exploitation_stats)\n",
    "    }\n",
    "    \n",
    "    # 3. Exploration-Exploitation Ratio Statistics\n",
    "    ratio_stats = calculate_mean_ci(data['exploration_exploitation_ratio'])\n",
    "    results['statistics']['exploration_exploitation_ratio'] = {\n",
    "        'raw': ratio_stats,\n",
    "        'formatted': format_stat(ratio_stats)\n",
    "    }\n",
    "    \n",
    "    # 4. Success Rate (percentage)\n",
    "    success_data = data['is_success']\n",
    "    total_trials = len(success_data)\n",
    "    successful_trials = np.sum(success_data == 1)\n",
    "    success_rate = (successful_trials / total_trials) * 100 if total_trials > 0 else 0\n",
    "    \n",
    "    # Calculate 95% CI for success rate using binomial distribution\n",
    "    if total_trials > 0 and successful_trials > 0 and successful_trials < total_trials:\n",
    "        # Wilson score interval for binomial proportion\n",
    "        z = stats.norm.ppf(0.975)  # 97.5th percentile for 95% CI\n",
    "        p = successful_trials / total_trials\n",
    "        n = total_trials\n",
    "        \n",
    "        denominator = 1 + (z**2 / n)\n",
    "        center = (p + (z**2 / (2*n))) / denominator\n",
    "        margin = z * np.sqrt((p*(1-p) + z**2/(4*n)) / n) / denominator\n",
    "        \n",
    "        success_ci_lower = max(0, (center - margin) * 100)\n",
    "        success_ci_upper = min(100, (center + margin) * 100)\n",
    "    else:\n",
    "        # Edge cases: no trials, all success, or all failure\n",
    "        success_ci_lower = success_rate\n",
    "        success_ci_upper = success_rate\n",
    "    \n",
    "    results['statistics']['success_rate'] = {\n",
    "        'raw': {\n",
    "            'rate': success_rate,\n",
    "            'ci_lower': success_ci_lower,\n",
    "            'ci_upper': success_ci_upper,\n",
    "            'successful_trials': successful_trials,\n",
    "            'total_trials': total_trials\n",
    "        },\n",
    "        'formatted': f\"{success_rate:.2f}% (95% CI: {success_ci_lower:.2f}%-{success_ci_upper:.2f}%) [{successful_trials}/{total_trials}]\"\n",
    "    }\n",
    "    \n",
    "    # 5. Completion Time Statistics\n",
    "    completion_time_stats = calculate_mean_ci(data['completion_time'])\n",
    "    results['statistics']['completion_time'] = {\n",
    "        'raw': completion_time_stats,\n",
    "        'formatted': format_stat(completion_time_stats).replace('.4f', '.2f')  # Use 2 decimal places for time\n",
    "    }\n",
    "    \n",
    "    # 6. RMSD Statistics\n",
    "    rmsd_stats = calculate_mean_ci(data['rmsd'])\n",
    "    results['statistics']['rmsd'] = {\n",
    "        'raw': rmsd_stats,\n",
    "        'formatted': format_stat(rmsd_stats)\n",
    "    }\n",
    "    \n",
    "    return results\n",
    "\n",
    "def print_trajectory_statistics(df, participant_id=None, session_no=None, task_type=None):\n",
    "    \"\"\"\n",
    "    Print formatted trajectory statistics in a readable format.\n",
    "    \"\"\"\n",
    "    results = calculate_trajectory_statistics(df, participant_id, session_no, task_type)\n",
    "    \n",
    "    if 'error' in results:\n",
    "        print(f\"Error: {results['error']}\")\n",
    "        print(f\"Filters applied: {', '.join(results['filters_applied']) if results['filters_applied'] else 'None'}\")\n",
    "        return results\n",
    "    \n",
    "    print(\"=\" * 80)\n",
    "    print(\"TRAJECTORY PERFORMANCE STATISTICS\")\n",
    "    print(\"=\" * 80)\n",
    "    \n",
    "    # Print filter information\n",
    "    if results['filters_applied']:\n",
    "        print(f\"Filters applied: {', '.join(results['filters_applied'])}\")\n",
    "    else:\n",
    "        print(\"Filters applied: None (All data)\")\n",
    "    \n",
    "    print(f\"Total trials analyzed: {results['n_trials']}\")\n",
    "    print()\n",
    "    \n",
    "    # Print each statistic\n",
    "    stats = results['statistics']\n",
    "    \n",
    "    print(\"EXPLORATION & EXPLOITATION METRICS:\")\n",
    "    print(\"-\" * 40)\n",
    "    print(f\"Exploration Score:     {stats['exploration_score']['formatted']}\")\n",
    "    print(f\"Exploitation Score:    {stats['exploitation_score']['formatted']}\")\n",
    "    print(f\"Exploration/Exploitation Ratio: {stats['exploration_exploitation_ratio']['formatted']}\")\n",
    "    print()\n",
    "    \n",
    "    print(\"PERFORMANCE METRICS:\")\n",
    "    print(\"-\" * 40)\n",
    "    print(f\"Success Rate:          {stats['success_rate']['formatted']}\")\n",
    "    print(f\"Mean Completion Time:  {stats['completion_time']['formatted']} ms\")\n",
    "    print(f\"Mean RMSD:             {stats['rmsd']['formatted']}\")\n",
    "    print()\n",
    "    \n",
    "    return results\n",
    "\n",
    "# Example usage functions for common scenarios:\n",
    "\n",
    "def analyze_participant_session_task(df, participant_id, session_no, task_type):\n",
    "    \"\"\"Analyze specific participant, session, and task combination\"\"\"\n",
    "    return print_trajectory_statistics(df, participant_id=participant_id, \n",
    "                                     session_no=session_no, task_type=task_type)\n",
    "\n",
    "def analyze_participant_all_sessions(df, participant_id, task_type=None):\n",
    "    \"\"\"Analyze specific participant across all sessions\"\"\"\n",
    "    return print_trajectory_statistics(df, participant_id=participant_id, task_type=task_type)\n",
    "\n",
    "def analyze_session_all_participants(df, session_no, task_type=None):\n",
    "    \"\"\"Analyze specific session across all participants\"\"\"\n",
    "    return print_trajectory_statistics(df, session_no=session_no, task_type=task_type)\n",
    "\n",
    "def analyze_task_all_data(df, task_type):\n",
    "    \"\"\"Analyze specific task type across all participants and sessions\"\"\"\n",
    "    return print_trajectory_statistics(df, task_type=task_type)\n",
    "\n",
    "def analyze_all_data(df):\n",
    "    \"\"\"Analyze all data without filters\"\"\"\n",
    "    return print_trajectory_statistics(df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6c597695-817d-4316-be5d-007290cba6b9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "TRAJECTORY PERFORMANCE STATISTICS\n",
      "================================================================================\n",
      "Filters applied: Participant: MTRLRN015, Session: 5, Task: Bimanual\n",
      "Total trials analyzed: 24\n",
      "\n",
      "EXPLORATION & EXPLOITATION METRICS:\n",
      "----------------------------------------\n",
      "Exploration Score:     0.0271 ± 0.0051 (95% CI: 0.0221-0.0322)\n",
      "Exploitation Score:    0.4803 ± 0.0631 (95% CI: 0.4172-0.5434)\n",
      "Exploration/Exploitation Ratio: 0.0681 ± 0.0230 (95% CI: 0.0451-0.0910)\n",
      "\n",
      "PERFORMANCE METRICS:\n",
      "----------------------------------------\n",
      "Success Rate:          100.00% (95% CI: 100.00%-100.00%) [24/24]\n",
      "Mean Completion Time:  5.0968 ± 0.6407 (95% CI: 4.4560-5.7375) ms\n",
      "Mean RMSD:             152.8541 ± 15.4856 (95% CI: 137.3685-168.3396)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results = analyze_participant_session_task(merged_df, participant_id=\"MTRLRN015\", session_no=5, task_type=\"Bimanual\")\n",
    "# results = analyze_participant_all_sessions(df_with_ratio_progress, 'MTRLRN002')\n",
    "# results = analyze_task_all_data(df_with_ratio_progress, 'Bimanual')\n",
    "# results = analyze_all_data(df_with_ratio_progress)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "02b391a8-1d6b-4d69-a75e-93b5993cdb49",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "TRAJECTORY PERFORMANCE STATISTICS\n",
      "================================================================================\n",
      "Filters applied: Participant: MTRLRN015, Session: 1, Task: Bimanual\n",
      "Total trials analyzed: 24\n",
      "\n",
      "EXPLORATION & EXPLOITATION METRICS:\n",
      "----------------------------------------\n",
      "Exploration Score:     0.0916 ± 0.0282 (95% CI: 0.0634-0.1198)\n",
      "Exploitation Score:    0.3448 ± 0.0647 (95% CI: 0.2802-0.4095)\n",
      "Exploration/Exploitation Ratio: 0.5718 ± 0.4314 (95% CI: 0.1405-1.0032)\n",
      "\n",
      "PERFORMANCE METRICS:\n",
      "----------------------------------------\n",
      "Success Rate:          20.83% (95% CI: 9.24%-40.47%) [5/24]\n",
      "Mean Completion Time:  10.1939 ± 0.7247 (95% CI: 9.4692-10.9185) ms\n",
      "Mean RMSD:             270.9083 ± 42.6062 (95% CI: 228.3021-313.5145)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results = analyze_participant_session_task(merged_df, participant_id=\"MTRLRN015\", session_no=1, task_type=\"Bimanual\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8d5ce1e0-1242-463f-8034-8d7e9955e8ad",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "TRAJECTORY PERFORMANCE STATISTICS\n",
      "================================================================================\n",
      "Filters applied: Participant: MTRLRN070, Session: 5, Task: Unimanual\n",
      "Total trials analyzed: 24\n",
      "\n",
      "EXPLORATION & EXPLOITATION METRICS:\n",
      "----------------------------------------\n",
      "Exploration Score:     0.0153 ± 0.0054 (95% CI: 0.0099-0.0207)\n",
      "Exploitation Score:    0.6404 ± 0.0926 (95% CI: 0.5478-0.7331)\n",
      "Exploration/Exploitation Ratio: 0.0433 ± 0.0315 (95% CI: 0.0118-0.0748)\n",
      "\n",
      "PERFORMANCE METRICS:\n",
      "----------------------------------------\n",
      "Success Rate:          100.00% (95% CI: 100.00%-100.00%) [24/24]\n",
      "Mean Completion Time:  2.3605 ± 0.4659 (95% CI: 1.8946-2.8264) ms\n",
      "Mean RMSD:             35.8417 ± 14.1163 (95% CI: 21.7254-49.9580)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results = analyze_participant_session_task(merged_df, participant_id=\"MTRLRN070\", session_no=5, task_type=\"Unimanual\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cf5e098c-f569-41bc-bbab-1f130c00b0f0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "TRAJECTORY PERFORMANCE STATISTICS\n",
      "================================================================================\n",
      "Filters applied: Participant: MTRLRN070, Session: 1, Task: Unimanual\n",
      "Total trials analyzed: 24\n",
      "\n",
      "EXPLORATION & EXPLOITATION METRICS:\n",
      "----------------------------------------\n",
      "Exploration Score:     0.0950 ± 0.0260 (95% CI: 0.0689-0.1210)\n",
      "Exploitation Score:    0.3226 ± 0.0480 (95% CI: 0.2746-0.3706)\n",
      "Exploration/Exploitation Ratio: 0.5579 ± 0.4252 (95% CI: 0.1328-0.9831)\n",
      "\n",
      "PERFORMANCE METRICS:\n",
      "----------------------------------------\n",
      "Success Rate:          91.67% (95% CI: 74.15%-97.68%) [22/24]\n",
      "Mean Completion Time:  7.3421 ± 1.0995 (95% CI: 6.2426-8.4416) ms\n",
      "Mean RMSD:             152.7404 ± 23.7967 (95% CI: 128.9437-176.5371)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results = analyze_participant_session_task(merged_df, participant_id=\"MTRLRN070\", session_no=1, task_type=\"Unimanual\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9549759c-d991-4c3c-b48a-c691225bdf01",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}