{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Test the impact of neighbor sampling size for inference\n",
    "\"\"\"\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "# os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "import sys\n",
    "sys.path.extend([\"../\"])\n",
    "import random\n",
    "from time import time\n",
    "import io\n",
    "import glob\n",
    "\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchmetrics.functional as MF\n",
    "import dgl\n",
    "import dgl.data\n",
    "import dgl.nn as dglnn\n",
    "from dgl.data import AsNodePredDataset\n",
    "from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler\n",
    "from tqdm import tqdm\n",
    "\n",
    "from model import SAGE\n",
    "from data_utils import load_data\n",
    "from utils import Logger, load_train_conf\n",
    "from script_utils import latexTableBase\n",
    "\n",
    "DEVICE = \"cpu\"\n",
    "\n",
    "# Load model\n",
    "# device = 'cuda:3'\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "def load_SAGE_model(g, train_conf, model_folder, device=\"cpu\"):\n",
    "    in_size = g.ndata['feat'].shape[1]\n",
    "    out_size = g.num_classes\n",
    "    model = SAGE(in_size, train_conf[\"hidden_size\"], out_size,\n",
    "                GNN_layer=train_conf[\"hidden_layer\"], dropout=train_conf[\"dropout\"])\n",
    "    model_path = os.path.join(model_folder, \"model_stat\")\n",
    "\n",
    "    model_state = CPU_Unpickler(open(model_path,\"rb\")).load()\n",
    "    model.load_state_dict(model_state)\n",
    "    model = model.to(device)\n",
    "    return model\n",
    "\n",
    "def test_inference_time(g, model, device, is_sample=False, sample_size=10, eval_size=10000):\n",
    "    sampling_time, infer_time = [], []\n",
    "\n",
    "    val_idx = g.val_idx[all_ind[:eval_size]].to(\"cpu\")\n",
    "    if is_sample:\n",
    "        sampler = NeighborSampler([sample_size] * model.GNN_layer,  # fanout for [layer-0, layer-1, layer-2]\n",
    "                                  prefetch_node_feats=['feat'],\n",
    "                                  prefetch_labels=['label'])\n",
    "    else:\n",
    "        sampler = MultiLayerFullNeighborSampler(model.GNN_layer,\n",
    "                                                prefetch_node_feats=['feat'],\n",
    "                                                prefetch_labels=['label'])\n",
    "    tic_data_loader = time()\n",
    "    val_dataloader = DataLoader(g, val_idx, sampler, device=\"cpu\",\n",
    "                                batch_size=1, shuffle=True,\n",
    "                                drop_last=False, num_workers=0,\n",
    "                                )\n",
    "    print(\"data loader created in {:.6f} seconds\".format(time() - tic_data_loader))\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        total_loss = 0\n",
    "        tic_sampling = time()\n",
    "        for it, (input_nodes, output_nodes, blocks) in enumerate(val_dataloader):\n",
    "            sampling_time.append(time() - tic_sampling)\n",
    "            tic_infer = time()\n",
    "            blocks = [i.to(device) for i in blocks]\n",
    "            x = blocks[0].srcdata['feat']\n",
    "            y = blocks[-1].dstdata['label']\n",
    "            y_hat = model(blocks, x)\n",
    "            infer_time.append(time() - tic_infer)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            total_loss += loss.item()\n",
    "            del blocks\n",
    "            del loss\n",
    "            del y_hat\n",
    "            torch.cuda.empty_cache()\n",
    "            tic_sampling = time()\n",
    "            \n",
    "    return sampling_time, infer_time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization\n",
    "# if not os.path.exists(model_folder):\n",
    "#     os.makedirs(model_folder)\n",
    "\n",
    "# g = g.to(\"cpu\")\n",
    "# # Config the log path\n",
    "# log_path = os.path.join(\"../result/best/speedtest\", name+f\"_SAGE_L{train_conf['hidden_layer']}.txt\")\n",
    "# sys.stdout = Logger(log_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "root_folder = \"../result\"\n",
    "res_root = \"../script_logs/speedtest/SAGE\"\n",
    "if not os.path.exists(res_root):\n",
    "    os.makedirs(res_root)\n",
    "split_seed = 0\n",
    "test_sample_size_pool = [20]\n",
    "# name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "# name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\"]\n",
    "name_pool = [\"ogbn-products\"]\n",
    "\n",
    "for name in name_pool:\n",
    "    print(\"*\"*21)\n",
    "    print(f\"Processing {name}...\")\n",
    "    model_folder = os.path.join(root_folder, name, f\"SAGE/seed_{split_seed}\")\n",
    "    conf_path = os.path.join(model_folder, \"train_conf.yml\")\n",
    "\n",
    "    train_conf = load_train_conf(conf_path)\n",
    "    g = load_data(name, seed=split_seed)\n",
    "    model = load_SAGE_model(g, train_conf, model_folder, device=DEVICE)\n",
    "\n",
    "    # testing for the sampling time\n",
    "    all_ind = list(range(g.val_idx.size()[0]))\n",
    "    random.seed(666)\n",
    "    random.shuffle(all_ind)\n",
    "\n",
    "    all_sample_time, all_infer_time = [], []\n",
    "\n",
    "    for test_size in test_sample_size_pool:\n",
    "        sampling_time, infer_time = test_inference_time(g, model, DEVICE, is_sample=True, sample_size=test_size)\n",
    "        all_sample_time.append(sampling_time)\n",
    "        all_infer_time.append(infer_time)\n",
    "    # sampling_time, infer_time = test_inference_time(g, model, DEVICE, is_sample=False)\n",
    "    # all_sample_time.append(sampling_time)\n",
    "    # all_infer_time.append(infer_time)\n",
    "\n",
    "    with open(os.path.join(res_root, f\"{name}.pkl\"), \"wb\") as fout:\n",
    "        pickle.dump([all_sample_time, all_infer_time, test_sample_size_pool], fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# display the mean of all \n",
    "name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "\n",
    "table_list = []\n",
    "for name in name_pool:\n",
    "    with open(os.path.join(res_root, f\"{name}.pkl\"), \"rb\") as fin:\n",
    "        [all_sample_time, all_infer_time, test_sample_size_pool] = pickle.load(fin)\n",
    "    cur_list = [name]\n",
    "    all_time_list =[[a+b for a, b in zip(s, i)] for s, i in zip(all_sample_time, all_infer_time)]\n",
    "    cur_list.append(f\"{np.mean(all_time_list[0])*1000:.1f}\")\n",
    "    table_list.append(cur_list)\n",
    "    \n",
    "table_str = latexTableBase.gen_table(table_list)\n",
    "\n",
    "print(table_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# final string generation with various functions\n",
    "\n",
    "def gen_res_table(stime, itime, func_list, func_name_list, head_list):\n",
    "    def apply_func(time_list, func):\n",
    "        return [func(i[1:]) for i in time_list]\n",
    "    rows = []\n",
    "    # first row\n",
    "    rows.append(\"\\hline\")\n",
    "    cur_row = \"Neighbor size & \" + \" & \".join(head_list) + \" \\\\\\\\ \\hline\"\n",
    "    rows.append(cur_row)\n",
    "    \n",
    "    for name, func in zip(func_name_list, func_list):\n",
    "        svals = apply_func(stime, func)\n",
    "        ivals = apply_func(itime, func)\n",
    "        all_list = [[a+b for a, b in zip(s, i)] for s, i in zip(stime, itime)]\n",
    "        alls = apply_func(all_list, func)\n",
    "        comb = [\"{:.5f}/{:.5f}/{:.5f}\".format(i*1000, j*1000, k*1000) for i, j, k in zip(svals, ivals, alls)]\n",
    "        cur_row = name + \" & \" + \" & \".join(comb) + \" \\\\\\\\\"\n",
    "        rows.append(cur_row)\n",
    "        \n",
    "    table_str = \"\\n\".join(rows)\n",
    "    \n",
    "    table_str += \" \\hline\"\n",
    "    return table_str\n",
    "        \n",
    "func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]\n",
    "func_name_list = [\"mean\", \"90-percentile\", \"99-percentile\", \"max\", \"std\"]\n",
    "head_list = [str(i) for i in test_sample_size_pool] + [\"full\"]\n",
    "\n",
    "res_str = gen_res_table(all_sample_time, all_infer_time, func_list, func_name_list, head_list)\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
