{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c18d95-99e5-4ea8-920a-35d834b91cee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "_FINAL_DFS_DIR = os.path.join('..', '..', 'final_dfs')\n",
    "_CSV_DIR = os.path.join('..', '..', 'csv_files')\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import plotconfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69664876-e4ad-44a9-8206-4a1d734f0fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import ticker\n",
    "from scipy import stats\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f423c3-e4a5-4b44-ab42-db28216916bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dbf = os.path.join(_FINAL_DFS_DIR, 'results.parquet')\n",
    "df = pd.read_parquet(dbf, engine='pyarrow')\n",
    "\n",
    "\n",
    "mask = ((df['method_name'] == \"LinearRegression\") & \n",
    "        (df['training_size'] == plotconfig.LAST_N) & \n",
    "        (df['eeg_name'] == \"EEG_Raw\") & \n",
    "        (df['test_name'] == \"random\"))\n",
    "\n",
    "mask = ((df['method_name'] == \"LinearRegression\") & \n",
    "        (df['training_size'] == plotconfig.LAST_N) & \n",
    "        (df['eeg_name'] == \"EEG_Raw\") & \n",
    "        (df['test_name'] == \"diagonal\"))\n",
    "\n",
    "results = df[mask]\n",
    "first_row = results.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442e8c8f-7dc7-4a1e-9a38-f8e0e6b15a41",
   "metadata": {},
   "outputs": [],
   "source": [
    "humanfile = os.path.join(_CSV_DIR, 'image-distance-user-evaluations.csv')\n",
    "df_human = pd.read_csv(humanfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f2ffc9d-c334-4b47-b168-aebf0f78c84e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_human['euclidean_d'] = df_human['distance'] * results['true_distances'].explode().max()\n",
    "df_human['reversed_eval'] = 1-df_human['distance_eval']\n",
    "df_human"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d66a510-a9b9-42d3-96e1-98257178b51b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "# sns.scatterplot(data=first_row, x=\"true_distances\", y=\"scores\", s=200)\n",
    "sns.regplot(data=first_row, x=\"true_distances\", y=\"scores\", \n",
    "            scatter_kws={\"s\": 200})\n",
    "\n",
    "# Calculate regression statistics\n",
    "x = first_row[\"true_distances\"]\n",
    "y = first_row[\"scores\"]\n",
    "slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)\n",
    "p_value_formatted = f\"{p_value:.1e}\"\n",
    "\n",
    "marker_color = sns.color_palette()[0]\n",
    "statistic_fontsize = 20\n",
    "plt.text(0.05, 0.40, f'Algorithm\\n'\n",
    "                    f'r = {r_value:.2f}, \\n'\n",
    "                    f'p = {p_value_formatted}', \n",
    "        transform=ax.transAxes,\n",
    "        verticalalignment='top',\n",
    "        fontsize=statistic_fontsize, \n",
    "        color=marker_color)\n",
    "\n",
    "fontsize = 20\n",
    "plt.title(\"Score vs distance to target face\", fontsize=fontsize)\n",
    "plt.xlabel(\"Distance to target face\", fontsize=fontsize)\n",
    "plt.ylabel(\"Algorithm Score\", fontsize=fontsize)\n",
    "\n",
    "plt.xlim(-2, 48)  # Set x-axis limits from 0 to 60\n",
    "plt.ylim(1.076, 1.106)  # Set y-axis limits from 0 to 10\n",
    "plt.tight_layout()\n",
    "\n",
    "labelsize = 16\n",
    "ax.tick_params(axis='x', labelsize=labelsize)\n",
    "ax.tick_params(axis='y', labelsize=labelsize, colors=marker_color)\n",
    "ax.yaxis.label.set_color(marker_color)\n",
    "ax.spines['left'].set_color(marker_color)\n",
    "ax.spines['left'].set_linewidth(2.5)  # Adjust the linewidth as needed\n",
    "\n",
    "\n",
    "ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))\n",
    "ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n",
    "\n",
    "ax2 = ax.twinx()\n",
    "marker_color = sns.color_palette()[1]\n",
    "sns.regplot(data=df_human, x=\"euclidean_d\", y=\"reversed_eval\", \n",
    "            color=marker_color, scatter_kws={\"s\": 200}, marker='^')\n",
    "\n",
    "ax2.set_ylabel(\"Human Score\", fontsize=fontsize)\n",
    "\n",
    "\n",
    "# Calculate regression statistics\n",
    "x = df_human[\"euclidean_d\"]\n",
    "y = df_human[\"reversed_eval\"]\n",
    "slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)\n",
    "p_value_formatted = f\"{p_value:.1e}\"\n",
    "\n",
    "plt.text(0.35, 0.78, f'Human\\n'\n",
    "                     f'r = {r_value:.2f}, \\n'\n",
    "                    f'p = {p_value_formatted}', \n",
    "        transform=ax.transAxes,\n",
    "        verticalalignment='top',\n",
    "        fontsize=statistic_fontsize, \n",
    "        color=marker_color)\n",
    "\n",
    "ax2.yaxis.label.set_color(marker_color)\n",
    "ax2.tick_params(axis='y', labelsize=labelsize, colors=marker_color)\n",
    "ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))\n",
    "ax2.spines['right'].set_color(marker_color)\n",
    "ax2.spines['right'].set_linewidth(2.5)  # Adjust the linewidth as needed\n",
    "\n",
    "# ax.set_zorder(1) \n",
    "# ax2.set_zorder(2) \n",
    "\n",
    "# Save as high-definition\n",
    "plotconfig.save_fig(\"score_vs_distance\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc88c63-ad06-4adf-ae2f-5838fbd20923",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
