{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from data_utils import load_data\n",
    "\n",
    "\n",
    "result_root = \"../result\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# name, model = \"cora\", \"SAGE\"\n",
    "# name, model = \"cora\", \"geomGCN\"\n",
    "# name, model = \"citeseer\", \"SAGE\"\n",
    "# name, model = \"citeseer\", \"geomGCN\"\n",
    "# name, model = \"pubmed\", \"SAGE\"\n",
    "# name, model = \"pubmed\", \"geomGCN\"\n",
    "# name, model = \"a-computer\", \"SAGE\"\n",
    "# name, model = \"a-computer\", \"exphormer\"\n",
    "name, model = \"a-photo\", \"SAGE\"\n",
    "# name, model = \"a-photo\", \"exphormer\"\n",
    "# name, model = \"ogbn-arxiv\", \"SAGE\"\n",
    "# name, model = \"ogbn-arxiv\", \"DRGAT\"\n",
    "# name, model = \"ogbn-products\", \"SAGE\"\n",
    "# name, model = \"ogbn-products\", \"RevGNN-112\"\n",
    "\n",
    "seed = 0 # 0,1,2,3,4,5,6,7,8,9\n",
    "\n",
    "cur_target_path = os.path.join(result_root, name, model, f'seed_{seed}', 'out.npz')\n",
    "print(cur_target_path)\n",
    "\n",
    "g = load_data(name, seed=seed)\n",
    "target = np.load(cur_target_path)[\"arr_0\"]\n",
    "# g.train_idx, g.val_idx, g.test_idx\n",
    "# g.ndata['feat'], g.ndata['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchmetrics, torch\n",
    "def torch_f1(y_hats, ys, task='multilabel', average='micro'):\n",
    "    # num_class = len(Counter(ys.detach().cpu().numpy()).keys())\n",
    "    num_class = int(y_hats.shape[1])\n",
    "    f1_score = torchmetrics.F1Score(num_classes=num_class, task=task, average=average)\n",
    "    return f1_score(y_hats, ys)\n",
    "\n",
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"),(\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\")]\n",
    "# name_model_pool = [(\"cora\", \"SAGE\")]\n",
    "all_cache = {}\n",
    "for name, model in name_model_pool:\n",
    "    all_res = []\n",
    "    for seed in range(10):\n",
    "        g = load_data(name, seed=seed)\n",
    "        cur_target_path = os.path.join(result_root, name, model, f'seed_{seed}', 'out.npz')\n",
    "        # print(\"!!!\", cur_target_path)\n",
    "        target = np.load(cur_target_path)[\"arr_0\"]\n",
    "        all_res.append(torch_f1(torch.tensor(target)[g.test_idx, :], g.ndata['label'][g.test_idx]).data)\n",
    "        \n",
    "    all_cache[name, model] = all_res\n",
    "\n",
    "for name, model in name_model_pool:\n",
    "    print(name, model, np.mean(all_cache[name, model])*100, np.std(all_cache[name, model])*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
