{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d609154",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd0e97f",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_ID = 9\n",
    "\n",
    "fdf = pd.read_csv(f\"results/exp_{EXP_ID}/frequency_ablation_study.csv\",index_col=0)\n",
    "ddf = pd.read_csv(f\"results/exp_{EXP_ID}/decay_ablation_study.csv\",index_col=0)\n",
    "pdf = pd.read_csv(f\"results/exp_{EXP_ID}/mode_ablation_study.csv\",index_col=0)\n",
    "sdf = pd.read_csv(f\"results/exp_{EXP_ID}/sampfreq_ablation_study.csv\",index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc44027",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_names_1 = [\n",
    "    'HS', \n",
    "    'operator', \n",
    "    'martin',\n",
    "    'eigenvalue', \n",
    "    'subspace', \n",
    "    'chordal_50',\n",
    "]\n",
    "\n",
    "metric_names_2 = [\n",
    "    'chordal_10',\n",
    "    'chordal_20',\n",
    "    'chordal_30',\n",
    "    'chordal_40',\n",
    "    'chordal_50',\n",
    "    'chordal_60',\n",
    "    'chordal_70',\n",
    "    'chordal_80',\n",
    "    'chordal_90',\n",
    "]\n",
    "\n",
    "metric_labels_1 = {\n",
    "    'HS' : 'Hilbert-Schmidt', \n",
    "    'operator' : 'Operator', \n",
    "    'martin' : 'Martin',\n",
    "    'eigenvalue' : 'SOT', \n",
    "    'subspace' : 'GOT', \n",
    "    'chordal_50' : r'SGOT ($\\eta = 0.5$)',\n",
    "}\n",
    "\n",
    "metric_labels_2 = {\n",
    "    'chordal_10' : r'$\\eta = 0.1$',\n",
    "    'chordal_20' : r'$\\eta = 0.2$',\n",
    "    'chordal_30' : r'$\\eta = 0.3$',\n",
    "    'chordal_40' : r'$\\eta = 0.4$',\n",
    "    'chordal_50' : r'$\\eta = 0.5$',\n",
    "    'chordal_60' : r'$\\eta = 0.6$',\n",
    "    'chordal_70' : r'$\\eta = 0.7$',\n",
    "    'chordal_80' : r'$\\eta = 0.8$',\n",
    "    'chordal_90' : r'$\\eta = 0.9$',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b22a455",
   "metadata": {},
   "outputs": [],
   "source": [
    "tab10 = mpl.colormaps['tab10']\n",
    "metric_plot_colors_1 = {\n",
    "    'HS': {'color': tab10(1), 'linestyle': '-'}, \n",
    "    'operator': {'color': tab10(4), 'linestyle': '-'}, \n",
    "    'martin': {'color': tab10(7), 'linestyle': '-'},\n",
    "    'eigenvalue': {'color': tab10(2), 'linestyle': '-'}, \n",
    "    'subspace': {'color': tab10(3), 'linestyle': '-'},  \n",
    "    'chordal_50': {'color': tab10(0), 'linestyle': '-'},\n",
    "}\n",
    "palette = mpl.colormaps[\"jet\"](np.linspace(0, 1, 9))\n",
    "metric_plot_colors_2 = {\n",
    "    'chordal_10': {'color': palette[0], 'linestyle': '-'},\n",
    "    'chordal_20': {'color': palette[1], 'linestyle': '-'},\n",
    "    'chordal_30': {'color': palette[2], 'linestyle': '-'},\n",
    "    'chordal_40': {'color': palette[3], 'linestyle': '-'},\n",
    "    'chordal_50': {'color': palette[4], 'linestyle': '-'},\n",
    "    'chordal_60': {'color': palette[5], 'linestyle': '-'},\n",
    "    'chordal_70': {'color': palette[6], 'linestyle': '-'},\n",
    "    'chordal_80': {'color': palette[7], 'linestyle': '-'},\n",
    "    'chordal_90': {'color': palette[8], 'linestyle': '-'}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7d3d828",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 14\n",
    "fig, axes = plt.subplots(1, 4, figsize=(13, 3), gridspec_kw={'width_ratios': [1, 1, 1, 1], 'wspace': 0.05})\n",
    "\n",
    "\n",
    "# Frequency\n",
    "for metric in metric_names_1:\n",
    "    label = metric_labels_1[metric]\n",
    "    style = metric_plot_colors_1[metric]\n",
    "    axes[0].plot(fdf['frequency'].values, fdf[metric]/fdf[metric].max(), label=label, **style)\n",
    "axes[0].set_xlabel('Frequency (Hz)')\n",
    "axes[0].set_ylabel('Metric value')\n",
    "\n",
    "# Decay\n",
    "for metric in metric_names_1:\n",
    "    label = metric_labels_1[metric]\n",
    "    style = metric_plot_colors_1[metric]\n",
    "    axes[1].plot(ddf['decay'].values, ddf[metric]/ddf[metric].max(), label=label, **style)\n",
    "axes[1].set_xlabel('Decay rate')\n",
    "axes[1].set_yticklabels([])\n",
    "\n",
    "# Power\n",
    "for metric in metric_names_1:\n",
    "    label = metric_labels_1[metric]\n",
    "    style = metric_plot_colors_1[metric]\n",
    "    axes[2].plot(pdf['power'].values, pdf[metric]/pdf[metric].max(), label=label, **style)\n",
    "axes[2].set_xlabel('Subspace shift')\n",
    "axes[2].set_yticklabels([])\n",
    "\n",
    "# Sampling Frequency\n",
    "for metric in metric_names_1:\n",
    "    label = metric_labels_1[metric]\n",
    "    style = metric_plot_colors_1[metric]\n",
    "    axes[3].plot(sdf['sampfreq'].values, sdf[metric], label=label, **style)\n",
    "axes[3].set_xlabel('Sampling frequency (Hz)')\n",
    "\n",
    "#manage spacing\n",
    "fig.subplots_adjust(wspace=0.05)  # Squeeze the first three plots together\n",
    "# Manually adjust the position of the last axis to increase the gap\n",
    "pos2 = axes[2].get_position()\n",
    "pos3 = axes[3].get_position()\n",
    "gap = 0.05  # increase this value for a wider gap\n",
    "axes[3].set_position([\n",
    "    pos3.x0 + gap, pos3.y0, pos3.width, pos3.height\n",
    "])\n",
    "\n",
    "# Draw a vertical line between the third and fourth subplot\n",
    "divider_x = pos2.x1 + (axes[3].get_position().x0 - pos2.x1) / 2\n",
    "fig.transFigure.invalidate()\n",
    "fig.lines.append(\n",
    "    plt.Line2D([divider_x, divider_x], [0.00, 0.88], color='k', linewidth=1, transform=fig.transFigure, linestyle='-')\n",
    ")\n",
    "\n",
    "# Unique legend outside\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=len(labels),frameon=False,columnspacing=1.0,handletextpad=0.2)\n",
    "plt.savefig(f\"results/exp_{EXP_ID}/metric_comparison.pdf\", bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f610d60",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 9\n",
    "plt.rcParams[\"ytick.labelsize\"] = 9\n",
    "plt.rcParams[\"axes.labelsize\"] = 13\n",
    "plt.rcParams[\"legend.fontsize\"] = 13\n",
    "fig, axes = plt.subplots(1, 4, figsize=(13, 3), gridspec_kw={'width_ratios': [1, 1, 1, 1], 'wspace': 0.05})\n",
    "\n",
    "\n",
    "# Frequency\n",
    "for metric in metric_names_2:\n",
    "    label = metric_labels_2[metric]\n",
    "    style = metric_plot_colors_2[metric]\n",
    "    axes[0].plot(fdf['frequency'].values, fdf[metric]/fdf[metric].max(), label=label, **style)\n",
    "axes[0].set_xlabel('(a) Frequency (Hz)')\n",
    "axes[0].set_ylabel('Metric value')\n",
    "\n",
    "# Decay\n",
    "for metric in metric_names_2:\n",
    "    label = metric_labels_2[metric]\n",
    "    style = metric_plot_colors_2[metric]\n",
    "    axes[1].plot(ddf['decay'].values, ddf[metric]/ddf[metric].max(), label=label, **style)\n",
    "axes[1].set_xlabel('(b) Decay rate')\n",
    "axes[1].set_yticklabels([])\n",
    "\n",
    "# Power\n",
    "for metric in metric_names_2:\n",
    "    label = metric_labels_2[metric]\n",
    "    style = metric_plot_colors_2[metric]\n",
    "    axes[2].plot(pdf['power'].values*2 + 4, pdf[metric]/pdf[metric].max(), label=label, **style)\n",
    "axes[2].set_xlabel('(c) Operator rank')\n",
    "axes[2].set_yticklabels([])\n",
    "\n",
    "# Sampling Frequency\n",
    "for metric in metric_names_2:\n",
    "    label = metric_labels_2[metric]\n",
    "    style = metric_plot_colors_2[metric]\n",
    "    axes[3].plot(sdf['sampfreq'].values, sdf[metric], label=label, **style)\n",
    "axes[3].set_xlabel('(d) Sampling frequency (Hz)')\n",
    "\n",
    "#manage spacing\n",
    "fig.subplots_adjust(wspace=0.05)  # Squeeze the first three plots together\n",
    "# Manually adjust the position of the last axis to increase the gap\n",
    "pos2 = axes[2].get_position()\n",
    "pos3 = axes[3].get_position()\n",
    "gap = 0.05  # increase this value for a wider gap\n",
    "axes[3].set_position([\n",
    "    pos3.x0 + gap, pos3.y0, pos3.width, pos3.height\n",
    "])\n",
    "\n",
    "# Draw a vertical line between the third and fourth subplot\n",
    "divider_x = pos2.x1 + (axes[3].get_position().x0 - pos2.x1) / 2\n",
    "fig.transFigure.invalidate()\n",
    "fig.lines.append(\n",
    "    plt.Line2D([divider_x, divider_x], [0.15, 0.85], color='k', linewidth=1, transform=fig.transFigure, linestyle='-')\n",
    ")\n",
    "\n",
    "# Unique legend outside\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.54, 1.02), ncol=len(labels),frameon=False,columnspacing=1.0,handletextpad=0.2)\n",
    "plt.savefig(f\"results/exp_{EXP_ID}/eta_ablation_study.pdf\", bbox_inches='tight', pad_inches=0.1)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "KooPOT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
