{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\"\"\"\n",
    "various analysis\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import scipy.sparse as sp\n",
    "from sklearn.preprocessing import normalize\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "from plotlib import *\n",
    "from script_utils import load_SDMP\n",
    "from path_dict import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cpu'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "\n",
    "ROOT_Ablation_Folder = \"../result/ablation\"\n",
    "fig_root = \"../figures\"\n",
    "\n",
    "def load_plot_data(root_path):\n",
    "    all_x, all_y = [], []\n",
    "    for d in os.listdir(root_path):\n",
    "        print(f\"Loading {d}\")\n",
    "        cur_data  = pd.read_csv(os.path.join(root_path, d, 'rel_regret_vs_time.csv'), \n",
    "                                sep=',', \n",
    "                                names=['time', 'rel_regret'])\n",
    "        all_x.append(cur_data[\"time\"])\n",
    "        all_y.append(cur_data[\"rel_regret\"])\n",
    "\n",
    "    return all_x, all_y\n",
    "\n",
    "conf_dict = defaultdict(dict)\n",
    "\n",
    "# cand_select\n",
    "## pubmed, SAGE\n",
    "conf_dict[(\"pubmed\", \"SAGE\")][\"x_lim\"] = [-1, 7000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name, model = \"pubmed\", \"SAGE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scenario = \"cand_select\"\n",
    "\n",
    "cur_folder = os.path.join(ROOT_Ablation_Folder, scenario, name+\"_\"+model)\n",
    "cur_fig_folder = os.path.join(fig_root, scenario)\n",
    "\n",
    "if not os.path.exists(cur_fig_folder):\n",
    "    os.makedirs(cur_fig_folder)\n",
    "\n",
    "x_list, y_list = load_plot_data(cur_folder)\n",
    "\n",
    "###########\n",
    "legend_list = [\"$K_1=1$\", \"$K_1=2$\", \"$K_1=3$\", \"$K_1=4$\"]\n",
    "style_list = ['solid', 'dashed', 'dashdot', 'dotted']\n",
    "\n",
    "plot_curves(x_list, \n",
    "            y_list,\n",
    "            legend_list,\n",
    "            xlim=conf_dict[(name, model)][\"x_lim\"],\n",
    "            figure_path=os.path.join(cur_fig_folder, f\"{name}_{model}_{scenario}.pdf\"),\n",
    "            styles=style_list,\n",
    "            x_label='Time(s)',\n",
    "            y_label='NR (log scale)',\n",
    "            figsize=[12, 6],\n",
    "            y_log=True,\n",
    "            widths=[4] * len(x_list),\n",
    "            ylim=[0.02, 1],\n",
    "            fontsize=32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scenario = \"warming\"\n",
    "\n",
    "cur_folder = os.path.join(ROOT_Ablation_Folder, scenario, name+\"_\"+model)\n",
    "cur_fig_folder = os.path.join(fig_root, scenario)\n",
    "\n",
    "if not os.path.exists(cur_fig_folder):\n",
    "    os.makedirs(cur_fig_folder)\n",
    "\n",
    "x_list, y_list = load_plot_data(cur_folder)\n",
    "\n",
    "###########\n",
    "legend_list = [\"k=1\", \"k=2\", \"k=3\", \"k=4\", \"k=5\"]\n",
    "style_list = ['solid', 'dashed', 'dashdot', 'dotted', 'dotted']\n",
    "\n",
    "plot_curves(x_list, \n",
    "            y_list,\n",
    "            legend_list,\n",
    "            # xlim=conf_dict[(name, model)][\"x_lim\"],\n",
    "            figure_path=os.path.join(cur_fig_folder, f\"{name}_{model}_{scenario}.pdf\"),\n",
    "            styles=style_list,\n",
    "            x_label='Time(s)',\n",
    "            y_label='NR (log scale)',\n",
    "            figsize=[6, 6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
