{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\hline\n",
      "Neighbor size &  \\\\ \\hline\n",
      "mean & 0.02057/0.26347/0.28404 \\\\\n",
      "90-percentile & 0.02074/0.27561/0.29612 \\\\\n",
      "99-percentile & 0.02337/0.30494/0.32973 \\\\\n",
      "max & 0.72098/10.99777/11.01828 \\\\\n",
      "std & 0.00937/0.17435/0.17467 \\\\ \\hline\n",
      "********************\n",
      "MLP\n",
      "\\hline\n",
      "Neighbor size &  \\\\ \\hline\n",
      "mean & 0.03091/0.42880/0.45970 \\\\\n",
      "90-percentile & 0.03123/0.43869/0.47016 \\\\\n",
      "99-percentile & 0.03958/0.49878/0.53645 \\\\\n",
      "max & 0.16880/9.19557/9.22632 \\\\\n",
      "std & 0.00378/0.22619/0.22623 \\\\ \\hline\n",
      "********************\n",
      "MLP RevGNN-112\n",
      "\\hline\n",
      "Neighbor size &  \\\\ \\hline\n",
      "mean & 0.03099/0.42666/0.45765 \\\\\n",
      "90-percentile & 0.03147/0.43869/0.47016 \\\\\n",
      "99-percentile & 0.03934/0.48281/0.52286 \\\\\n",
      "max & 0.51808/9.16290/9.19509 \\\\\n",
      "std & 0.00606/0.18840/0.18853 \\\\ \\hline\n",
      "********************\n",
      "MLP SAGE\n",
      "\\hline\n",
      "Neighbor size &  \\\\ \\hline\n",
      "mean & 0.03113/0.40821/0.43934 \\\\\n",
      "90-percentile & 0.03147/0.42629/0.45824 \\\\\n",
      "99-percentile & 0.03839/0.45444/0.48780 \\\\\n",
      "max & 1.96528/2.33984/2.42591 \\\\\n",
      "std & 0.02066/0.03940/0.04489 \\\\ \\hline\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "from time import time\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "from models import Model\n",
    "import torch\n",
    "from dataloader import load_data\n",
    "from utils import get_training_config\n",
    "\n",
    "class Logger(object):\n",
    "    \"\"\"\n",
    "    Print hook in order to print the output to both \n",
    "    the terminal and some log file. \n",
    "    \"\"\"\n",
    "    def __init__(self, log_path):\n",
    "        self.terminal = sys.stdout\n",
    "        self.log = open(log_path, \"a\")\n",
    "\n",
    "    def write(self, message):\n",
    "        self.terminal.write(message)\n",
    "        self.log.write(message)\n",
    "        self.flush()\n",
    "\n",
    "    def flush(self):\n",
    "        self.terminal.flush()\n",
    "        self.log.flush()\n",
    "\n",
    "log_path = \"log.txt\"\n",
    "sys.stdout = Logger(log_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(\n",
    "    state_dict_path, \n",
    "    data_name, \n",
    "    data_cache_path, \n",
    "    dw_emb_path,\n",
    "    student_name, \n",
    "    config_path,\n",
    "    device='cpu',\n",
    "    n_data=5\n",
    "):\n",
    "    # get feature\n",
    "    g, labels, idx_train, idx_val, idx_test = load_data(\n",
    "        data_name,\n",
    "        data_cache_path,\n",
    "        split_idx=0,\n",
    "        seed=0,\n",
    "        labelrate_train=20,\n",
    "        labelrate_val=30,\n",
    "    )\n",
    "    all_ind = list(range(len(idx_val)))\n",
    "    random.seed(666)\n",
    "    random.shuffle(all_ind)\n",
    "    feats = g.ndata[\"feat\"].to(device)\n",
    "    # dw \n",
    "    loaded_dw_emb = torch.load(dw_emb_path).to(device)\n",
    "    position_feature = loaded_dw_emb\n",
    "    len_position_feature = position_feature.shape[-1]\n",
    "\n",
    "    # get model\n",
    "    conf = {}\n",
    "    conf = get_training_config(\n",
    "        config_path,\n",
    "        student_name,\n",
    "        data_name\n",
    "    )  # Note: student config\n",
    "    conf['feat_dim'] = g.ndata[\"feat\"].shape[1]\n",
    "    conf['label_dim'] = labels.int().max().item() + 1\n",
    "    conf['device'] = device\n",
    "    model = Model(conf, None, len_position_feature)\n",
    "    state_dict = torch.load(state_dict_path)\n",
    "    model.load_state_dict(state_dict)\n",
    "    model.eval()\n",
    "\n",
    "    # # inference\n",
    "    # batch_mlp_emb, logits = model(None, feats)\n",
    "    # return batch_mlp_emb, logits\n",
    "    # inference\n",
    "    feats = g.ndata[\"feat\"].to(device)\n",
    "    feats = torch.cat([feats, position_feature], dim=1)\n",
    "    stime, ctime = [], []\n",
    "    sample_tic = time()\n",
    "    for i in all_ind[:n_data]:\n",
    "        cur_x = feats[i, :].reshape(1, -1)\n",
    "        stime.append(time() - sample_tic)\n",
    "        inf_tic = time()\n",
    "        emb, y = model(None, cur_x)\n",
    "        ctime.append(time() - inf_tic)\n",
    "        sample_tic = time()\n",
    "    \n",
    "    return stime, ctime\n",
    "\n",
    "student_name='MLP'\n",
    "model_name='SAGE'\n",
    "print(\"*\"*20)\n",
    "print(student_name, model_name)\n",
    "\n",
    "stime, ctime = inference(\n",
    "    state_dict_path='../../result/ogbn-products/NOSMOG_'+model_name+'_'+student_name+'/seed_0/model.pth',\n",
    "    data_name='ogbn-products',\n",
    "    student_name=student_name,\n",
    "    data_cache_path='../../dataset',\n",
    "    dw_emb_path='../../result/ogbn-products/NOSMOG_'+model_name+'_'+student_name+'/dw_emb.pt',\n",
    "    config_path='./train.conf.yaml',\n",
    "    n_data=10000\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# final string generation\n",
    "\n",
    "def gen_res_table(stime, itime, func_list, func_name_list, head_list):\n",
    "    def apply_func(time_list, func):\n",
    "        return [func(i[1:]) for i in time_list]\n",
    "    rows = []\n",
    "    # first row\n",
    "    rows.append(\"\\hline\")\n",
    "    cur_row = \"Neighbor size & \" + \" & \".join(head_list) + \" \\\\\\\\ \\hline\"\n",
    "    rows.append(cur_row)\n",
    "    \n",
    "    for name, func in zip(func_name_list, func_list):\n",
    "        svals = apply_func(stime, func)\n",
    "        ivals = apply_func(itime, func)\n",
    "        all_list = [[a+b for a, b in zip(s, i)] for s, i in zip(stime, itime)]\n",
    "        alls = apply_func(all_list, func)\n",
    "        comb = [\"{:.5f}/{:.5f}/{:.5f}\".format(i*1000, j*1000, k*1000) for i, j, k in zip(svals, ivals, alls)]\n",
    "        cur_row = name + \" & \" + \" & \".join(comb) + \" \\\\\\\\\"\n",
    "        rows.append(cur_row)\n",
    "        \n",
    "    table_str = \"\\n\".join(rows)\n",
    "    \n",
    "    table_str += \" \\hline\"\n",
    "    return table_str\n",
    "        \n",
    "func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]\n",
    "func_name_list = [\"mean\", \"90-percentile\", \"99-percentile\", \"max\", \"std\"]\n",
    "head_list = []\n",
    "\n",
    "res_str = gen_res_table([stime], [ctime], func_list, func_name_list, head_list)\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
