{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = ['magenta','blue', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "import matplotlib\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "torch.set_num_threads(1) #cpu num\n",
    "from fractions import Fraction\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import math\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_alpha(_aa):\n",
    "    aa_ep = defaultdict(list)\n",
    "    for i, x in enumerate(_aa):\n",
    "        aa_ep[math.ceil(x[0])].append(i)\n",
    "        \n",
    "    _aa = np.array(_aa)\n",
    "        \n",
    "    _a = []\n",
    "    _areal = []\n",
    "    _amax = []\n",
    "    \n",
    "    for key in sorted(aa_ep.keys()):\n",
    "        ind = aa_ep[key]\n",
    "        _a.append(np.mean(_aa[ind,2]))\n",
    "        _areal.append(np.mean(_aa[ind,3]))\n",
    "        _amax.append(np.mean(_aa[ind,4]))\n",
    "    return list(sorted(aa_ep.keys())),_a, _areal, _amax\n",
    "    \n",
    "\n",
    "def fillup(_epp,_ep,_l):\n",
    "    \n",
    "    for ind in range(1,len(_epp)):\n",
    "        if _epp[ind] not in _ep:\n",
    "            _l.insert(ind,_l[ind-1])\n",
    "    return _l    \n",
    "\n",
    "def runmax(_l):\n",
    "    \n",
    "    return [max(_l[:i+1]) for i in range(len(_l))]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DNAME = ['ijcnn1','rcv1','real-sim','news20','covtype']\n",
    "TICK=[5,10,5,10,5]\n",
    "ALGO = ['ai_sarah']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfolder = '%s/reg/%s/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adf={}\n",
    "for dname in DNAME:\n",
    "    aDF = []\n",
    "    print(dname)\n",
    "    \n",
    "    temp_logfolder = logfolder%(dname,'ai_sarah')\n",
    "    for f in os.listdir(temp_logfolder):\n",
    "        if 'DONE' in f or '.tar' not in f or 'RUN' in f:\n",
    "            continue\n",
    "        else:\n",
    "            temp = torch.load(temp_logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            seed = parm[1]\n",
    "            bs = parm[0]\n",
    "            alpha = temp['alpha']\n",
    "            aEP,aL,arL,amL = avg_alpha(alpha)\n",
    "            \n",
    "            aDF+=[['AI-SARAH', gamma, bs, seed, e, 'alpha', ai] for e,ai\\\n",
    "                in zip(aEP,aL)]\n",
    "            aDF+=[['AI-SARAH', gamma, bs, seed, e, 'alpha_newton', ali] for e,ali\\\n",
    "                in zip(aEP,arL)]\n",
    "            aDF+=[['AI-SARAH', gamma, bs, seed, e, 'alpha_max', ami] for e,ami\\\n",
    "                in zip(aEP,amL)]\n",
    "            \n",
    "    adf[dname] = pd.DataFrame(data=aDF,columns=['algo','gamma','BS','seed',\\\n",
    "                                       'ep','type','alpha'])\n",
    "    adf[dname] = adf[dname].sort_values(by=['algo','type'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,len(DNAME),figsize=(2.5*len(DNAME),4),sharex=False,sharey=False)\n",
    "ind = 0\n",
    "\n",
    "palette = sns.color_palette()[:2]\n",
    "palette = cmap[:2]\n",
    "markers = ('o', 'v')\n",
    "dashes = (\"\",\"\")\n",
    "dashes = [(2,2)]*2\n",
    "\n",
    "for d in range(len(DNAME)):\n",
    "    \n",
    "    if d==0:\n",
    "        LEGEND='full'\n",
    "    else:\n",
    "        LEGEND=False\n",
    "    \n",
    "    subdf = adf[DNAME[d]][(adf[DNAME[d]]['ep']<=ENDEP[d])&(adf[DNAME[d]]['ep']>=1)\\\n",
    "                          &(adf[DNAME[d]]['type']!='alpha_newton')]\n",
    "    lp=sns.lineplot(x='ep',y='alpha',dashes=dashes,palette=palette,\\\n",
    "       markers=markers,zorder=-1,markevery=5,markersize=8,lw=2.0,\\\n",
    "       legend=LEGEND,hue='type',ci=95,style='type',\\\n",
    "       data=subdf,ax=ax[d])\n",
    "    ax[d].set_xlim([0,ENDEP[d]+0.1])\n",
    "    ax[d].set_ylabel('')\n",
    "    ax[d].set_title(DNAME[d],fontsize=15)\n",
    "    ax[d].set_xlabel('')\n",
    "    ax[d].grid('both')\n",
    "    label=[r'$\\alpha$',r'$\\alpha_{max}$']\n",
    "ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "ax[0].legend(handles=handles[1:], \\\n",
    "             labels=label,fancybox=True,fontsize=12,loc='lower right')\n",
    "fig.tight_layout()\n",
    "plt.show()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
