{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "import matplotlib\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "torch.set_num_threads(1) #cpu num\n",
    "from fractions import Fraction\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import math\n",
    "from collections import defaultdict\n",
    "import seaborn as sns\n",
    "from matplotlib.lines import Line2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../AllLogs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfolder = '%s/reg/%s/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DNAME = ['ijcnn1','rcv1','real-sim','news20','covtype']\n",
    "ENDEP = [20,30,20,40,20]\n",
    "TICK=[5,10,5,10,5]\n",
    "ALGO_NAME = ['AI-SARAH', 'SARAH', 'SARAH+', 'SVRG', 'Adam', 'SGD w/m']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Legend for Figure 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "hue_order = ['AI-SARAH','SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "handles=[Line2D([0],[0],color=palette[i],linewidth=6) for i in range(len(hue_order))]\n",
    "fig,ax=plt.subplots(1,1,figsize=(15,1))\n",
    "ax.legend(handles=handles,labels=hue_order,fancybox=True,fontsize=10,loc='center',\\\n",
    "         ncol=len(hue_order),prop={'size': 15})\n",
    "ax.axis('off')\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 3 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2={}\n",
    "for dd in range(len(DNAME)):\n",
    "    dname = DNAME[dd]\n",
    "    DF = []\n",
    "    endep = ENDEP[dd]    \n",
    "    for al in range(len(ALGO_NAME)):\n",
    "        alg_name = ALGO_NAME[al]\n",
    "        temp_logfolder = logfolder%(dname,algo)\n",
    "        cnt=0\n",
    "        for f in os.listdir(temp_logfolder):\n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            if alg_name=='AI-SARAH' and parm[2]!=1/32:\n",
    "                continue\n",
    "            stat = temp['stat']\n",
    "            stat = [si for si in stat if si[0]<=endep]\n",
    "            stat = np.array(stat)\n",
    "            hist = temp['hist']\n",
    "            hist = [si for si in hist if si[0]<=endep]\n",
    "            hist = np.array(hist)\n",
    "            \n",
    "            if alg_name in ['Adam', 'SGD w/m']:\n",
    "                tp = list(stat[:,1])\n",
    "            else:\n",
    "                tp = list(stat[:,3])\n",
    "                \n",
    "            tp=[np.sum(tp[:oi]) for oi in range(1,len(tp)+1)]\n",
    "    \n",
    "            ep = list(hist[:,0])\n",
    "            grad = list(hist[:,2])\n",
    "            grad = [np.min(grad[:oi]) for oi in range(1,len(grad)+1)]\n",
    "            DF+=[[alg_name,cnt,ei,ti,gi] for ei,ti,gi in zip(ep,tp,grad)]\n",
    "            cnt+=1\n",
    "    df2[dname]=pd.DataFrame(data=DF,columns=['algo','cnt','ep','tp','grad'])    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "hue_order = ['AI-SARAH','SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "handles=[Line2D([0],[0],color=palette[i],linewidth=3) for i in range(len(hue_order))]\n",
    "\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharey=False)\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df2[DNAME[d]].copy(deep=True)\n",
    "    subdf=subdf.groupby(['algo','cnt']).agg({'tp':'max'}).reset_index()\n",
    "    box=sns.boxenplot(x=\"algo\",y=\"tp\",width=0.8,data=subdf,palette=palette,order=hue_order,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].set_xlabel('')\n",
    "    ax[d].set_xticks([])\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "ax[0].set_ylabel('Avg. Wall Clock (sec)',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "hue_order = ['AI-SARAH','SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "handles=[Line2D([0],[0],color=palette[i],linewidth=3) for i in range(len(hue_order))]\n",
    "\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"whitegrid\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharey=False)\n",
    "\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df2[DNAME[d]].copy(deep=True)\n",
    "    subdf=subdf.groupby(['algo','cnt']).agg({'tp':'max'}).reset_index()\n",
    "    subdf=subdf.groupby('algo').agg({'tp':'sum'}).reset_index()\n",
    "    s1 = subdf.loc[subdf.algo=='AI-SARAH'].copy(deep=True)\n",
    "    s2 = subdf.loc[subdf.algo!='AI-SARAH'].copy(deep=True)\n",
    "    s2['tp']=s2['tp'].apply(lambda x: x*5.0)\n",
    "    subdf = pd.concat([s1,s2], ignore_index=True, sort=False)\n",
    "    \n",
    "    sns.barplot(x=\"algo\", y=\"tp\", data=subdf, order=hue_order,palette=palette,ax=ax[d])\n",
    "    ax[d].set_yscale('log')\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].set_xlabel('Algorithm',fontsize=12)\n",
    "    ax[d].set_xticks([])\n",
    "ax[0].set_ylabel('Total Wall Clock (sec)',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Legend for Figure 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "hue_order = ['AI-SARAH','SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "handles=[Line2D([0],[0],color=palette[0],linewidth=3,dashes=[2,2])]\n",
    "handles+=[Line2D([0],[0],color=palette[i],linewidth=6) for i in range(1,len(hue_order))]\n",
    "fig,ax=plt.subplots(1,1,figsize=(15,1))\n",
    "ax.legend(handles=handles,labels=hue_order,fancybox=True,fontsize=10,loc='center',\\\n",
    "         ncol=len(hue_order),prop={'size': 15})\n",
    "ax.axis('off')\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df={}\n",
    "for dd in range(len(DNAME)):\n",
    "    dname = DNAME[dd]\n",
    "    DF = []\n",
    "    endep = ENDEP[dd]    \n",
    "    for al in range(len(ALGO_NAME)):\n",
    "        alg_name = ALGO_NAME[al]\n",
    "        temp_logfolder = logfolder%(dname,algo)\n",
    "        cnt=0\n",
    "        for f in os.listdir(temp_logfolder):\n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            if alg_name=='AI-SARAH' and parm[2]!=1/32:\n",
    "                continue\n",
    "            stat = temp['stat']\n",
    "            stat = [si for si in stat if si[0]<=endep]\n",
    "            stat = np.array(stat)\n",
    "            hist = temp['hist']\n",
    "            hist = [si for si in hist if si[0]<=endep]\n",
    "            hist = np.array(hist)\n",
    "            ep = list(hist[:,0])\n",
    "            grad = list(hist[:,2])\n",
    "            grad = [np.min(grad[:oi]) for oi in range(1,len(grad)+1)]\n",
    "            DF+=[[alg_name,cnt,ei,gi] for ei,gi in zip(ep,grad)]\n",
    "            cnt+=1\n",
    "    df[dname]=pd.DataFrame(data=DF,columns=['algo','cnt','ep','grad'])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ENDEP = [20,30,20,40,20]\n",
    "STEP = [10,15,10,20,10]\n",
    "\n",
    "CRIT = [[5,ti,ei] for ti,ei in zip(STEP,ENDEP)]\n",
    "\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"whitegrid\")\n",
    "subdf2 = subdf.loc[subdf.algo!='AI-SARAH']\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "\n",
    "aisarah_line=[Line2D([0],[0],color=palette[0],linewidth=3)]\n",
    "\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharey=False)\n",
    "for d in range(len(DNAME)):\n",
    "    crit = CRIT[d]\n",
    "    subdf = df[DNAME[d]].copy(deep=True)\n",
    "    subdf['ep_round']=np.round(subdf['ep'])\n",
    "    cond = [(subdf.ep_round<=ci) for ci in crit]\n",
    "    subdf['ep_step']=np.select(cond,crit)\n",
    "    subdf = subdf.groupby(['algo','cnt','ep_step']).agg({'grad':'min'}).reset_index()\n",
    "    subdf2=subdf.loc[subdf.algo!='AI-SARAH']\n",
    "    \n",
    "    aisarah = subdf.loc[subdf.algo=='AI-SARAH'].groupby(['algo','ep_step']).agg({'grad':'min'}).reset_index()\n",
    "    \n",
    "    hue_order = ['SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "    sns.boxenplot(data=subdf2, x=\"ep_step\", y=\"grad\", hue=\"algo\", hue_order=hue_order,\\\n",
    "                      palette=palette[1:], \\\n",
    "                      linewidth=2,ax=ax[d])\n",
    "    for ci in range(len(crit)):\n",
    "        ax[d].axhline(y=aisarah.loc[aisarah.ep_step==crit[ci]]['grad'].values.item(),\\\n",
    "                      xmin=ci*0.33, xmax=(ci+1)*0.33,color=palette[0],ls='--',lw=2)\n",
    "    ax[d].set_yscale('log')\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "    if d>=0:\n",
    "        ax[d].legend('',frameon=False)\n",
    "ax[0].set_ylabel(r'$||\\nabla P(w)||^2$',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"whitegrid\")\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "\n",
    "aisarah_line=[Line2D([0],[0],color=palette[0],linewidth=3)]\n",
    "\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharey=False)\n",
    "for d in range(len(DNAME)):\n",
    "    subdf = df2[DNAME[d]].copy(deep=True)\n",
    "    subdf['tp_round']=np.round(subdf['tp'])\n",
    "    timelist=sorted(list(subdf.tp_round.unique()))\n",
    "    crit=[10.0,timelist[int(len(timelist)/2)-1],timelist[-1]]\n",
    "    SUBDF=[]\n",
    "    for ci in crit:\n",
    "        temp=subdf.loc[subdf.tp_round<=ci].groupby(['algo','cnt']).agg({'grad':'min'}).reset_index()\n",
    "        temp['tp_step']=ci\n",
    "        SUBDF.append(temp)\n",
    "    SUBDF=pd.concat(SUBDF, ignore_index=True, sort=False)\n",
    "    subdf2=SUBDF.loc[subdf.algo!='AI-SARAH']\n",
    "    \n",
    "    aisarah = SUBDF.loc[SUBDF.algo=='AI-SARAH'].groupby(['algo','tp_step']).agg({'grad':'min'}).reset_index()\n",
    "    \n",
    "    hue_order = ['SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "    sns.boxenplot(data=subdf2, x=\"tp_step\", y=\"grad\", hue=\"algo\", hue_order=hue_order,\\\n",
    "                      palette=palette[1:], \\\n",
    "                      linewidth=2,ax=ax[d])\n",
    "    for ci in range(len(crit)):\n",
    "        ax[d].axhline(y=aisarah.loc[aisarah.tp_step==crit[ci]]['grad'].values.item(),\\\n",
    "                      xmin=ci*0.33, xmax=(ci+1)*0.33,color=palette[0],ls='--',lw=2)\n",
    "    ax[d].set_yscale('log')\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].set_xlabel('Wall Clock (sec)',fontsize=12)\n",
    "    if d>=0:\n",
    "        ax[d].legend('',frameon=False)\n",
    "ax[0].set_ylabel(r'$||\\nabla P(w)||^2$',fontsize=12)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
