{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "font = {'size'   : 30}\n",
    "\n",
    "plt.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tolerance = \"1e-05\"\n",
    "names = [\n",
    "\t\"node\", \n",
    "\t\"anode\", \n",
    "\t\"sonode\", \n",
    "\t\"hbnode\", \n",
    "\t\"ghbnode\", \n",
    "\t\"nesterovnode\", \n",
    "\t\"gnesterovnode\"\n",
    "]\n",
    "alt_names = [\n",
    "\t\"NODE\", \n",
    "\t\"ANODE\", \n",
    "\t\"SONODE\", \n",
    "\t\"HBNODE\", \n",
    "\t\"GHBNODE\", \n",
    "\t\"NesterovNODE\", \n",
    "\t\"GNesterovNODE\"\n",
    "]\n",
    "df_names = {}\n",
    "for name in names:\n",
    "\tfilepath = f\"../imgdat/1_2/{name}_{tolerance}.csv\"\n",
    "\ttemp_df = pd.read_csv(filepath, header=None, names=[\"model\", \"test#\", \"train/test\", \"iter\", \"loss\", \"acc\", \"forwardnfe\", \"backwardnfe\", \"time/iter\", \"time_elapsed\"])\n",
    "\tdf_names[name] = temp_df\n",
    "df_names[names[-1]].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = [\n",
    "\t\"mediumvioletred\",\n",
    "\t\"red\",\n",
    "\t\"deepskyblue\",\n",
    "\t\"royalblue\",\n",
    "\t\"navy\",\n",
    "\t\"green\",\n",
    "\t\"darkorange\",\n",
    "]\n",
    "line_styles = [\n",
    "\t':',\n",
    "\t'--',\n",
    "\t'-.',\n",
    "\t'-.',\n",
    "\t'-.',\n",
    "\t'-',\n",
    "\t'-'\n",
    "]\n",
    "line_widths = [\n",
    "\t5,\n",
    "\t5,\n",
    "\t5,\n",
    "\t5,\n",
    "\t5,\n",
    "\t7,\n",
    "\t7\n",
    "]\n",
    "\n",
    "fig = plt.figure(figsize=(25, 15))\n",
    "gs = fig.add_gridspec(2, 6, hspace=0.25, wspace=1.2)\n",
    "ax1 = fig.add_subplot(gs[0, :2])\n",
    "ax2 = fig.add_subplot(gs[0, 2:4])\n",
    "ax3 = fig.add_subplot(gs[0, 4:])\n",
    "ax4 = fig.add_subplot(gs[1, 1:3])\n",
    "ax5 = fig.add_subplot(gs[1, 3:5])\n",
    "axes = (ax1, ax2, ax4)\n",
    "height_width_ratio = \"auto\"\n",
    "alt_attr_names = [\"NFEs (forward)\", \"NFEs (backward)\", \"Loss\"]\n",
    "for j, attribute in enumerate([\"forwardnfe\", \"backwardnfe\", \"loss\"]):\n",
    "\taxes[j].set_aspect(height_width_ratio)\n",
    "\tfor i, name in enumerate(names):\n",
    "\t\t# print(i, name)\n",
    "\t\tdf_name = df_names[name]\n",
    "\t\tdf_name_train = df_name.loc[df_name[\"train/test\"] == \"train\"]\n",
    "\t\tattr_arr = df_name_train[attribute]\n",
    "\t\titeration_arr = df_name_train[\"iter\"]\n",
    "\t\tassert attr_arr.shape[0] <= 40 # max number of iterations\n",
    "\t\taxes[j].plot(iteration_arr, attr_arr, line_styles[i], linewidth=line_widths[i], color=colors[i], label=alt_names[i])\n",
    "\taxes[j].set(xlabel=\"Epoch\", ylabel=f\"Train {alt_attr_names[j]}\")\n",
    "\tif attribute == \"backwardnfe\":\n",
    "\t\taxes[j].set_ylim([35, 110])\n",
    "\tif attribute == \"forwardnfe\":\n",
    "\t\taxes[j].set_ylim([20, 90])\t\n",
    "\tif attribute == \"loss\":\n",
    "\t\taxes[j].set_ylim(0.0, 0.3)\n",
    "\taxes[j].grid()\n",
    "alt_attr_names = [\"Accuracy\", \"NFEs (forward)\"]\n",
    "offset = 2\n",
    "axes = (ax5, ax3)\n",
    "for j, attribute in enumerate([\"acc\", \"forwardnfe\"]):\n",
    "\taxes[j].set_aspect(height_width_ratio)\n",
    "\tfor i, name in enumerate(names):\n",
    "\t\tdf_name = df_names[name]\n",
    "\t\tdf_name_train = df_name.loc[df_name[\"train/test\"] == \"test\"]\n",
    "\t\tattr_arr = df_name_train[attribute]\n",
    "\t\tif attribute == \"acc\":\n",
    "\t\t\tprint(f\"Accuracy of {name}: {np.max(attr_arr)}\")\n",
    "\t\titeration_arr = df_name_train[\"iter\"]\n",
    "\t\tassert attr_arr.shape[0] <= 40 # max number of iterations\n",
    "\t\taxes[j].plot(iteration_arr, attr_arr, line_styles[i], color=colors[i], linewidth=line_widths[i], label=alt_names[i])\n",
    "\taxes[j].set(xlabel=\"Epoch\", ylabel=f\"Test {alt_attr_names[j]}\")\n",
    "\tif attribute == \"acc\":\n",
    "\t\taxes[j].set_xlim(5, 40)\n",
    "\t\taxes[j].set_ylim(0.96, 0.985)\n",
    "\tif attribute == \"forwardnfe\":\n",
    "\t\taxes[j].set_ylim(30, 90)\n",
    "\t# plt.legend()\n",
    "\taxes[j].grid()\n",
    "axbox = axes[0].get_position()\n",
    "l5 = plt.legend(bbox_to_anchor=(0.5, axbox.y0-0.22), loc=\"lower center\", \n",
    "                bbox_transform=fig.transFigure, ncol=3)\n",
    "plt.savefig(f\"mnist.pdf\", transparent = True, bbox_inches = 'tight', pad_inches = 0)\n",
    "plt.show()\n",
    "\n",
    "for i, name in enumerate(names):\n",
    "\tdf_name = df_names[name]\n",
    "\tdf_name_train = df_name.loc[df_name[\"train/test\"] == \"test\"]\n",
    "\tattr_arr = df_name_train[\"acc\"]\n",
    "\tprint(f\"Accuracy of {name}: {np.max(attr_arr)}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "6fde2ce5c4dc601402da0c7de54c7ec149de6129e6c5f2f584a44d976ec7ca4a"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('nesterov_node': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
