{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot the ACC v.s. time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yaochen/.conda/envs/gnn/lib/python3.7/site-packages/pandas/util/_decorators.py:311: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.\n",
      "  return func(*args, **kwargs)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Length of values (2) does not match length of index (6)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_54651/2757594581.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     41\u001b[0m             \u001b[0mcolor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm_c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m             \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'marker'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmarker\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'color'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/gnn/lib/python3.7/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m   3610\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3611\u001b[0m             \u001b[0;31m# set column\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3612\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_item\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3614\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_setitem_slice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/gnn/lib/python3.7/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m_set_item\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m   3782\u001b[0m         \u001b[0mensure\u001b[0m \u001b[0mhomogeneity\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3783\u001b[0m         \"\"\"\n\u001b[0;32m-> 3784\u001b[0;31m         \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sanitize_column\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3785\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3786\u001b[0m         if (\n",
      "\u001b[0;32m~/.conda/envs/gnn/lib/python3.7/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m_sanitize_column\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m   4507\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4508\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mis_list_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4509\u001b[0;31m             \u001b[0mcom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequire_length_match\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   4510\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msanitize_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_2d\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4511\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/gnn/lib/python3.7/site-packages/pandas/core/common.py\u001b[0m in \u001b[0;36mrequire_length_match\u001b[0;34m(data, index)\u001b[0m\n\u001b[1;32m    530\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    531\u001b[0m         raise ValueError(\n\u001b[0;32m--> 532\u001b[0;31m             \u001b[0;34m\"Length of values \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    533\u001b[0m             \u001b[0;34mf\"({len(data)}) \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    534\u001b[0m             \u001b[0;34m\"does not match length of index \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Length of values (2) does not match length of index (6)"
     ]
    }
   ],
   "source": [
    "data = pd.read_csv('../result/old_best/speedtest/ogbn-products_plot_pioneer.csv', sep=', ')\n",
    "# data = pd.read_csv('../result/best/speedtest/speed.csv', sep=', ')\n",
    "\n",
    "\n",
    "fontdict = {'size': 50}\n",
    "gridspec_kw = {'height_ratios': [1.2, 1], 'hspace': 0.05}\n",
    "buttom_yrange = (59.3, 66.9) \n",
    "top_yrange = (73.2, 82.6)\n",
    "xrange = (0.34, 360)\n",
    "\n",
    "# add something\n",
    "marker_color_dict = {\n",
    "    'SDGNN-Rev': ['P', 'c'],\n",
    "    'SDGNN-SAGE': ['P', 'c'],\n",
    "    'NOSMOG': ['s', 'r'],\n",
    "    'GNN': ['X', 'g'],\n",
    "    'GLNN': ['^', 'orange'],\n",
    "    'MLP': ['o', 'b'],\n",
    "    'PPRGo': ['D', 'brown']\n",
    "}\n",
    "\n",
    "NOSMOG, 0.46, 80.91\n",
    "GLNN, 0.46, 64.96\n",
    "MLP, 0.40, 59.94\n",
    "GNN, 128.71, 82.79\n",
    "SDGNN, 0.74, 82.87\n",
    "PPRGo, 0.74, 79.37\n",
    "marker = []\n",
    "color = []\n",
    "for idx, row in data.iterrows():\n",
    "    for name, m_c in marker_color_dict.items():\n",
    "        if name in row['name']:\n",
    "            marker.append(m_c[0])\n",
    "            color.append(m_c[1])\n",
    "            break\n",
    "data['marker'] = marker\n",
    "data['color'] = color\n",
    "\n",
    "pos_dict = dict([(row['name'], [0.05, -0.28]) for _, row in data.iterrows()])\n",
    "pos_dict['GL-Rev-w4'] = [0.0, 1.2]\n",
    "pos_dict['NOS-SAGE'] = [-0.1, -1.3]\n",
    "pos_dict['SAGE-L2-N20'] = [-0.2, -1.0]\n",
    "pos_dict['SAGE-L2-full'] = [0.0, 1.2]\n",
    "pos_dict['SAGE-L3-full'] = [-0.2, -1.2]\n",
    "\n",
    "def compute_gap(x, a):\n",
    "    return x + x*(10**a - 1)\n",
    "\n",
    "data_section = dict([(row['name'], 1) if row['acc'] <= 70 else (row['name'], 0) for _, row in data.iterrows()])\n",
    "f, ax = plt.subplots(2, 1, figsize=(22, 16), gridspec_kw=gridspec_kw) # figsize=(20, 16)\n",
    "f.set_dpi(300)\n",
    "f.text(0.04, 0.42, 'Accuracy', rotation=90, fontdict=fontdict)\n",
    "f.text(0.25, 0.01, 'Log Scale Inference Time (ms)', fontdict=fontdict)\n",
    "\n",
    "ax[0].set_ylim(*top_yrange)\n",
    "ax[0].set_xlim(*xrange)\n",
    "ax[0].set_xscale('log')\n",
    "ax[0].xaxis.tick_top()\n",
    "ax[0].set_yticks(ticks=[75, 80])\n",
    "ax[0].tick_params(\n",
    "    labeltop=False, \n",
    "    labelsize=fontdict['size'], \n",
    "    width=4, \n",
    "    length=8, \n",
    "    grid_linewidth=8,\n",
    "    pad=8\n",
    ")\n",
    "for axis in ['top','bottom','left','right']:\n",
    "    ax[0].spines[axis].set_linewidth(4)\n",
    "ax[0].spines['bottom'].set_visible(False)\n",
    "\n",
    "ax[1].set_ylim(*buttom_yrange)\n",
    "ax[1].set_xlim(*xrange)\n",
    "ax[1].set_xscale('log')\n",
    "ax[1].tick_params(\n",
    "    labelsize=fontdict['size'], \n",
    "    width=4, \n",
    "    length=8, \n",
    "    grid_linewidth=8,\n",
    "    pad=8\n",
    ")\n",
    "for axis in ['top','bottom','left','right']:\n",
    "    ax[1].spines[axis].set_linewidth(4)\n",
    "ax[1].spines['top'].set_visible(False)\n",
    "\n",
    "\n",
    "for idx, row in data.iterrows():\n",
    "    if 'N40' in row['name']:\n",
    "        continue\n",
    "    if 'SDGNN-' in row['name']:\n",
    "        cur_fontdict = {\n",
    "            'size': 40,\n",
    "            'weight': 'bold'\n",
    "        }\n",
    "    else:\n",
    "        cur_fontdict = {\n",
    "            'size': 40\n",
    "        }\n",
    "    sec = data_section[row['name']]\n",
    "    ax[sec].scatter(row['time (ms)'], row['acc'], marker=row['marker'], c=row['color'], s=fontdict['size']*15)\n",
    "    ax[sec].text(\n",
    "        compute_gap(row['time (ms)'], pos_dict[row['name']][0]), \n",
    "        row['acc'] + pos_dict[row['name']][1], \n",
    "        row['name'], \n",
    "        fontdict=cur_fontdict\n",
    "    )\n",
    "    if row['name'] == 'SAGE-L2-full' or row['name'] == 'GL-Rev-w4':\n",
    "        ax[sec].plot(\n",
    "            (\n",
    "                compute_gap(row['time (ms)'], pos_dict[row['name']][0]+0.03),\n",
    "                compute_gap(row['time (ms)'], pos_dict[row['name']][0]),\n",
    "            ),\n",
    "            (\n",
    "                row['acc'] + pos_dict[row['name']][1]-0.2, \n",
    "                row['acc'] + 0.4\n",
    "            ),\n",
    "            color='k',\n",
    "            linewidth=3\n",
    "        )\n",
    "\n",
    "d = 0.015  # how big to make the diagonal lines in axes coordinates\n",
    "# arguments to pass to plot, just so we don't keep repeating them\n",
    "kwargs = dict(transform=ax[0].transAxes, color='k', clip_on=False, linewidth=4)\n",
    "ax[0].plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal\n",
    "ax[0].plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal\n",
    "\n",
    "kwargs.update(transform=ax[1].transAxes)  # switch to the bottom axes\n",
    "ax[1].plot((-d, +d), (1 - d*gridspec_kw['height_ratios'][0], 1 + d*gridspec_kw['height_ratios'][0]), **kwargs)  # bottom-left diagonal\n",
    "ax[1].plot((1 - d, 1 + d), (1 - d*gridspec_kw['height_ratios'][0], 1 + d*gridspec_kw['height_ratios'][0]), **kwargs)  # bottom-right diagonal\n",
    "\n",
    "\n",
    "plt.savefig('../figures/acc2infertime_plot.pdf', bbox_inches='tight')\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "res_path = '/home/yaochen/data/to_plot'\n",
    "\n",
    "\n",
    "\n",
    "df_holder = []\n",
    "legend_holder = []\n",
    "for d in os.listdir(r):\n",
    "    df_holder.append(\n",
    "        pd.read_csv(\n",
    "            os.path.join(r, d, 'rel_regret_vs_time.csv'), \n",
    "            sep=',', \n",
    "            names=['time', 'rel_regret']\n",
    "        )\n",
    "    )\n",
    "    with open(os.path.join(r, d, 'tag.txt'), 'r') as f:\n",
    "        legend_holder.append(f.readline().strip())\n",
    "\n",
    "marker_holder = ['s', '^', 'd', 'o']\n",
    "color_holder = ['r', 'b', 'g', 'k']\n",
    "style_holder = ['solid', 'dashed', 'dashdot', 'dotted']\n",
    "sort_holder = [2, 1, 0, 3]\n",
    "fontdict = {'size': 50}\n",
    "f = plt.figure(figsize=(18, 12))\n",
    "ax = f.add_subplot(111)\n",
    "# plt.yscale('log')\n",
    "ax.set_ylim(0, 0.8)\n",
    "ax.set_xlim(0, 820)\n",
    "ax.set_xlabel('Time (s)', fontdict=fontdict)\n",
    "ax.set_ylabel('Normalized Regret', fontdict=fontdict)\n",
    "ax.set_xticks(ticks=[0, 200, 400, 600, 800])\n",
    "ax.set_xticklabels(['0', '', '400', '', '800'])\n",
    "ax.tick_params(\n",
    "    labelsize=fontdict['size'], \n",
    "    width=4, \n",
    "    length=8, \n",
    "    grid_linewidth=8,\n",
    "    pad=8\n",
    ")\n",
    "\n",
    "\n",
    "for dirction in ['top', 'bottom', 'left', 'right']:\n",
    "    ax.spines[dirction].set_linewidth(4)\n",
    "\n",
    "all_holder = []\n",
    "for data, label, marker, color, style, sort_idx in zip(\n",
    "    df_holder, legend_holder, marker_holder, color_holder, style_holder, sort_holder\n",
    "):\n",
    "    all_holder.append((data, label, marker, color, style, sort_idx))\n",
    "all_holder.sort(key=lambda x: x[5])\n",
    "\n",
    "for data, label, marker, color, style, _ in all_holder:\n",
    "    x = data['time']\n",
    "    y = data['rel_regret']\n",
    "    ax.plot(x, y, label=label, c=color, linewidth=5, linestyle=style)\n",
    "\n",
    "\n",
    "ax.legend(fontsize=fontdict['size'])\n",
    "# plt.savefig('/home/Ge-zhang/yhu-GNN_Efficient_Inference/ablation_to_plot_v2.pdf', bbox_inches='tight')\n",
    "ax.plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# graph plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "G = nx.complete_graph(5)\n",
    "nx.draw(G)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
