{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "from data_utils import load_data\n",
    "from path_dict import name_model_pool\n",
    "\n",
    "\n",
    "# result_root = \"../result/baselines\"\n",
    "result_root = \"../result/baselines_bootstrap\"\n",
    "script_result_root = \"../script_logs/neighbour_alg\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# name, model = \"cora\", \"SAGE\"\n",
    "# name, model = \"cora\", \"geomGCN\"\n",
    "# name, model = \"citeseer\", \"SAGE\"\n",
    "# name, model = \"citeseer\", \"geomGCN\"\n",
    "# name, model = \"pubmed\", \"SAGE\"\n",
    "# name, model = \"pubmed\", \"geomGCN\"\n",
    "# name, model = \"a-computer\", \"SAGE\"\n",
    "# name, model = \"a-computer\", \"exphormer\"\n",
    "# name, model = \"a-photo\", \"SAGE\"\n",
    "# name, model = \"a-photo\", \"exphormer\"\n",
    "# name, model = \"ogbn-arxiv\", \"SAGE\"\n",
    "# name, model = \"ogbn-arxiv\", \"DRGAT\"\n",
    "# name, model = \"ogbn-products\", \"SAGE\"\n",
    "name, model = \"ogbn-products\", \"RevGNN-112\"\n",
    "\n",
    "method_pool = [\"GLNN\", \"NOSMOG\", \"PPRGo\", \"SGC\", \"SDGNN\"]\n",
    "to_bootstrap = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_acc_MLP(res_folder, seed_list=list(range(10))):\n",
    "    acc_list = []\n",
    "    for seed in seed_list:\n",
    "        cur_path = os.path.join(res_folder, f\"seed_{seed}\", \"test_acc.txt\")\n",
    "        with open(cur_path, \"r\") as fin:\n",
    "            acc_list.append(float(fin.readline()))\n",
    "    return acc_list\n",
    "\n",
    "def load_acc_ablation(res_path):\n",
    "    with open(res_path, \"rb\") as fin:\n",
    "        cur_dict = pickle.load(fin)\n",
    "    return cur_dict\n",
    "\n",
    "grand_acc_dict = dict()\n",
    "\n",
    "for  name, model in name_model_pool:\n",
    "    all_acc = defaultdict(list)\n",
    "    # load GLNN acc\n",
    "    res_path = os.path.join(result_root, name, f\"{model}_GLNN\")\n",
    "    cur_acc = load_acc_MLP(res_path)\n",
    "    all_acc[\"GLNN\"] = cur_acc\n",
    "    # load NOSMOG acc\n",
    "    res_path = os.path.join(result_root, name, f\"{model}_NOSMOG\")\n",
    "    cur_acc = load_acc_MLP(res_path)\n",
    "    all_acc[\"NOSMOG\"] = cur_acc\n",
    "    # load PPRGo and SGC and SDGNN\n",
    "    res_path = os.path.join(script_result_root, f\"{name}_{model}\", \"all_res_dict.pkl\")\n",
    "    cur_dict = load_acc_ablation(res_path)\n",
    "    all_acc[\"PPRGo\"] = cur_dict[\"PPRGo\"]\n",
    "    all_acc[\"SGC\"] = cur_dict[\"SGC\"]\n",
    "    all_acc[\"SDGNN\"] = cur_dict[\"SDMP\"]\n",
    "    # all_acc[\"SDGNN-equal\"] = cur_dict[\"equal\"]\n",
    "    grand_acc_dict[name, model] = all_acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../result/all_acc_table.pkl\", 'wb') as fout:\n",
    "    pickle.dump(grand_acc_dict, fout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Boostrap for scenarios with insufficient runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_root_source = \"../result/baselines\"\n",
    "result_root = \"../result/baselines_bootstrap\"\n",
    "\n",
    "# name, model, scenario = \"ogbn-products\", \"SAGE\", \"NOSMOG\"\n",
    "# name, model, scenario = \"ogbn-products\", \"RevGNN-112\", \"NOSMOG\"\n",
    "# name, model, scenario = \"ogbn-arxiv\", \"DRGAT\", \"NOSMOG\"\n",
    "name, model, scenario = \"ogbn-arxiv\", \"SAGE\", \"NOSMOG\"\n",
    "\n",
    "import torchmetrics, torch\n",
    "def torch_f1(y_hats, ys, task='multilabel', average='micro'):\n",
    "    # num_class = len(Counter(ys.detach().cpu().numpy()).keys())\n",
    "    num_class = int(y_hats.shape[1])\n",
    "    f1_score = torchmetrics.F1Score(num_classes=num_class, task=task, average=average)\n",
    "    return f1_score(y_hats, ys)\n",
    "\n",
    "def get_local_f1(y_hats, ys, index):\n",
    "    return torch_f1(torch.tensor(y_hats)[index, :], ys[index]).data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split = 2\n",
    "\n",
    "for seed in range(5):\n",
    "    target_folder_root = os.path.join(result_root, name, f\"{model}_{scenario}\")\n",
    "    source_folder = os.path.join(result_root_source, name, f\"{model}_{scenario}\", f\"seed_{seed}\")\n",
    "\n",
    "    g = load_data(name, seed=seed)\n",
    "\n",
    "    prediction = np.load(os.path.join(source_folder, \"out.npz\"))[\"arr_0\"]\n",
    "    print(torch_f1(torch.tensor(prediction)[g.test_idx, :], g.ndata['label'][g.test_idx]).data)\n",
    "    index = torch.randperm(g.test_idx.shape[0])\n",
    "    test_idx = g.test_idx[index]\n",
    "    split_size = int(g.test_idx.shape[0] / split)\n",
    "    for each_split in range(split):\n",
    "        cur_test_idx = test_idx[each_split*split_size : (each_split+1)*split_size]\n",
    "        my_acc = get_local_f1(prediction, g.ndata['label'], cur_test_idx)\n",
    "        \n",
    "        save_folder = os.path.join(target_folder_root, f\"seed_{seed*split + each_split}\")\n",
    "        if not os.path.exists(save_folder):\n",
    "            os.makedirs(save_folder)\n",
    "        save_path = os.path.join(save_folder, \"test_acc.txt\")\n",
    "        with open(save_path, \"w\") as fout:\n",
    "            fout.write(str(my_acc.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
