{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The overall test of MLP for accuracy and inference time\n",
    "\n",
    "\"\"\"\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "import sys\n",
    "sys.path.extend([\"../\"])\n",
    "import random\n",
    "from time import time\n",
    "import io\n",
    "from collections import defaultdict\n",
    "\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "import dgl\n",
    "import dgl.data\n",
    "from dgl.data import AsNodePredDataset\n",
    "\n",
    "from model import MLP\n",
    "from data_utils import PlainLoader as MyLoader\n",
    "from data_utils import load_data\n",
    "from utils import evaluate, Logger, load_train_conf\n",
    "from script_utils import latexTableBase\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "DEVICE = 'cuda:4'\n",
    "\n",
    "# Utility functions\n",
    "\n",
    "def train(device, g, model, train_conf):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = g.train_idx.to(device)\n",
    "    val_idx = g.val_idx.to(device)\n",
    "    features = g.ndata['feat'].to(device)\n",
    "    labels = g.ndata['label'].to(device)\n",
    "    \n",
    "    train_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], train_idx)\n",
    "    val_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], val_idx)\n",
    "\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf['weight_decay'])\n",
    "    \n",
    "    best_state, best_val, best_epoch = None, 0, 0\n",
    "    for epoch in tqdm(range(train_conf[\"epoch\"])):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (x, y) in enumerate(train_dataloader):\n",
    "            y_hat = model(x)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "        acc = evaluate(model, val_dataloader)\n",
    "        # print(\"Epoch {:04d} | ACC {:.4f}\"\n",
    "        #      .format(epoch, acc.item()))\n",
    "        if acc.item() > best_val:\n",
    "            best_val = acc.item()\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "            \n",
    "    print(\"Epoch {:05d} | F1_micro {:.4f} \"\n",
    "           .format(best_epoch, best_val))\n",
    "    \n",
    "    return best_state\n",
    "\n",
    "# # Config the log path\n",
    "# log_path = os.path.join(\"../result/best/speedtest\", name+\"_MLP.txt\")\n",
    "# sys.stdout = Logger(log_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Parameter settings and data loading\n",
    "# def get_index(dataset):\n",
    "#     def mask_to_ind(mask):\n",
    "#         return torch.tensor([i for i, flag in enumerate(mask) if flag])\n",
    "#     dataset.train_idx = mask_to_ind(dataset.train_mask)\n",
    "#     dataset.val_idx = mask_to_ind(dataset.val_mask)\n",
    "#     dataset.test_idx = mask_to_ind(dataset.test_mask)\n",
    "\n",
    "\n",
    "train_conf_MLP_dict = defaultdict(dict)\n",
    "# # reddit\n",
    "# name = \"reddit\"\n",
    "# dataset = dgl.data.RedditDataset()\n",
    "# get_index(dataset)\n",
    "# train_conf[\"batch_size\"] = 1024\n",
    "# train_conf[\"epoch\"] = 50\n",
    "# train_conf[\"lr\"] = 0.001\n",
    "# train_conf['weight_decay'] = 5e-6\n",
    "# hidden_size = None if train_conf[\"IS_ONE_LAYER\"] else [256, 256]\n",
    "\n",
    "# cora\n",
    "train_conf_MLP_dict['cora']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 500,\n",
    "    'lr': 0.01,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 64,\n",
    "    'weight_decay': 0.005,\n",
    "    'dropout': 0.6\n",
    "}\n",
    "\n",
    "\n",
    "# citeceer\n",
    "train_conf_MLP_dict['citeseer']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 100,\n",
    "    'lr': 0.01,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 64,\n",
    "    'weight_decay': 0.00005,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "# pubmed\n",
    "train_conf_MLP_dict['pubmed']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 100,\n",
    "    'lr': 0.01,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 64,\n",
    "    'weight_decay': 0.0005,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "# a-computer\n",
    "train_conf_MLP_dict['a-computer']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 100,\n",
    "    'lr': 0.01,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 128,\n",
    "    'weight_decay': 0.0005,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "# a-photo\n",
    "train_conf_MLP_dict['a-photo']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 100,\n",
    "    'lr': 0.01,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 128,\n",
    "    'weight_decay': 0.0005,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "# # reddit\n",
    "# name = \"reddit\"\n",
    "# dataset = dgl.data.RedditDataset()\n",
    "# g = dataset[0]\n",
    "# get_index(dataset)\n",
    "# train_conf[\"batch_size\"] = 1024\n",
    "# train_conf[\"epoch\"] = 50\n",
    "# train_conf[\"lr\"] = 0.001\n",
    "# train_conf[\"train_neighbor_size\"] = 10\n",
    "# train_conf[\"hidden_layer\"] = 2\n",
    "# train_conf[\"hidden_size\"]= 64\n",
    "\n",
    "# OGBN-arxiv\n",
    "train_conf_MLP_dict['ogbn-arxiv']={\n",
    "    'batch_size': 4096,\n",
    "    'epoch': 200,\n",
    "    'lr': 0.001,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 256,\n",
    "    'weight_decay': 5e-6,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "\n",
    "# OGBN-products\n",
    "train_conf_MLP_dict['ogbn-products']={\n",
    "    'batch_size': 512,\n",
    "    'epoch': 200,\n",
    "    'lr': 0.001,\n",
    "    'hidden_layer': 2,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 256,\n",
    "    'weight_decay': 5e-6,\n",
    "    'dropout': 0.0\n",
    "}\n",
    "\n",
    "# Initialization\n",
    "# model_folder = os.path.join(\"../result\", name , \"MLP\")\n",
    "# if not os.path.exists(model_folder):\n",
    "#     os.makedirs(model_folder)\n",
    "\n",
    "# # Config the log path\n",
    "# log_path = os.path.join(\"../result/best/speedtest\", name+\"_MLP.txt\")\n",
    "# sys.stdout = Logger(log_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_pool = [0]\n",
    "name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "\n",
    "for name in name_pool:\n",
    "    print(\"*\"*21)\n",
    "    print(name)\n",
    "    train_conf = train_conf_MLP_dict[name]\n",
    "\n",
    "    model_folder = os.path.join(\"../result\", name , \"MLP\")\n",
    "    if not os.path.exists(model_folder):\n",
    "        os.makedirs(model_folder)\n",
    "\n",
    "    all_acc =  []\n",
    "    for seed in seed_pool:\n",
    "        print(seed)\n",
    "        # training loop\n",
    "        g = load_data(name, seed=seed).to(DEVICE)\n",
    "        # create GraphSAGE model\n",
    "        in_size =g.ndata['feat'].shape[1]\n",
    "        out_size = g.num_classes\n",
    "        hidden_size = None if train_conf[\"IS_ONE_LAYER\"] else [train_conf[\"hidden_size\"] ] * train_conf['hidden_layer']\n",
    "\n",
    "        model = MLP(in_size, hidden_size, out_size, dropout=train_conf[\"dropout\"]).to(DEVICE)\n",
    "        best_model_state = train(DEVICE, g, model, train_conf)\n",
    "        \n",
    "        with open(os.path.join(model_folder, \"state_dict_\"+str(seed)), \"wb\") as f:\n",
    "            f.write(best_model_state)\n",
    "\n",
    "        model.load_state_dict(pickle.loads(best_model_state))\n",
    "        test_dataloader = MyLoader(g.ndata['feat'].to(DEVICE), torch.tensor(g.ndata['label']).to(DEVICE),\n",
    "                                train_conf[\"batch_size\"], g.test_idx.to(DEVICE))\n",
    "        acc = evaluate(model, test_dataloader)\n",
    "        all_acc.append(acc.item())\n",
    "\n",
    "    print(\"Test F1 score micro {:.4f}\".format(acc.item()))\n",
    "    # save the f1 as well\n",
    "    with open(os.path.join(model_folder, \"f1.txt\"), \"w\") as fout:\n",
    "        fout.write(str(acc.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output to latex\n",
    "from utils import confidence_interval\n",
    "\n",
    "def to_latex(res):\n",
    "    mean = np.mean(res)\n",
    "    std = np.std(res)\n",
    "    _, interval = confidence_interval(res, confidence=0.95)\n",
    "    res_str = \"Mean and std: {:.2f}$\\pm${:.2f}\\n\".format(mean*100, std*100)\n",
    "    res_str += f\"Mean and 95 interval: {mean*100:.2f}$\\pm${interval*100:.2f}\\n\"\n",
    "    return res_str\n",
    "\n",
    "res = []\n",
    "res.append(to_latex(all_acc))\n",
    "res = [\"$\"+i+\"$\" for i in res]\n",
    "res_str = \" & \".join(res)\n",
    "print(res_str)\n",
    "print(train_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.stats as st\n",
    "data = all_acc\n",
    "a, b = st.t.interval(alpha=0.95, df=len(data)-1, loc=np.mean(data), scale=st.sem(data)) \n",
    "print((a+b)/2)\n",
    "# all_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test the effect of embedding of a GNN encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def mtrain(device, feat, g, model, train_conf):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = g.train_idx.to(device)\n",
    "    val_idx = g.val_idx.to(device)\n",
    "    features = feat.to(device)\n",
    "    labels = g.ndata['label'].to(device)\n",
    "    \n",
    "    model = model.to(device)\n",
    "    \n",
    "    train_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], train_idx)\n",
    "    val_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], val_idx)\n",
    "    \n",
    "    test_idx = g.test_idx.to(device)\n",
    "    test_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], test_idx)\n",
    "\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf['weight_decay'])\n",
    "    \n",
    "    best_state, best_val, best_epoch = None, 0, 0\n",
    "    for epoch in range(train_conf[\"epoch\"]):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (x, y) in enumerate(train_dataloader):\n",
    "            y_hat = model(x)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "        acc = evaluate(model, val_dataloader)\n",
    "        # print(\"Epoch {:04d} | ACC {:.4f}\"\n",
    "        #      .format(epoch, acc.item()))\n",
    "        if acc.item() > best_val:\n",
    "            best_val = acc.item()\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "        \n",
    "        test_acc = evaluate(model, test_dataloader)\n",
    "        print(f\"e: {epoch}, val acc: {acc.item():.4f}, test acc: {test_acc.item():.4f}\")\n",
    "            \n",
    "    print(\"Epoch {:05d} | F1_micro {:.4f} \"\n",
    "           .format(best_epoch, best_val))\n",
    "    \n",
    "    return best_state\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test the external embedding with variantion on the target\n",
    "all_acc =  []\n",
    "\n",
    "# feat_path = '/home/yaochen/Working/localhashgnn/result/ogbn-arxiv/DRGAT/seed_0/out_emb_list.npz'\n",
    "feat_path = '/home/yaochen/Working/localhashgnn/result/ogbn-products/RevGNN-112/seed_0/out_emb_list.npz'\n",
    "lam = 0.07\n",
    "\n",
    "hidden_size_pool = [[]]\n",
    "\n",
    "feat = torch.from_numpy(np.load(feat_path)['arr_0'])\n",
    "sig = torch.nn.Sigmoid()\n",
    "feat = sig(feat*10)\n",
    "\n",
    "# # add noise\n",
    "noise = torch.tensor(np.random.normal(1, 1, feat.size()), dtype=torch.float)\n",
    "old_feat = feat\n",
    "feat = old_feat * (1-lam) + lam * noise\n",
    "\n",
    "print(torch.norm(feat - old_feat) ** 2 / torch.norm(old_feat) **2)\n",
    "\n",
    "train_conf[\"weght_decay\"] = 1.0e-8\n",
    "train_conf[\"epoch\"] = 500\n",
    "\n",
    "print(train_conf)\n",
    "\n",
    "for hidden_size in hidden_size_pool:\n",
    "    print(\"*\"*10)\n",
    "    # create GraphSAGE model\n",
    "    in_size =feat.shape[1]\n",
    "    out_size = g.num_classes\n",
    "\n",
    "    model = MLP(in_size, hidden_size, out_size).to(device)\n",
    "    best_model_state = mtrain(device, feat, g, model, train_conf)\n",
    "    \n",
    "    # with open(os.path.join(model_folder, \"state_dict_\"+str(I)), \"wb\") as f:\n",
    "    #     f.write(best_model_state)\n",
    "\n",
    "    model.load_state_dict(pickle.loads(best_model_state))\n",
    "    test_dataloader = MyLoader(feat.to(device), torch.tensor(g.ndata['label']).to(device),\n",
    "                            train_conf[\"batch_size\"], g.test_idx[:20000].to(device))\n",
    "    acc = evaluate(model, test_dataloader)\n",
    "    all_acc.append(acc.item())\n",
    "    \n",
    "    print(feat.shape)\n",
    "    print(feat_path)\n",
    "    print(hidden_size)\n",
    "    print(\"****Test F1 score micro {:.4f}\".format(acc.item()))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test the external embedding\n",
    "all_acc =  []\n",
    "\n",
    "# feat_pool = ['out_emb_list_1.npz', 'out_emb_list_2.npz', 'out_emb_list_3.npz']\n",
    "feat_pool = ['out_emb_list.npz']\n",
    "feat_root = '/home/yaochen/data/runs/ogbn-arxiv/seed_0'\n",
    "\n",
    "hidden_size_pool = [[], [64]]\n",
    "\n",
    "# feat = torch.from_numpy(np.load('/home/yaochen/data/runs/cora/seed_234/out_emb_list_1.npz')['arr_0'])\n",
    "# hidden_size = [64]\n",
    "# # hidden_size = []\n",
    "\n",
    "# new_feat = feat[:, :144]\n",
    "# for i in range(1, 9):\n",
    "#     new_feat += feat[:, 144*i:144*(i+1)]\n",
    "# new_feat /= 9\n",
    "# feat = new_feat\n",
    "\n",
    "for feat_path in feat_pool:\n",
    "    feat = torch.from_numpy(np.load(os.path.join(feat_root, feat_path))['arr_0'])\n",
    "    for hidden_size in hidden_size_pool:\n",
    "        print(\"*\"*10)\n",
    "        # create GraphSAGE model\n",
    "        in_size =feat.shape[1]\n",
    "        out_size = g.num_classes\n",
    "\n",
    "        model = MLP(in_size, hidden_size, out_size).to(device)\n",
    "        best_model_state = mtrain(device, feat, g, model, train_conf)\n",
    "        \n",
    "        with open(os.path.join(model_folder, \"state_dict_\"+str(I)), \"wb\") as f:\n",
    "            f.write(best_model_state)\n",
    "\n",
    "        model.load_state_dict(pickle.loads(best_model_state))\n",
    "        test_dataloader = MyLoader(feat.to(device), torch.tensor(g.ndata['label']).to(device),\n",
    "                                train_conf[\"batch_size\"], g.test_idx.to(device))\n",
    "        acc = evaluate(model, test_dataloader)\n",
    "        all_acc.append(acc.item())\n",
    "        \n",
    "        print(feat.shape)\n",
    "        print(feat_path)\n",
    "        print(hidden_size)\n",
    "        print(\"****Test F1 score micro {:.4f}\".format(acc.item()))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test the graphSAGE\n",
    "from utils import load_sdmp_conf_with_default\n",
    "from data_utils import SDMPDataPre\n",
    "\n",
    "DEVICE=\"cuda:0\"\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "\n",
    "CONF_NAME = \"ogbn-products/SDMP/ogbn-products_SDMP_SAGE.yml\"\n",
    "TARGET_GNN_FOLDER = '../result/ogbn-products/SAGE'\n",
    "\n",
    "conf_path = os.path.join(CONF_ROOT_FOLDER, CONF_NAME)\n",
    "train_conf = load_sdmp_conf_with_default(conf_path)\n",
    "DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf[\"name\"])\n",
    "\n",
    "\n",
    "######################################\n",
    "h_model_path = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_path'])\n",
    "h_conf_path = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_conf_path'])\n",
    "# data loading and prepare the data\n",
    "preprocesser = SDMPDataPre(train_conf[\"name\"], train_conf[\"feature_normalize\"],\n",
    "                           train_conf['target_h_mode'],\n",
    "                           h_conf_path, h_model_path, train_conf[\"target_h_model\"], \n",
    "                           train_conf[\"h_init_theta_mode\"], train_conf[\"h_init_theta_k\"],\n",
    "                           train_conf[\"h_init_theta_k_fanout\"],\n",
    "                           train_conf[\"theta_cand_mode\"], train_conf[\"theta_cand_k2\"],\n",
    "                           train_conf[\"theta_cand_k1\"], train_conf[\"theta_cand_fanout\"],\n",
    "                           train_conf[\"theta_cand_add_self\"],\n",
    "                           use_cache=True, cache_path=os.path.join(DATA_FOLDER, \"SDMPPre\"),\n",
    "                           device=DEVICE)\n",
    "preprocesser.disp_states()\n",
    "theta_cand, h_init_theta, X, target = preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target\n",
    "\n",
    "g=load_data(train_conf[\"name\"])\n",
    "feat = torch.from_numpy(target)\n",
    "hidden_size=[]\n",
    "train_conf[\"lr\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp_conf = {}\n",
    "mlp_conf[\"batch_size\"] = 512\n",
    "mlp_conf[\"epoch\"] = 500\n",
    "mlp_conf[\"lr\"] = 0.0001\n",
    "mlp_conf['weight_decay'] = 5e-6\n",
    "hidden_size = None\n",
    "mlp_conf[\"dropout\"] = 0.0\n",
    "print(mlp_conf)\n",
    "all_acc =  []\n",
    "\n",
    "for I in range(1):\n",
    "    print(I)\n",
    "    # training loop\n",
    "\n",
    "    # create GraphSAGE model\n",
    "    in_size =feat.shape[1]\n",
    "    out_size = g.num_classes\n",
    "\n",
    "    model = MLP(in_size, hidden_size, out_size, dropout=mlp_conf[\"dropout\"]).to(device)\n",
    "    best_model_state = mtrain(device, feat, g, model, mlp_conf)\n",
    "    \n",
    "    with open(os.path.join(model_folder, \"state_dict_\"+str(I)), \"wb\") as f:\n",
    "        f.write(best_model_state)\n",
    "\n",
    "    model.load_state_dict(pickle.loads(best_model_state))\n",
    "    test_dataloader = MyLoader(feat.to(device), torch.tensor(g.ndata['label']).to(device),\n",
    "                               train_conf[\"batch_size\"], g.test_idx.to(device))\n",
    "    acc = evaluate(model, test_dataloader)\n",
    "    all_acc.append(acc.item())\n",
    "\n",
    "print(\"Test F1 score micro {:.4f}\".format(acc.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hack the graphSAGE embedding and see\n",
    "from model import SAGE\n",
    "from utils import load_train_conf\n",
    "model_folder = \"../result/ogbn-products/SAGE\"\n",
    "\n",
    "sage_conf = load_train_conf(os.path.join(model_folder, \"train_conf.yml\"))\n",
    "\n",
    "# in_size = g.ndata['feat'].shape[1]\n",
    "# out_size = g.num_classes\n",
    "# model = SAGE(in_size, sage_conf[\"hidden_size\"], out_size,\n",
    "#              GNN_layer=sage_conf[\"hidden_layer\"], dropout=sage_conf[\"dropout\"])\n",
    "model_path = os.path.join(model_folder, \"model_stat\")\n",
    "\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "model_state = CPU_Unpickler(open(model_path,\"rb\")).load()\n",
    "# model.load_state_dict(model_state)\n",
    "# model = model.to(device)\n",
    "\n",
    "mlp_state = pickle.loads(best_model_state)\n",
    "mlp_state[\"layers.0.weight\"] = model_state[\"layers.2.fc_neigh.weight\"].to(device)\n",
    "mlp_state[\"layers.0.bias\"] = model_state[\"layers.2.bias\"].to(device)\n",
    "model.load_state_dict(mlp_state)\n",
    "\n",
    "# inference and test\n",
    "test_dataloader = MyLoader(feat.to(device), torch.tensor(g.ndata['label']).to(device),\n",
    "                           train_conf[\"batch_size\"], g.test_idx.to(device))\n",
    "acc = evaluate(model, test_dataloader)\n",
    "print(\"test\", acc.item())\n",
    "\n",
    "test_dataloader = MyLoader(feat.to(device), torch.tensor(g.ndata['label']).to(device),\n",
    "                           train_conf[\"batch_size\"], g.val_idx.to(device))\n",
    "acc = evaluate(model, test_dataloader)\n",
    "print(\"val\", acc.item())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time measurement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_inference_time(model, device, g, eval_size = 10000):\n",
    "    sampling_time, infer_time = [], []\n",
    "\n",
    "    val_idx = g.val_idx[all_ind[:eval_size]].to(\"cpu\")\n",
    "    tic_data_loader = time()\n",
    "    val_dataloader = MyLoader(g.ndata['feat'].to(\"cpu\"), g.ndata['label'].to(\"cpu\"),\n",
    "                               1, val_idx)\n",
    "    print(\"data loader created in {:.6f} seconds\".format(time() - tic_data_loader))\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        total_loss = 0\n",
    "        tic_sampling = time()\n",
    "        for it, (x, y) in enumerate(val_dataloader):\n",
    "            sampling_time.append(time() - tic_sampling)\n",
    "            tic_infer = time()\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            y_hat = model(x)\n",
    "            infer_time.append(time() - tic_infer)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            total_loss += loss.item()\n",
    "            del x\n",
    "            del loss\n",
    "            del y\n",
    "            del y_hat\n",
    "            torch.cuda.empty_cache()\n",
    "            tic_sampling = time()\n",
    "            \n",
    "    return sampling_time, infer_time\n",
    "\n",
    "\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "def load_MLP_model(g, model_folder, train_conf, device=\"cpu\"):\n",
    "    in_size = g.ndata['feat'].shape[1]\n",
    "    out_size = g.num_classes\n",
    "    hidden_size = None if train_conf[\"IS_ONE_LAYER\"] else [train_conf[\"hidden_size\"] ] * train_conf['hidden_layer']\n",
    "    model = MLP(in_size, hidden_size, out_size)\n",
    "    model_path = os.path.join(model_folder, \"state_dict_0\")\n",
    "\n",
    "\n",
    "    model_state = CPU_Unpickler(open(model_path,\"rb\")).load()\n",
    "    model.load_state_dict(model_state)\n",
    "    model = model.to(device)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# Load model\n",
    "# device=\"cuda:0\"\n",
    "DEVICE=\"cpu\"\n",
    "print(\"=\"*10+f\"Test on {DEVICE}\")\n",
    "seed = 0\n",
    "save_res_root = \"../script_logs/speedtest/MLP\"\n",
    "if not os.path.exists(save_res_root):\n",
    "    os.makedirs(save_res_root)\n",
    "\n",
    "name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "\n",
    "\n",
    "for name in name_pool:\n",
    "    print(\"*\"*21)\n",
    "    print(name)\n",
    "    model_folder = os.path.join(\"../result\", name , \"MLP\")\n",
    "    g = load_data(name, seed=seed).to(DEVICE)\n",
    "    train_conf = train_conf_MLP_dict[name]\n",
    "\n",
    "    mlp_model = load_MLP_model(g, model_folder, train_conf, device=DEVICE)\n",
    "    # testing for the sampling time\n",
    "\n",
    "    all_ind = list(range(g.val_idx.size()[0]))\n",
    "    random.seed(666)\n",
    "    random.shuffle(all_ind)\n",
    "\n",
    "    all_sample_time, all_infer_time = [], []\n",
    "\n",
    "    sampling_time, infer_time = test_inference_time(mlp_model, DEVICE, g)\n",
    "    all_sample_time.append(sampling_time)\n",
    "    all_infer_time.append(infer_time)\n",
    "\n",
    "    with open(os.path.join(save_res_root, f\"{name}.pkl\"), \"wb\") as fout:\n",
    "        pickle.dump([all_sample_time, all_infer_time], fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# display the mean of all \n",
    "name_pool = [\"cora\", \"citeseer\", \"pubmed\", \"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "\n",
    "table_list = []\n",
    "for name in name_pool:\n",
    "    with open(os.path.join(save_res_root, f\"{name}.pkl\"), \"rb\") as fin:\n",
    "        [all_sample_time, all_infer_time] = pickle.load(fin)\n",
    "    cur_list = [name]\n",
    "    all_time_list =[[a+b for a, b in zip(s, i)] for s, i in zip(all_sample_time, all_infer_time)]\n",
    "    cur_list.append(f\"{np.mean(all_time_list[0])*1000:.1f}\")\n",
    "    table_list.append(cur_list)\n",
    "    \n",
    "table_str = latexTableBase.gen_table(table_list)\n",
    "\n",
    "print(table_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# final string generation\n",
    "\n",
    "def gen_res_table(stime, itime, func_list, func_name_list, head_list):\n",
    "    def apply_func(time_list, func):\n",
    "        return [func(i[1:]) for i in time_list]\n",
    "    rows = []\n",
    "    # first row\n",
    "    rows.append(\"\\hline\")\n",
    "    cur_row = \"Neighbor size & \" + \" & \".join(head_list) + \" \\\\\\\\ \\hline\"\n",
    "    rows.append(cur_row)\n",
    "    \n",
    "    for name, func in zip(func_name_list, func_list):\n",
    "        svals = apply_func(stime, func)\n",
    "        ivals = apply_func(itime, func)\n",
    "        all_list = [[a+b for a, b in zip(s, i)] for s, i in zip(stime, itime)]\n",
    "        alls = apply_func(all_list, func)\n",
    "        comb = [\"{:.5f}/{:.5f}/{:.5f}\".format(i*1000, j*1000, k*1000) for i, j, k in zip(svals, ivals, alls)]\n",
    "        cur_row = name + \" & \" + \" & \".join(comb) + \" \\\\\\\\\"\n",
    "        rows.append(cur_row)\n",
    "        \n",
    "    table_str = \"\\n\".join(rows)\n",
    "    \n",
    "    table_str += \" \\hline\"\n",
    "    return table_str\n",
    "        \n",
    "func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]\n",
    "func_name_list = [\"mean\", \"90-percentile\", \"99-percentile\", \"max\", \"std\"]\n",
    "head_list = [\"MLP\"]\n",
    "\n",
    "res_str = gen_res_table(all_sample_time, all_infer_time, func_list, func_name_list, head_list)\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
