{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Modify the output and input model folder\n",
    "\n",
    "Test the impact of neighbor sampling size for inference\n",
    "\n",
    "\"\"\"\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "# os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "import sys\n",
    "sys.path.extend([\"../\", \"./\"])\n",
    "import random\n",
    "import io\n",
    "\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import dgl\n",
    "import dgl.data\n",
    "import dgl.nn as dglnn\n",
    "from dgl.data import AsNodePredDataset\n",
    "from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler\n",
    "from tqdm import tqdm\n",
    "\n",
    "from model import SAGE\n",
    "from utils import torch_f1, load_train_conf\n",
    "from data_utils import load_data\n",
    "\n",
    "device = 'cuda:0'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "# Utility functions\n",
    "def evaluate(mode, dataloader):\n",
    "    model.eval()\n",
    "    ys = []\n",
    "    y_hats = []\n",
    "    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n",
    "        with torch.no_grad():\n",
    "            x = blocks[0].srcdata['feat']\n",
    "            ys.append(blocks[-1].dstdata['label'])\n",
    "            y_hats.append(model(blocks, x))\n",
    "    return torch_f1(torch.cat(y_hats), torch.cat(ys))\n",
    "\n",
    "def layerwise_infer(device, graph, nid, model, batch_size, is_sample=False, sample_size=2):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        pred = model.inference(graph, device, batch_size, is_sample=is_sample, sample_size=sample_size) # pred in buffer_device\n",
    "        pred = pred[nid]\n",
    "        label = graph.ndata['label'][nid].to(pred.device)\n",
    "        return torch_f1(pred, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configs\n",
    "\n",
    "## ogbn-products\n",
    "data_name = \"ogbn-products\"\n",
    "model_folder = \"../result/ogbn-products/SAGE_time/Layer3/\"\n",
    "test_sample_size_pool = [20, 40]\n",
    "\n",
    "\n",
    "# Comment settings\n",
    "train_conf_path = os.path.join(model_folder, \"train_conf.yml\")\n",
    "\n",
    "# Data loading, preprocessing\n",
    "train_conf = load_train_conf(train_conf_path)\n",
    "\n",
    "g = load_data(data_name)\n",
    "g = g.to(device)\n",
    "\n",
    "in_size = g.ndata['feat'].shape[1]\n",
    "out_size = g.num_classes\n",
    "model = SAGE(in_size, train_conf[\"hidden_size\"], out_size, GNN_layer=train_conf[\"hidden_layer\"])\n",
    "model_path = os.path.join(model_folder, \"model_stat\")\n",
    "\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "model_state = CPU_Unpickler(open(model_path,\"rb\")).load()\n",
    "model.load_state_dict(model_state)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_train_p_test, p_train_f_test = defaultdict(list), []\n",
    "\n",
    "for I in range(1):\n",
    "    print(I)\n",
    "    acc = layerwise_infer(device, g, g.test_idx, model, batch_size=32, is_sample=False)\n",
    "    p_train_f_test.append(acc.item())\n",
    "    for test_sample_size in test_sample_size_pool:\n",
    "        acc = layerwise_infer(device, g, g.test_idx, model, batch_size=32, is_sample=True, sample_size=test_sample_size)\n",
    "        p_train_p_test[test_sample_size].append(acc.item())\n",
    "    \n",
    "    # model.load_state_dict(pickle.loads(full_neighbor_model_state))\n",
    "    # acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=train_conf[\"batch_size\"], is_sample=False)\n",
    "    # f_train_f_test.append(acc.item())\n",
    "    # for test_sample_size in test_sample_size_pool:\n",
    "    #     acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=train_conf[\"batch_size\"], is_sample=True, sample_size=test_sample_size)\n",
    "    #     f_train_p_test[test_sample_size].append(acc.item())\n",
    "    \n",
    "    \n",
    "# print(\"Test Accuracy {:.4f}\".format(acc.item()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output to latex\n",
    "def to_latex(res):\n",
    "    mean = np.mean(res)\n",
    "    std = np.std(res)\n",
    "    res_str = \"{:.4f}\\pm{:.4f}\".format(mean, std)\n",
    "    return res_str\n",
    "\n",
    "res = []\n",
    "for size in test_sample_size_pool:\n",
    "    res.append(to_latex(p_train_p_test[size]))\n",
    "res.append(to_latex(p_train_f_test))\n",
    "res = [\"$\"+i+\"$\" for i in res]\n",
    "res_str = \" & \".join(res)\n",
    "print(res_str)\n",
    "\n",
    "# res = []\n",
    "# for size in test_sample_size_pool:\n",
    "#     res.append(to_latex(f_train_p_test[size]))\n",
    "# res.append(to_latex(f_train_f_test))\n",
    "# res = [\"$\"+i+\"$\" for i in res]\n",
    "# res_str = \" \".join(res)\n",
    "# print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
