{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Post analysis and show results\n",
    "\n",
    "\"\"\"\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import glob\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from utils import load_train_conf\n",
    "from path_dict import *\n",
    "from script_utils import latexTableBase\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "class analyzeLogBase:\n",
    "    \"\"\"\n",
    "    The base class for loading logs of experiments and some basic operations on the logs.\n",
    "    \n",
    "    The logs are supposed to have this harachy log_dir/trial_dir. log_dir is some pre-specified\n",
    "    path to store a bunch of experiments. trials_dir is usually automatrically generated names with\n",
    "    the format of device_name+time_tag. Inside this dir, there could be various files including\n",
    "    log of standard output, model results, various intermediate mesure trace etc. Initialize this\n",
    "    class will automatically load all the logs in log_dir in a unified manner.\n",
    "    \n",
    "    Input:\n",
    "    log_dir: the root folder path contains the the bunch of logs. If a list of root folder is\n",
    "             given, the loader will loop over all the folders and generate one list of results.\n",
    "    name_to_load: the list of the names of the files in each log_dir/trial_dir to load. Currently\n",
    "                  only support the file style 'txt', 'pkl' and 'yml'. More are TBD. Value of None\n",
    "                  will load all the files in each trial folder.\n",
    "    drop_suffex_key: whether to keep the suffex name as the key for the loaded files. In case that\n",
    "                     the dropped name has conflict (e.g. log.pkl and log.txt), we need to turn this\n",
    "                     off.\n",
    "    \n",
    "    Return:\n",
    "    A list of dict, each dict contains all the loaded logs for each trial. Keys are assigned to\n",
    "    be the specific file name and values are the contain of that file. Addicionally, key 'log_dir'\n",
    "    and value 'log_dir/this_trial_dir' will always be added to each dict.\n",
    "    \n",
    "    Example:\n",
    "    Given a log dir of\n",
    "    -foo_dir\n",
    "    ---server0_20220930\n",
    "    -----res.pkl\n",
    "    -----time.pkl\n",
    "    -----config.yml\n",
    "    ---server0_20220931\n",
    "    -----res.pkl\n",
    "    -----time.pkl\n",
    "    -----config.yml\n",
    "    After calling test=analyzeLogBase('foo_dir', ['res.pkl', 'config.yml']), a list res of dict contains two\n",
    "    dicts will be returned, where test.res[0]['res'] stores the containt in\n",
    "    'foo_dir/server0_20220930/res.pkl'\n",
    "    \"\"\"\n",
    "    def __init__(self, list_log_dir, name_to_load=None, drop_suffex_key=True):\n",
    "        self.list_log_dir = list_log_dir\n",
    "        if not isinstance(self.list_log_dir, list):\n",
    "            self.list_log_dir = [self.list_log_dir]\n",
    "        if name_to_load is None or name_to_load == []:\n",
    "            print(\"Did not provide name to load. Load all files. \")\n",
    "            self.name_to_load = None\n",
    "        else:\n",
    "            self.name_to_load = name_to_load\n",
    "        self.drop_suffex_key = drop_suffex_key\n",
    "\n",
    "        self.all_logs = self.load_log_pool(self.list_log_dir, name_to_load, drop_suffex_key=self.drop_suffex_key)\n",
    "\n",
    "    @staticmethod\n",
    "    def load_log(log_dir, name_to_load=None, drop_suffex_key=True):\n",
    "        def load_file(path):\n",
    "            style = path.split(\".\")[-1]\n",
    "            if style == \"pkl\":\n",
    "                with open(path, \"rb\") as fin:\n",
    "                    res = pickle.load(fin)\n",
    "            elif style == \"yml\":\n",
    "                res = load_train_conf(path)\n",
    "            elif style == \"txt\":\n",
    "                with open(path, \"r\") as fin:\n",
    "                    res = fin.read()\n",
    "            else:\n",
    "                print(f\"Unsupported file style with path {path}. Style can only be txt, pkl, yml.\")\n",
    "                # raise ValueError(f\"Unrecoginized file style {path}. Style can only be txt, pkl, yml.\")\n",
    "            return res\n",
    "\n",
    "        if name_to_load is None:\n",
    "            name_to_load = [this_path.split(\"/\")[-1] for this_path in glob.iglob(log_dir+\"/*\")]\n",
    "        res_dict = {}\n",
    "        for this_name in name_to_load:\n",
    "            hash_key = this_name.split(\".\")[0] if drop_suffex_key else this_name\n",
    "            res_dict[hash_key] = load_file(os.path.join(log_dir, this_name))\n",
    "        res_dict[\"log_dir\"] = log_dir\n",
    "        return res_dict\n",
    "\n",
    "    @staticmethod\n",
    "    def load_log_pool(list_root_path, name_to_load, drop_suffex_key=True):\n",
    "        res = []\n",
    "        for this_root_path in list_root_path:\n",
    "            for this_path in glob.iglob(f\"{this_root_path}/*\"):\n",
    "                try:\n",
    "                    this_res = analyzeLogBase.load_log(this_path, name_to_load, drop_suffex_key=drop_suffex_key)\n",
    "                    res.append(this_res)\n",
    "                    print(f\"Loaded log {this_path}.\")\n",
    "                except:\n",
    "                    print(f\"In commplete log {this_path}.\")\n",
    "        return res\n",
    "\n",
    "    def disp_by_keys(self, keys):\n",
    "        for this_log in self.all_logs:\n",
    "            for this_key in keys:\n",
    "                print(f\"{this_key}: {this_log[this_key]}\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDMP Analysis base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load all and present the results\n",
    "all_res_dict = defaultdict(dict)\n",
    "all_root_path = \"../result/best\"\n",
    "\n",
    "name_to_load = [\"dict_log.txt\", \"dict_log.pkl\", \"dict_conf.yml\", \"conf.yml\", \"f1_micro.pkl\", \"GNN_f1.txt\", \"GNN_data_split_seed.txt\"]\n",
    "\n",
    "# load and gather results\n",
    "for name, model in name_model_pool:\n",
    "    # res_path = os.path.join(all_root_path, f\"{name}_{model}_MLP_TEST\")\n",
    "    res_path = path_dict[name, model][\"result_root_parent_folder\"]\n",
    "    print(res_path)\n",
    "    try:\n",
    "        list_log = analyzeLogBase(res_path,\n",
    "                       name_to_load=name_to_load,\n",
    "                       drop_suffex_key=False).all_logs\n",
    "    except:\n",
    "        print(\"!!! Failed with {name}/{model}\")\n",
    "        continue\n",
    "    res_acc, res_reg, res_target = collect_mean_stds(list_log, return_target=True)\n",
    "    all_res_dict[(name, model)][\"log\"] = list_log\n",
    "    all_res_dict[(name, model)][\"acc\"] = res_acc\n",
    "    all_res_dict[(name, model)][\"reg\"] = res_reg\n",
    "    all_res_dict[(name, model)][\"target\"] = res_target\n",
    "    all_res_dict[(name, model)][\"cnt\"] = len(res_acc)\n",
    "    all_res_dict[(name, model)][\"list_log\"] = list_log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t\tName/model & Original & Accuracy & NR & Mean Receptive Field\\\\\n",
      "\t\tcora/SAGE & 85.43$\\pm$1.22 & 85.53$\\pm$1.32 & 0.007$\\pm$0.002 & 35.2\\\\\n",
      "\t\tcora/geomGCN & 86.52$\\pm$0.98 & 86.64$\\pm$0.85 & 0.019$\\pm$0.008 & 40.2\\\\\n",
      "\t\tciteseer/SAGE & 72.41$\\pm$1.79 & 73.62$\\pm$2.04 & 0.033$\\pm$0.083 & 16.3\\\\\n",
      "\t\tciteseer/geomGCN & 79.91$\\pm$0.94 & 80.27$\\pm$1.11 & 0.036$\\pm$0.014 & 13.5\\\\\n",
      "\t\tpubmed/SAGE & 87.17$\\pm$0.56 & 87.10$\\pm$0.54 & 0.034$\\pm$0.002 & 25.5\\\\\n",
      "\t\tpubmed/geomGCN & 89.75$\\pm$0.43 & 89.77$\\pm$0.47 & 0.037$\\pm$0.005 & 16.3\\\\\n",
      "\t\ta-computer/SAGE & 89.06$\\pm$0.53 & 90.60$\\pm$0.52 & 0.007$\\pm$0.001 & 26.1\\\\\n",
      "\t\ta-computer/exphormer & 91.60$\\pm$0.58 & 94.29$\\pm$0.52 & 0.016$\\pm$0.001 & 47.3\\\\\n",
      "\t\ta-photo/SAGE & 92.90$\\pm$0.66 & 93.96$\\pm$0.39 & 0.003$\\pm$0.002 & 26.9\\\\\n",
      "\t\ta-photo/exphormer & 95.03$\\pm$0.48 & 96.73$\\pm$0.30 & 0.005$\\pm$0.001 & 46.1\\\\\n",
      "\t\togbn-arxiv/SAGE & 70.56$\\pm$0.31 & 70.50$\\pm$0.33 & 0.013$\\pm$0.000 & 31.3\\\\\n",
      "\t\togbn-arxiv/DRGAT & 73.78$\\pm$0.09 & 73.72$\\pm$0.11 & 0.077$\\pm$0.006 & 36.2\\\\\n",
      "\t\togbn-products/SAGE & 78.08$\\pm$0.65 & 78.07$\\pm$0.61 & 0.061$\\pm$0.001 & 24.0\\\\\n",
      "\t\togbn-products/RevGNN-112 & 82.79$\\pm$0.25 & 82.87$\\pm$0.21 & 0.029$\\pm$0.001 & 23.6\n"
     ]
    }
   ],
   "source": [
    "def extract_mean_receptive(list_log):\n",
    "    all_receptive = [cur_log['dict_log.pkl']['Theta_col_nonzero_stat'][-1][0] for cur_log in list_log]\n",
    "    return np.mean(all_receptive)\n",
    "\n",
    "# show the results\n",
    "all_res_list = []\n",
    "all_res_list.append([\"Name/model\",\"Original\", \"Accuracy\", \"NR\", \"Mean Receptive Field\"])\n",
    "\n",
    "for k, v in name_model_pool:\n",
    "    cur_res = []\n",
    "    cur_dict = all_res_dict[k, v]\n",
    "    cur_res.append(f\"{k}/{v}\")\n",
    "    cur_res.append(f\"{100*np.mean(cur_dict['target']):.2f}$\\pm${100*np.std(cur_dict['target']):.2f}\")\n",
    "    cur_res.append(f\"{100*np.mean(cur_dict['acc']):.2f}$\\pm${100*np.std(cur_dict['acc']):.2f}\")\n",
    "    cur_res.append(f\"{np.mean(cur_dict['reg']):.3f}$\\pm${np.std(cur_dict['reg']):.3f}\")\n",
    "    # cur_res.append(f\"{cur_dict['cnt']:d}\")\n",
    "    cur_res.append(f\"{extract_mean_receptive(cur_dict['list_log']):.1f}\")\n",
    "    all_res_list.append(cur_res)\n",
    "    \n",
    "res_str = latexTableBase.gen_table(all_res_list)\n",
    "print(res_str)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['log', 'acc', 'reg', 'target', 'cnt'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cur_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "23.603542612615556"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_log[0]['dict_log.pkl']['Theta_col_nonzero_stat'][-1][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDMP inductive training test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class dictTable(latexTableBase):\n",
    "    def gen_a_row(self, dict_log):\n",
    "        dl = dict_log[\"dict_log.pkl\"]\n",
    "        dltxt = dict_log[\"dict_log.txt\"]\n",
    "        TT = dict_log[\"ThetaT.pkl\"]\n",
    "        dc = dict_log[\"dict_conf.yml\"]\n",
    "        cf = dict_log[\"conf.yml\"]\n",
    "        f1 = dict_log[\"f1_micro.pkl\"]\n",
    "        \n",
    "        res = []\n",
    "\n",
    "        res.append(str(dc['inductive_train_ratio']))\n",
    "        res.append(str(dc['theta_n_nonzero']))\n",
    "        res.append(dc['theta_cand_mode'])\n",
    "        res.append(str(dc['theta_cand_k1']))\n",
    "        res.append(str(dc['theta_cand_k2']))\n",
    "        res.append(dc['h_init_theta_mode'])\n",
    "        res.append(str(dc['h_init_theta_k'])) \n",
    "        res.append(f\"{dl['Theta_col_nonzero_stat'][-1][0]:.1f}\" + \"$\\\\pm$\"\n",
    "                   f\"{dl['Theta_col_nonzero_stat'][-1][1]:.1f}\")\n",
    "        res.append(f\"{dl['rel_regret'][-1]:.3f}\")\n",
    "        res.append(self.load_dict_log_time(dltxt))\n",
    "        res.append(f\"{np.mean(f1):.3f}\" + \"$\\\\pm$\"\n",
    "                   f\"{np.std(f1):.3f}\")\n",
    "        return res\n",
    "\n",
    "    @staticmethod\n",
    "    def load_dict_log_time(file_str):\n",
    "        tmp = file_str.split(\"\\n\")[-5]\n",
    "        res = tmp.split(\" \")[3]\n",
    "        return res\n",
    "\n",
    "    def post_precessing(self):\n",
    "        pass\n",
    "        self.table_element.sort(key=lambda x: (x[0], x[2], int(x[1]), int(x[3])))\n",
    "        \n",
    "header = [\"ind ratio\", \"$\\\\vtheta$ cnt bd\", \"$\\\\vtheta$ mode\", \"$\\\\vtheta$ k1\", \"$\\\\vtheta$ k2\", \"$h$ mode\", \"$h$ k\", \"$\\\\vtheta$ 0 cnt\", \"n-Regret\", \"SDMP time\", \"f1\\_micro\"]\n",
    "name_to_load = [\"dict_log.txt\", \"dict_log.pkl\", \"ThetaT.pkl\", \"dict_conf.yml\", \"conf.yml\", \"f1_micro.pkl\"]\n",
    "\n",
    "# pubmed\n",
    "list_log_paths = [\"../result/a-photo_exphormer_MLP_TEST\"]\n",
    "list_log = analyzeLogBase(list_log_paths, name_to_load=name_to_load, drop_suffex_key=False).all_logs\n",
    "table_str = dictTable(list_log, header=header).table_str\n",
    "\n",
    "print()\n",
    "print(table_str)\n",
    "print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82.79$\\pm$0.25\n"
     ]
    }
   ],
   "source": [
    "root = \"../result/ogbn-products/RevGNN-112/seed_\"\n",
    "\n",
    "# for seed in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 234]:\n",
    "#     cur_folder_path = root+str(seed)\n",
    "#     with open(os.path.join(cur_folder_path, \"data_split_seed.txt\"), \"w\") as fout:\n",
    "#         fout.write(str(seed))\n",
    "\n",
    "res = []     \n",
    "for seed in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:\n",
    "    cur_folder_path = root+str(seed)\n",
    "    with open(os.path.join(cur_folder_path, \"test_acc.txt\"), \"r\") as fin:\n",
    "        # res.append(eval(fin.read().split(':')[1]))\n",
    "        res.append(eval(fin.read()))\n",
    "print(f\"{np.mean(res)*100:.2f}$\\pm${np.std(res)*100:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = \"../result/a-computer_exphormer_SDMP\"\n",
    "\n",
    "for file in os.listdir(root):\n",
    "    print(file)\n",
    "    split_target_path = os.path.join(root, file, \"GNN_data_split_seed.txt\")\n",
    "    log_path = os.path.join(root, file, \"log.txt\")\n",
    "    with open(log_path, \"r\") as fin:\n",
    "        fin.readline()\n",
    "        fin.readline()\n",
    "        fin.readline()\n",
    "        crit_line = fin.readline()\n",
    "    print(crit_line)\n",
    "    seed_num = crit_line[-2]\n",
    "    with open(split_target_path, 'w') as fout:\n",
    "        fout.write(seed_num)\n",
    "    print(f\"seed num {seed_num}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "78.37$\\pm$0.55\n"
     ]
    }
   ],
   "source": [
    "root = \"../result/ogbn-products/SAGE\"\n",
    "\n",
    "res = []\n",
    "for file in os.listdir(root):\n",
    "    # print(file)\n",
    "    if file == \"seed_234\":\n",
    "        continue\n",
    "    f1_path = os.path.join(root, file, \"f1.txt\")\n",
    "    with open(f1_path, \"r\") as fin:\n",
    "        res.append(eval(fin.read()))\n",
    "\n",
    "print(f\"{np.mean(res)*100:.2f}$\\pm${np.std(res)*100:.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean and std: 78.37$\\pm$0.55\n",
      "Mean and 0.99 interval: 78.37$\\pm$0.60\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from utils import latex_sample_mean_std_confidence\n",
    "\n",
    "res_str = latex_sample_mean_std_confidence(res, confidence=0.99)\n",
    "\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.06907268136455509"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy import stats\n",
    "np.sqrt(stats.sem(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.004771035310889357"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(res)/np.sqrt(9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.004771035310889357"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats.sem(res)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
