{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../')\n",
    "from util import plot_data, plot_data_ncol, read_dataframes\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = 'Spiking ResNet-19'\n",
    "DATASET_NAME = 'CIFAR-100 (T=4)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add pathing to directory containing each ablation study seed\n",
    "\n",
    "dfs_S42 = read_dataframes(dir='')\n",
    "dfs_S43 = read_dataframes(dir='')\n",
    "dfs_S44 = read_dataframes(dir='')\n",
    "print(len(dfs_S42), len(dfs_S43), len(dfs_S44))\n",
    "\n",
    "dfs_S42_val_acc = [df[\"Validation Accuracy\"].max() for df in dfs_S42]\n",
    "dfs_S43_val_acc = [df[\"Validation Accuracy\"].max() for df in dfs_S43]\n",
    "dfs_S44_val_acc = [df[\"Validation Accuracy\"].max() for df in dfs_S44]\n",
    "\n",
    "\n",
    "dfs_avg_val_acc = [np.mean([dfs_S42_val_acc[i], dfs_S43_val_acc[i], dfs_S44_val_acc[i]]) for i in range(len(dfs_S42_val_acc))]\n",
    "dfs_avg_val_acc_std = [np.std([dfs_S42_val_acc[i], dfs_S43_val_acc[i], dfs_S44_val_acc[i]]) for i in range(len(dfs_S42_val_acc))]\n",
    "print(len(dfs_avg_val_acc), dfs_avg_val_acc)\n",
    "print(dfs_avg_val_acc_std)\n",
    "\n",
    "labels = ['1', '2', '3', '4', '5', '6', '7', '8', '9']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 5))\n",
    "\n",
    "# Plot main accuracy curve\n",
    "\n",
    "# Add error bars (std dev)\n",
    "ax.errorbar(\n",
    "    labels, \n",
    "    dfs_avg_val_acc, \n",
    "    yerr=dfs_avg_val_acc_std, \n",
    "    fmt='none',\n",
    "    ecolor='dimgray',\n",
    "    elinewidth=1.5,\n",
    "    capsize=6,\n",
    "    capthick=1.4,\n",
    "    alpha=1.0,\n",
    "    label='Standard Deviation'\n",
    ")\n",
    "\n",
    "\n",
    "ax.plot(\n",
    "    labels, \n",
    "    dfs_avg_val_acc, \n",
    "    marker='s', \n",
    "    linestyle='-', \n",
    "    linewidth=1.5,\n",
    "    markersize=6.5,\n",
    "    color='black',\n",
    "    label='Mean Validation Accuracy'\n",
    ")\n",
    "\n",
    "# Axis labels\n",
    "ax.set_xlabel('Polynomial Degree', fontsize=16, labelpad=10)\n",
    "ax.set_ylabel('Validation Accuracy (%)', fontsize=16, labelpad=10)\n",
    "\n",
    "# Tick formatting\n",
    "ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "ax.set_xticks(labels[1:])  # Ensure ticks match labels\n",
    "\n",
    "# Axis spines cleanup\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "# Grid for better readability\n",
    "ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.6)\n",
    "\n",
    "# Layout and save\n",
    "plt.tight_layout()\n",
    "fig.savefig(f'./Graphs/{DATASET_NAME}_ablation.pdf', bbox_inches='tight', dpi=600, format='pdf')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SNN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
