{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "import matplotlib\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "torch.set_num_threads(1) #cpu num\n",
    "from fractions import Fraction\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import math\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_alpha(_aa):\n",
    "    aa_ep = defaultdict(list)\n",
    "    for i, x in enumerate(_aa):\n",
    "        aa_ep[math.ceil(x[0])].append(i)\n",
    "        \n",
    "    _aa = np.array(_aa)\n",
    "        \n",
    "    _a = []\n",
    "    _areal = []\n",
    "    _amax = []\n",
    "    \n",
    "    for key in sorted(aa_ep.keys()):\n",
    "        ind = aa_ep[key]\n",
    "        _a.append(np.mean(_aa[ind,2]))\n",
    "        _areal.append(np.mean(_aa[ind,3]))\n",
    "        _amax.append(np.mean(_aa[ind,4]))\n",
    "    return list(sorted(aa_ep.keys())),_a, _areal, _amax\n",
    "    \n",
    "\n",
    "def fillup(_epp,_ep,_l):\n",
    "    \n",
    "    for ind in range(1,len(_epp)):\n",
    "        if _epp[ind] not in _ep:\n",
    "            _l.insert(ind,_l[ind-1])\n",
    "    return _l    \n",
    "\n",
    "def runmax(_l):\n",
    "    \n",
    "    return [max(_l[:i+1]) for i in range(len(_l))]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DNAME = ['ijcnn1','rcv1','real-sim','news20','covtype']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../SenseLogs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df={}\n",
    "adf={}\n",
    "for dname in DNAME:    \n",
    "\n",
    "    logfolder = '%s/non_reg/ai_sarah/'%dname\n",
    "    epp={}\n",
    "    for f in os.listdir(logfolder):\n",
    "        if '.tar' not in f:\n",
    "            continue\n",
    "        else:\n",
    "            temp = torch.load(logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            hist = np.array(temp['hist'])\n",
    "            stat = np.array(temp['stat'])\n",
    "            o = parm[2]\n",
    "            if o not in epp:\n",
    "                epp[o]=[]\n",
    "\n",
    "            ep = list(np.round(hist[:,0]))\n",
    "            epp[o]+=ep\n",
    "\n",
    "    EPP={}\n",
    "    for o in sorted(epp.keys()):\n",
    "        EPP[o] = sorted(list(set(epp[o])))\n",
    "\n",
    "    DF = []\n",
    "    aDF=[]\n",
    "\n",
    "    for f in os.listdir(logfolder):\n",
    "        if '.tar' not in f:\n",
    "            continue\n",
    "        else:\n",
    "            temp = torch.load(logfolder+f)\n",
    "            parm = temp['parm']\n",
    "            BS = parm[0]\n",
    "            seed = parm[1]\n",
    "            gamma = parm[2]\n",
    "            hist = np.array(temp['hist'])\n",
    "            stat = np.array(temp['stat'])\n",
    "            alpha = temp['alpha']\n",
    "            O = gamma\n",
    "\n",
    "            ep = list(np.round(hist[:,0]))\n",
    "            loss = fillup(EPP[O],ep,list(hist[:,1]))\n",
    "            grad = fillup(EPP[O],ep,list(hist[:,2]))\n",
    "            test = fillup(EPP[O],ep,list(hist[:,3]))\n",
    "            test = runmax(test)\n",
    "            innerT = fillup(EPP[O],ep,list(stat[:,2]))\n",
    "            outerT = fillup(EPP[O],ep,list(stat[:,1]))\n",
    "            timeT = fillup(EPP[O],ep,list(stat[:,3]))\n",
    "            \n",
    "            aEP,aL,arL,amL = avg_alpha(alpha)\n",
    "            \n",
    "            DF+=[[upper_bound+heuristic, gamma, BS, seed, e,oT,iT,tT,l,g,t] for e,oT,iT,tT,l,g,t\\\n",
    "                in zip(EPP[O],outerT,innerT,timeT,loss,grad,test)]\n",
    "            \n",
    "            aDF+=[[upper_bound+heuristic, gamma, BS, seed, e, 'alpha', ai] for e,ai\\\n",
    "                in zip(aEP,aL)]\n",
    "            aDF+=[[upper_bound+heuristic, gamma, BS, seed, e, 'alpha_real', ali] for e,ali\\\n",
    "                in zip(aEP,arL)]\n",
    "            aDF+=[[upper_bound+heuristic, gamma, BS, seed, e, 'alpha_max', ami] for e,ami\\\n",
    "                in zip(aEP,amL)]\n",
    "\n",
    "    df[dname] = pd.DataFrame(data=DF,columns=['algo','gamma','BS','seed',\\\n",
    "                                       'ep','outer','inner','time','loss','grad','test'])\n",
    "    df[dname] = df[dname].sort_values(by=['algo'])\n",
    "    \n",
    "    adf[dname] = pd.DataFrame(data=aDF,columns=['algo','gamma','BS','seed',\\\n",
    "                                       'ep','type','alpha'])\n",
    "    adf[dname] = adf[dname].sort_values(by=['algo','type'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "sns.set(context=\"paper\",font='serif')\n",
    "sns.set_style(\"ticks\")\n",
    "fig, ax = plt.subplots(1,5,figsize=(15,3),sharex=True,sharey=False)\n",
    "\n",
    "palette = sns.color_palette()[:6]\n",
    "palette = cmap[1:2]+cmap[0:1]+cmap[2:6]\n",
    "markers = ('v', 'o', '^', '<', '>', '8')\n",
    "dashes = [(2,2)]*6\n",
    "\n",
    "for d in range(5):\n",
    "    if d==0:\n",
    "        LEGEND='full'\n",
    "    else:\n",
    "        LEGEND=False\n",
    "    subdf = df[DNAME[d]]\n",
    "    lp=sns.lineplot(x='ep',y='grad',dashes=dashes,palette=palette,\\\n",
    "                 markers=markers,zorder=-1,markevery=5,markersize=8,lw=2.0,\\\n",
    "                    legend=LEGEND,hue='gamma',ci=95,style='gamma',\\\n",
    "                    data=subdf,ax=ax[d])\n",
    "    ax[d].set_yscale('log') \n",
    "    ax[d].grid('both')\n",
    "    if d==0:\n",
    "        ax[d].set_ylabel(r'$||\\nabla P(w)||^2$',fontsize=12)\n",
    "    else:\n",
    "        ax[d].set_ylabel('')\n",
    "    ax[d].set_xlabel('Effective Pass',fontsize=12)\n",
    "    \n",
    "handles, labels = ax[0].get_legend_handles_labels()\n",
    "labels = [r'$\\gamma$']+[str(Fraction(n)) for n in labels[1:]]\n",
    "ax[0].legend(handles=handles,labels=labels,fontsize=10,loc='lower left',fancybox=True)\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
