{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This is a notebook to plot the results of the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# modify the list depending on the experiment and stored metrics\n",
    "List = [vLosses, tLosses, vERRs, tERRs, vACCs5, tACCs5, avg_GD]\n",
    "\n",
    "# change the name of the file to the file that the results were saved in\n",
    "with open('temp.data', 'rb') as filehandle:\n",
    "    vLosses = pickle.load(filehandle)    \n",
    "    tLosses = pickle.load(filehandle)    \n",
    "    vERRs = pickle.load(filehandle)    \n",
    "    tERRs = pickle.load(filehandle)    \n",
    "    vACCs5 = pickle.load(filehandle)    \n",
    "    tACCs5 = pickle.load(filehandle)    \n",
    "    avg_GD = pickle.load(filehandle)    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_loss = [(i-j)*(i>j) for (i,j) in zip(vLosses, tLosses)]\n",
    "gen_err = [(i-j)*(i>j) for (i,j) in zip(vERRs, tERRs)]\n",
    "num_epochs = np.array(vLosses).size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loss curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=[16,9])\n",
    "epochs = np.arange(0, num_epochs, 1)\n",
    "x = epochs\n",
    "y1 = vLosses\n",
    "y2 = tLosses\n",
    "y3 = gen_loss\n",
    "\n",
    "t0, = axs.semilogx(x, y1, label='test loss', linewidth=5.0)\n",
    "\n",
    "t1, = axs.semilogx(x, y2, label='train loss', linewidth=5.0)\n",
    "\n",
    "t2, = axs.semilogx(x, y3, label='generalization loss', linewidth=5.0)\n",
    "\n",
    "# axs.axvspan(10, 11, color='black', alpha=0.3) # to draw a vertical line\n",
    "\n",
    "\n",
    "axs.legend(handles=[t0,t1,t2], prop={'size': 40})\n",
    "\n",
    "axs.set_xlabel('epochs', fontsize=30)\n",
    "axs.set_ylabel('Cross entropy loss', fontsize=30)\n",
    "\n",
    "axs.tick_params(axis=\"x\", labelsize=27)  \n",
    "axs.tick_params(axis=\"y\", labelsize=27)\n",
    "# fig.savefig('figures/loss.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Error curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=[16,9])\n",
    "epochs = np.arange(0, num_epochs, 1)\n",
    "x = epochs\n",
    "y1 = vERRs\n",
    "y2 = tERRs\n",
    "y3 = gen_err\n",
    "\n",
    "t0, = axs.semilogx(x, y1, label='test error', linewidth=5.0)\n",
    "\n",
    "t1, = axs.semilogx(x, y2, label='train error', linewidth=5.0)\n",
    "\n",
    "t2, = axs.semilogx(x, y3, label='generalization error', linewidth=5.0)\n",
    "\n",
    "# axs.axvspan(10, 11, color='black', alpha=0.3)\n",
    "\n",
    "\n",
    "\n",
    "axs.legend(handles=[t0,t1,t2], prop={'size': 40})\n",
    "\n",
    "axs.set_xlabel('epochs', fontsize=30)\n",
    "axs.set_ylabel('Error percentage', fontsize=30)\n",
    "\n",
    "axs.tick_params(axis=\"x\", labelsize=27)  \n",
    "axs.tick_params(axis=\"y\", labelsize=27)\n",
    "#fig.savefig('figures/error.png') \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient disparity vs Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=[16,9])\n",
    "epochs = np.arange(0, num_epochs, 1)\n",
    "x = epochs\n",
    "y3 = gen_loss\n",
    "y1 = vLosses\n",
    "color = 'tab:blue'\n",
    "axs.set_xlabel('epochs', fontsize=30)\n",
    "axs.set_ylabel('Cross entropy loss', color=\"black\", fontsize=30)\n",
    "t0, = axs.semilogx(x,y3, color=\"tab:green\", label='generalization loss', linewidth=5.0)\n",
    "t1, = axs.semilogx(x,y1, color=color, label='test loss', linewidth=5.0)\n",
    "\n",
    "axs.tick_params(axis='y', labelcolor=\"black\", labelsize=27)\n",
    "axs.tick_params(axis='x', labelsize=27)\n",
    "\n",
    "# axs.axvspan(10, 11, color='black', alpha=0.3)\n",
    "\n",
    "ax2 = axs.twinx()\n",
    "\n",
    "color = 'tab:red'\n",
    "y4 = avg_GD\n",
    "\n",
    "ax2.set_ylabel('Average gradient disparity', color=color, fontsize=30)\n",
    "t2, = ax2.semilogx(x, y4, color=color, linewidth=5.0, label='gradient disparity')\n",
    "\n",
    "ax2.tick_params(axis='y', labelcolor=color, labelsize=27)\n",
    "\n",
    "\n",
    "axs.legend(handles=[t0,t1,t2], prop={'size':40}, loc='upper left')\n",
    "#fig.savefig('figures/loss_vs_GD.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient disparity vs Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=[16,9])\n",
    "epochs = np.arange(0, num_epochs, 1)\n",
    "y3 = gen_err\n",
    "\n",
    "color = 'tab:green'\n",
    "axs.set_ylabel('Generalization error', color=color, fontsize=30)\n",
    "axs.set_xlabel('Iterations', fontsize=30)\n",
    "t0, = axs.semilogx(epochs, y3, color=color, linewidth=5.0, label='generalization error')\n",
    "\n",
    "axs.tick_params(axis='y', labelcolor=color, labelsize=27)\n",
    "axs.tick_params(axis='x', labelsize=27)\n",
    "\n",
    "# axs.axvspan(10, 11, color='black', alpha=0.3)\n",
    "\n",
    "ax2 = axs.twinx()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax2.set_ylabel('Average gradient disparity', color=color, fontsize=30)\n",
    "\n",
    "y4 = avg_GD\n",
    "\n",
    "t1, = ax2.semilogx(x, y4, color=color, linewidth=5.0, label='gradient disparity')\n",
    "\n",
    "ax2.tick_params(axis='y', labelcolor=color, labelsize=27)\n",
    "\n",
    "\n",
    "axs.legend(handles=[t0,t1], prop={'size':40})\n",
    "#fig.savefig('figures/error_vs_GD.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
