{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "import matplotlib\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "torch.set_num_threads(1) #cpu num\n",
    "from fractions import Fraction\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import math\n",
    "from collections import defaultdict\n",
    "from matplotlib.lines import Line2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fillup(_epp,_ep,_l):\n",
    "    \n",
    "    for ind in range(1,len(_epp)):\n",
    "        if _epp[ind] not in _ep:\n",
    "            _l.insert(ind,_l[ind-1])\n",
    "    return _l    \n",
    "\n",
    "def runmax(_l):\n",
    "    \n",
    "    return [max(_l[:i+1]) for i in range(len(_l))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfolder = '%s/reg/%s/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DNAME = ['ijcnn1','rcv1','real-sim','news20','covtype']\n",
    "ENDEP = [20,30,20,40,20]\n",
    "TICK=[5,10,5,10,5]\n",
    "ALGO_NAME = ['AI-SARAH', 'SARAH', 'SARAH+', 'SVRG', 'Adam', 'SGD w/m']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu = [2e-5,4.94e-5,1.844e-5,6.668e-5,2.295e-6]\n",
    "BS = 64\n",
    "TrainSize = [49990,20242,54231,14997,435759]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df={}\n",
    "for dd in range(len(DNAME)):\n",
    "    dname = DNAME[dd]\n",
    "    mu_d = mu[dd]\n",
    "    perEpoch = TrainSize[dd]//BS\n",
    "    DF = []\n",
    "    for al in range(len(ALGO_NAME)):\n",
    "        alg_name = ALGO_NAME[al]\n",
    "        temp_logfolder = logfolder%(dname,algo)\n",
    "\n",
    "        if alg_name in ['AI-SARAH','SVRG','SARAH','SARAH+']:\n",
    "            tpp=[]\n",
    "            for f in os.listdir(temp_logfolder): \n",
    "                if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                    continue\n",
    "                temp = torch.load(temp_logfolder+f)\n",
    "                stat = np.array(temp['stat'])\n",
    "                ter = list(stat[:,3])\n",
    "                ter = [np.sum(ter[:oi]) for oi in range(1,len(ter)+1)]\n",
    "                tpp+=list(np.round(ter,2))\n",
    "                TPP = sorted(list(set(tpp)))\n",
    "        else:\n",
    "            tpp=[]\n",
    "            for f in os.listdir(temp_logfolder): \n",
    "                if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                    continue\n",
    "                temp = torch.load(temp_logfolder+f)\n",
    "                stat = np.array(temp['stat'])\n",
    "                ter = list(stat[:,1])\n",
    "                ter = [np.sum(ter[:oi]) for oi in range(1,len(ter)+1)]\n",
    "                tpp+=list(np.round(ter,2))\n",
    "                TPP = sorted(list(set(tpp)))\n",
    "        \n",
    "        for f in os.listdir(temp_logfolder):\n",
    "            \n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            \n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            \n",
    "            if alg_name == 'AI-SARAH':\n",
    "                lr = -1\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = parm[2]\n",
    "                if gamma != 1/32:\n",
    "                    continue\n",
    "                schedule = -1\n",
    "                m = -1\n",
    "            if alg_name == 'SARAH+':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = parm[3]\n",
    "                schedule = -1\n",
    "                m = -1\n",
    "            if alg_name=='SARAH':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = -1\n",
    "                m = parm[3]\n",
    "            if alg_name=='Adam':\n",
    "                temp_parm = f.split('-')\n",
    "                lr = temp_parm[1]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = temp_parm[3]\n",
    "                m = -1\n",
    "            if alg_name=='SGD w/m':\n",
    "                temp_parm = f.split('-')\n",
    "                lr = temp_parm[1]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = temp_parm[3]\n",
    "                m = -1\n",
    "            if alg_name=='SVRG':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = -1\n",
    "                m = parm[3]\n",
    "            hist = np.array(temp['hist'])\n",
    "            stat = np.array(temp['stat'])\n",
    "            \n",
    "            if alg_name in ['Adam','SGD w/m']:\n",
    "                tp = list(stat[:,1])\n",
    "                tp = np.round([np.sum(tp[:oi]) for oi in range(1,len(tp)+1)],2)\n",
    "            else:\n",
    "                tp = list(stat[:,3])\n",
    "                tp = np.round([np.sum(tp[:oi]) for oi in range(1,len(tp)+1)],2)\n",
    "            \n",
    "            loss = list(hist[:,1])\n",
    "            grad = list(hist[:,2])\n",
    "            test = list(hist[:,3])\n",
    "            ep = list(np.round(hist[:,0]))\n",
    "\n",
    "            loss = fillup(TPP,tp,loss)\n",
    "            grad = fillup(TPP,tp,grad)\n",
    "            test = fillup(TPP,tp,test)\n",
    "            ep = fillup(TPP,tp,ep)\n",
    "            test = runmax(test)\n",
    "            \n",
    "            DF+=[[alg_name,bs,seed,lr,schedule,gamma,m,\\\n",
    "                  tp_i,li,gi,ti,ep_i] for tp_i,li,gi,ti,ep_i in \\\n",
    "                zip(TPP,loss,grad,test,ep)]\n",
    "            \n",
    "    df[dname]=pd.DataFrame(data=DF,columns=['algo','bs','seed','lr','schedule',\\\n",
    "                                            'gamma','m','tp',\\\n",
    "                                            'loss','grad','test','ep'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2={}\n",
    "for dd in range(len(DNAME)):\n",
    "    dname = DNAME[dd]\n",
    "    mu_d = mu[dd]\n",
    "    perEpoch = TrainSize[dd]//BS\n",
    "    DF = []\n",
    "    for al in range(len(ALGO_NAME)):\n",
    "        algo = ALGO[al]\n",
    "        alg_name = ALGO_NAME[al]\n",
    "        temp_logfolder = logfolder%(dname,algo)        \n",
    "        if alg_name in ['AI-SARAH','SVRG','SARAH','SARAH+']:\n",
    "            gpp=[]\n",
    "            for f in os.listdir(temp_logfolder): \n",
    "                if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                    continue\n",
    "                temp = torch.load(temp_logfolder+f)\n",
    "                stat = np.array(temp['stat'])               \n",
    "                outer = list(stat[:,1])\n",
    "                inner = list(stat[:,2])\n",
    "                if alg_name=='AI-SARAH':\n",
    "                    ter \\\n",
    "                    = np.array([(outer_i+inner_i*4)/perEpoch for outer_i,inner_i in zip(outer,inner)])\n",
    "                else:\n",
    "                    ter \\\n",
    "                    = np.array([(outer_i+inner_i*2)/perEpoch for outer_i,inner_i in zip(outer,inner)])\n",
    "                gpp+=list(np.round(ter))\n",
    "                GPP = sorted(list(set(gpp)))\n",
    "        else:\n",
    "            gpp=[]\n",
    "            for f in os.listdir(temp_logfolder): \n",
    "                if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                    continue\n",
    "                temp = torch.load(temp_logfolder+f)\n",
    "                stat = np.array(temp['stat'])\n",
    "                gpp+=list(np.round(stat[:,0]))\n",
    "                GPP = sorted(list(set(gpp)))\n",
    "        \n",
    "        for f in os.listdir(temp_logfolder):\n",
    "            \n",
    "            if '.tar' not in f or 'DONE' in f or 'RUN' in f:\n",
    "                continue\n",
    "            \n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            \n",
    "            if alg_name == 'AI-SARAH':\n",
    "                lr = -1\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = parm[2]\n",
    "                if gamma != 1/32:\n",
    "                    continue\n",
    "                schedule = -1\n",
    "                m = -1\n",
    "            if alg_name == 'SARAH+':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = parm[3]\n",
    "                schedule = -1\n",
    "                m = -1\n",
    "            if alg_name=='SARAH':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = -1\n",
    "                m = parm[3]\n",
    "            if alg_name=='Adam':\n",
    "                temp_parm = f.split('-')\n",
    "                lr = temp_parm[1]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = temp_parm[3]\n",
    "                m = -1\n",
    "            if alg_name=='SGD w/m':\n",
    "                temp_parm = f.split('-')\n",
    "                lr = temp_parm[1]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = temp_parm[3]\n",
    "                m = -1\n",
    "            if alg_name=='SVRG':\n",
    "                lr = parm[2]\n",
    "                bs = parm[0]\n",
    "                seed = parm[1]\n",
    "                gamma = -1\n",
    "                schedule = -1\n",
    "                m = parm[3]\n",
    "            print('algo: %s, BS: %s, seed: %s, LR: %s, Sche: %s, gamma: %s, m: %s'\\\n",
    "                  %(alg_name,bs,seed,lr,schedule,gamma,m))\n",
    "            \n",
    "            hist = np.array(temp['hist'])\n",
    "            stat = np.array(temp['stat'])\n",
    "            \n",
    "            if alg_name in ['Adam','SGD w/m']:\n",
    "                gp = list(np.round(stat[:,0]))\n",
    "            else:\n",
    "                outer = list(stat[:,1])\n",
    "                inner = list(stat[:,2])\n",
    "                if alg_name=='AI-SARAH':\n",
    "                    ter \\\n",
    "                    = np.array([(outer_i+inner_i*4)/perEpoch for outer_i,inner_i in zip(outer,inner)])\n",
    "                else:\n",
    "                    ter \\\n",
    "                    = np.array([(outer_i+inner_i*2)/perEpoch for outer_i,inner_i in zip(outer,inner)])\n",
    "                gp=list(np.round(ter))\n",
    "            loss = list(hist[:,1])\n",
    "            grad = list(hist[:,2])\n",
    "            test = list(hist[:,3])\n",
    "            ep = list(np.round(hist[:,0]))\n",
    "            loss = fillup(GPP,gp,loss)\n",
    "            grad = fillup(GPP,gp,grad)\n",
    "            test = fillup(GPP,gp,test)\n",
    "            ep = fillup(GPP,gp,ep)\n",
    "            test = runmax(test)\n",
    "            \n",
    "            DF+=[[alg_name,bs,seed,lr,schedule,gamma,m,\\\n",
    "                  gp_i,li,gi,ti,ep_i] for gp_i,li,gi,ti,ep_i in \\\n",
    "                zip(GPP,loss,grad,test,ep)]\n",
    "            \n",
    "    df2[dname]=pd.DataFrame(data=DF,columns=['algo','bs','seed','lr','schedule',\\\n",
    "                                            'gamma','m','gp',\\\n",
    "                                            'loss','grad','test','ep'])  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Legend for Figures 5,6,7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "hue_order = ['AI-SARAH','SARAH','SARAH+','SVRG','Adam','SGD w/m']\n",
    "handles=[Line2D([0],[0],color=palette[i],linewidth=3,dashes=[2,2],marker=markers[i],markersize=10)\\\n",
    "        for i in range(len(hue_order))]\n",
    "fig,ax=plt.subplots(1,1,figsize=(15,1))\n",
    "ax.legend(handles=handles,labels=hue_order,fancybox=True,fontsize=10,loc='center',\\\n",
    "         ncol=len(hue_order),prop={'size': 15})\n",
    "ax.axis('off')\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df2[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]\n",
    "    lp=sns.lineplot(x='ep',y='grad',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=5,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_yscale('log')\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "ax[0].set_ylabel(r'$||\\nabla P(w)||^2$',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]\n",
    "    lp=sns.lineplot(x='tp',y='grad',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=40,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_yscale('log')\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Wall Clock (sec)',fontsize=12)\n",
    "ax[0].set_ylabel(r'$||\\nabla P(w)||^2$',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "loss_limit=[[0.181,0.3],[0.2,0.3],[0.155,0.3],[0.325,0.42],[0.5155,0.54]]\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df2[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]\n",
    "        \n",
    "    ylim = loss_limit[d]\n",
    "    lp=sns.lineplot(x='ep',y='loss',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=5,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_ylim(ylim)\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "    \n",
    "ax[0].set_ylabel(r'$P(w)$',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "loss_limit=[[0.181,0.3],[0.2,0.3],[0.155,0.3],[0.325,0.42],[0.5155,0.54]]\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]   \n",
    "    ylim = loss_limit[d]\n",
    "    lp=sns.lineplot(x='tp',y='loss',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=40,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_ylim(ylim)\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Wall Clock (sec)',fontsize=12)\n",
    "ax[0].set_ylabel(r'$P(w)$',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "test_limit = [[0.90,0.925],[0.94,0.96],[0.90,0.975],[0.90,0.937],[0.70,0.761]]\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df2[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]\n",
    "    ylim = test_limit[d]\n",
    "    lp=sns.lineplot(x='ep',y='test',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=5,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_ylim(ylim)\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "ax[0].set_ylabel('Accuracy',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(3*len(DNAME),3),sharex='col',sharey=False)\n",
    "test_limit = [[0.90,0.925],[0.94,0.96],[0.90,0.975],[0.90,0.937],[0.70,0.761]]\n",
    "palette = sns.color_palette(\"tab10\")[:6]\n",
    "markers = ('o', 'v', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "for d in range(len(DNAME)):\n",
    "    subdf=df[DNAME[d]]\n",
    "    subdf=subdf[subdf['ep']<=ENDEP[d]]\n",
    "    ylim = test_limit[d]\n",
    "    lp=sns.lineplot(x='tp',y='test',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,\\\n",
    "                 zorder=-1,markevery=40,markersize=8,lw=2.0,\n",
    "                    legend=False,hue='algo',ci=95,style='algo',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].grid('both')\n",
    "    ax[d].set_ylim(ylim)\n",
    "for d in range(len(DNAME)):\n",
    "    ax[d].set_xlabel('Wall Clock (sec)',fontsize=12)    \n",
    "ax[0].set_ylabel('Accuracy',fontsize=12)\n",
    "fig.tight_layout()\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
