{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beefdd90-a169-4c76-b50b-4c3af4ff1332",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "_FINAL_DFS_DIR = os.path.join('..', '..', 'final_dfs')\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import plotconfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f557a9c-6278-4418-b3e9-218469b10d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import ticker\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "001da3bd-0f73-475a-ac48-3c680250f533",
   "metadata": {},
   "outputs": [],
   "source": [
    "dbf = os.path.join(_FINAL_DFS_DIR, 'results.parquet')\n",
    "df = pd.read_parquet(dbf, engine='pyarrow')\n",
    "\n",
    "mask = ((df['training_size'] == plotconfig.N_FOR_PERF_SCORE_COMPARISON) &\n",
    "        (df['eeg_name'] == \"EEG_Raw\") & \n",
    "        (df['test_name'] == \"random\"))\n",
    "\n",
    "df = df[mask]\n",
    "df['method_name'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bc3c611-0ab3-4bb7-a15c-83593bd8fbcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming your DataFrame is named 'df'\n",
    "grouped_df = df.groupby(['method_name'])\n",
    "\n",
    "# Calculate mean and std for each subtable\n",
    "compiled_result = grouped_df.agg({\n",
    "    'target_raw_score': ['mean', 'std', 'sem'],\n",
    "    'target_shuffled_score': ['mean', 'std', 'sem'],\n",
    "    'pearsonr_statistic': ['mean', 'std', 'sem'],\n",
    "    'fit_time': ['mean', 'std', 'sem']\n",
    "})\n",
    "\n",
    "# Flatten the column names\n",
    "compiled_result.columns = ['_'.join(col).strip() for col in compiled_result.columns.values]\n",
    "\n",
    "# Reset the index to display method_name and training_size as columns\n",
    "compiled_result = compiled_result.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d971a5da-442b-425f-a604-833d02d65960",
   "metadata": {},
   "outputs": [],
   "source": [
    "compiled_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b03d970-0c28-48b4-b99e-0b93735ae685",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "\n",
    "ax = sns.scatterplot(data=compiled_result, x=\"target_raw_score_mean\", y=\"target_shuffled_score_mean\", hue=\"method_name\", s=50)\n",
    "\n",
    "# Get the color palette used by seaborn\n",
    "palette = sns.color_palette()\n",
    "color_dict = dict(zip(compiled_result['method_name'].unique(), palette))\n",
    "\n",
    "# Add error bars with matching colors\n",
    "for method in compiled_result['method_name'].unique():\n",
    "    method_data = compiled_result[compiled_result['method_name'] == method]\n",
    "    ax.errorbar(method_data['target_raw_score_mean'], method_data['target_shuffled_score_mean'],\n",
    "                xerr=method_data['target_raw_score_std'], yerr=method_data['target_shuffled_score_std'],\n",
    "                fmt='none', ecolor=color_dict[method], alpha=0.5, elinewidth=2, capsize=5, capthick=2)\n",
    "\n",
    "bounds = [0.88, 1.45]\n",
    "plt.plot(bounds, bounds, c='k', linestyle='--')\n",
    "\n",
    "\n",
    "ax.set_xlabel('RMSE')\n",
    "ax.set_ylabel('RMSE Shuffled')\n",
    "ax.set_title('EEG vs Shuffled EEG')\n",
    "\n",
    "plt.xlim(bounds)\n",
    "plt.ylim(bounds)\n",
    "plt.gca().set_aspect('equal')\n",
    "\n",
    "# Save as high-definition\n",
    "plotconfig.save_fig(\"aligned_vs_shuffled\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94bc5ba0-d2f5-47f3-af27-23aa6776b153",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "\n",
    "# Create the scatter plot\n",
    "ax = sns.scatterplot(data=compiled_result, x=\"fit_time_mean\", y=\"pearsonr_statistic_mean\", hue=\"method_name\", s=50)\n",
    "\n",
    "# Get the color palette used by seaborn\n",
    "palette = sns.color_palette()\n",
    "color_dict = dict(zip(compiled_result['method_name'].unique(), palette))\n",
    "\n",
    "# Add error bars with matching colors\n",
    "for method in compiled_result['method_name'].unique():\n",
    "    method_data = compiled_result[compiled_result['method_name'] == method]\n",
    "    ax.errorbar(method_data['fit_time_mean'], method_data['pearsonr_statistic_mean'],\n",
    "                xerr=method_data['fit_time_std'], yerr=method_data['pearsonr_statistic_std'],\n",
    "                fmt='none', ecolor=color_dict[method], alpha=0.5, elinewidth=2, capsize=5, capthick=2)\n",
    "\n",
    "# Set x-axis to log scale\n",
    "# ax.set_xscale('log')\n",
    "\n",
    "# Define custom formatter function\n",
    "# def time_formatter(x, pos):\n",
    "#     if x < 60:\n",
    "#         return f\"{x:.0f}s\"\n",
    "#     elif x < 3600:\n",
    "#         return f\"{x/60:.0f}m\"\n",
    "#     else:\n",
    "#         return f\"{x/3600:.0f}h\"\n",
    "\n",
    "# # Apply custom formatter to x-axis\n",
    "# ax.xaxis.set_major_formatter(FuncFormatter(time_formatter))\n",
    "\n",
    "# # Set x-axis ticks\n",
    "# ax.set_xticks([1, 60, 3600, 86400])  # 1 second, 1 minute, 1 hour, 1 day\n",
    "\n",
    "# # Adjust x-axis limits if needed\n",
    "# min_time = compiled_result['fit_time_mean'].min()\n",
    "# max_time = compiled_result['fit_time_mean'].max()\n",
    "# ax.set_xlim(min_time / 2, max_time * 2)  # Adjust as needed\n",
    "\n",
    "# # Set labels and title\n",
    "ax.set_xlabel('Scoring Time (s)')\n",
    "ax.set_ylabel('Pearson Correlation')\n",
    "ax.set_title('Model Performance: Scoring Time vs Correlation')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.gca().invert_yaxis()\n",
    "\n",
    "# Save as high-definition\n",
    "plotconfig.save_fig(\"time_vs_pearson\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a069c6ae-d4e4-4110-ad78-15aa14e9fcfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "\n",
    "# Create the scatter plot\n",
    "ax = sns.scatterplot(data=compiled_result, x=\"target_raw_score_mean\", y=\"pearsonr_statistic_mean\", hue=\"method_name\", s=50)\n",
    "\n",
    "# Get the color palette used by seaborn\n",
    "palette = sns.color_palette()\n",
    "color_dict = dict(zip(compiled_result['method_name'].unique(), palette))\n",
    "\n",
    "# Add error bars with matching colors\n",
    "for method in compiled_result['method_name'].unique():\n",
    "    method_data = compiled_result[compiled_result['method_name'] == method]\n",
    "    ax.errorbar(method_data['target_raw_score_mean'], method_data['pearsonr_statistic_mean'],\n",
    "                xerr=method_data['target_raw_score_std'], yerr=method_data['pearsonr_statistic_std'],\n",
    "                fmt='none', ecolor=color_dict[method], alpha=0.5, elinewidth=2, capsize=5, capthick=2)\n",
    "\n",
    "# # Set labels and title\n",
    "ax.set_xlabel('RMSE')\n",
    "ax.set_ylabel('Pearson Correlation')\n",
    "ax.set_title('Model Performance: Scoring Time vs Correlation')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.gca().invert_yaxis()\n",
    "\n",
    "# Save as high-definition\n",
    "plotconfig.save_fig(\"rmse_target_vs_pearson\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a926158d-944d-49f1-a56c-8412637abc2f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5cff4e9-d9fb-4c82-878c-aa60fa9c0246",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
