{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The line_profiler extension is already loaded. To reload it, use:\n",
      "  %reload_ext line_profiler\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Test various versions of simplied GCN\n",
    "\n",
    "Created by: Yaochen Hu\n",
    "Created on: Aug. 10, 2022\n",
    "\n",
    "\"\"\"\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "import sys\n",
    "sys.path.extend([\"../\"])\n",
    "import random\n",
    "from time import time\n",
    "import io\n",
    "\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "import dgl\n",
    "import dgl.data\n",
    "from dgl.data import AsNodePredDataset\n",
    "import yaml\n",
    "\n",
    "from model import MLP\n",
    "from data_utils import PlainLoader as MyLoader\n",
    "from data_utils import SGCFeatureGen\n",
    "from utils import evaluate\n",
    "\n",
    "device = 'cuda:1'\n",
    "\n",
    "\n",
    "# Utility functions\n",
    "\n",
    "def train(device, features, dataset, model, train_conf):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = dataset.train_idx.to(device)\n",
    "    val_idx = dataset.val_idx.to(device)\n",
    "    features = features.to(device)\n",
    "    labels = torch.tensor(dataset.labels).to(device)\n",
    "    \n",
    "    train_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], train_idx)\n",
    "    val_dataloader = MyLoader(features, labels, train_conf[\"batch_size\"], val_idx)\n",
    "\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf['weight_decay'])\n",
    "    # opt = torch.optim.LBFGS(model.parameters(), lr=train_conf[\"lr\"])\n",
    "    \n",
    "    best_state, best_val, best_epoch = None, 0, 0\n",
    "    for epoch in tqdm(range(train_conf[\"epoch\"])):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (x, y) in enumerate(train_dataloader):\n",
    "            y_hat = model(x)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "        acc = evaluate(model, val_dataloader)\n",
    "        # print(\"Epoch {:04d} | ACC {:.4f}\"\n",
    "        #      .format(epoch, acc.item()))\n",
    "        if acc.item() > best_val:\n",
    "            best_val = acc.item()\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "            \n",
    "    print(\"Epoch {:05d} | f1_micro {:.4f} \"\n",
    "           .format(best_epoch, best_val))\n",
    "    \n",
    "    return best_state\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "{'name': 'cora', 'dropout': 0.0, 'model': 'SGC', 'batch_size': 512, 'epoch': 100, 'lr': 0.5, 'weight_decay': 5e-06, 'hidden_size': 64, 'hidden_layer': 0, 'feature_normalize': 'row_sum', 'sgc_k': 3, 'sgc_is_add_original': False, 'sgc_is_add_ori': False}\n"
     ]
    }
   ],
   "source": [
    "# Parameter settings and data loading\n",
    "def get_index(dataset):\n",
    "    def mask_to_ind(mask):\n",
    "        return torch.tensor([i for i, flag in enumerate(mask) if flag])\n",
    "    dataset.train_idx = mask_to_ind(dataset.train_mask)\n",
    "    dataset.val_idx = mask_to_ind(dataset.val_mask)\n",
    "    dataset.test_idx = mask_to_ind(dataset.test_mask)\n",
    "    \n",
    "def load_train_conf(conf_path):\n",
    "    with open(conf_path, 'r') as fin:\n",
    "        train_conf = yaml.full_load(fin)\n",
    "    return train_conf\n",
    "\n",
    "data_root_folder = \"../dataset\"\n",
    "conf_root_folder = \"../config\"\n",
    "\n",
    "#############################################\n",
    "\n",
    "# cora\n",
    "dataset = dgl.data.CoraGraphDataset()\n",
    "get_index(dataset)\n",
    "conf_name = \"cora/SGC/cora_SGC0.yml\"  # MLP/SGC 0/1/2\n",
    "\n",
    "# pubmed\n",
    "# dataset = dgl.data.PubmedGraphDataset()\n",
    "# get_index(dataset)\n",
    "# conf_name = \"pubmed/pubmed_SGC0.yml\"   # MLP/SGC 0/1/2\n",
    "\n",
    "# # reddit\n",
    "# name = \"reddit\"\n",
    "# dataset = dgl.data.RedditDataset()\n",
    "# get_index(dataset)\n",
    "# conf_name = \"reddit/reddit_SGC1.yml\"   # MLP/SGC 0/1/2\n",
    "\n",
    "\n",
    "#######################################\n",
    "# Initialization\n",
    "conf_path = os.path.join(conf_root_folder, conf_name)\n",
    "train_conf = load_train_conf(conf_path)\n",
    "train_conf[\"sgc_k\"] = train_conf.get(\"sgc_k\", 2)\n",
    "train_conf[\"sgc_is_add_ori\"] = train_conf.get(\"sgc_is_add_ori\", False)\n",
    "\n",
    "print(train_conf)\n",
    "\n",
    "hidden_size = [train_conf['hidden_size']] * train_conf['hidden_layer']\n",
    "\n",
    "model_folder = os.path.join(\"res\", train_conf[\"model\"]+str(train_conf['hidden_layer']), train_conf['name'])\n",
    "if not os.path.exists(model_folder):\n",
    "    os.makedirs(model_folder)\n",
    "\n",
    "data_folder = os.path.join(data_root_folder, train_conf['name'])\n",
    "if not os.path.exists(data_folder):\n",
    "    os.makedirs(data_folder)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating features...\n",
      "    Normalizing features...\n",
      "    Generating normalized adjacent matrix...\n",
      "    Computing feature propogation...\n",
      "        finished in 0.0073 s\n"
     ]
    }
   ],
   "source": [
    "# Preprocessing the features\n",
    "feature_gen = SGCFeatureGen(dataset, k=train_conf[\"sgc_k\"], device=device, cache_path=data_folder, use_cache=True, return_torch=True,\n",
    "                            need_k_adj=False, compute_device=device, feature_normalize=train_conf['feature_normalize'])\n",
    "\n",
    "if train_conf[\"model\"] == \"MLP\":\n",
    "    features = feature_gen.torch_features\n",
    "elif train_conf[\"model\"] == \"SGC\":\n",
    "    if train_conf[\"sgc_is_add_ori\"]:\n",
    "        features = feature_gen.all_features\n",
    "    else:\n",
    "        features = feature_gen.prop_features\n",
    "else:\n",
    "    raise ValueError(\"Not implemented model {}\".format(train_conf['model']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 213.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00016 | f1_micro 0.7780 \n",
      "1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 222.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00016 | f1_micro 0.7900 \n",
      "2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 220.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00013 | f1_micro 0.7800 \n",
      "3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 221.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00014 | f1_micro 0.7780 \n",
      "4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 222.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00014 | f1_micro 0.7780 \n",
      "Test f1_micro 0.7830\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "all_acc =  []\n",
    "\n",
    "for I in range(5):\n",
    "    print(I)\n",
    "    # training loop\n",
    "\n",
    "    # create GraphSAGE model\n",
    "    in_size = features.shape[1]\n",
    "    out_size = dataset.num_classes\n",
    "\n",
    "    model = MLP(in_size, hidden_size, out_size, dropout=train_conf['dropout']).to(device)\n",
    "    best_model_state = train(device, features, dataset, model, train_conf)\n",
    "    \n",
    "    with open(os.path.join(model_folder, \"state_dict_\"+str(I)), \"wb\") as f:\n",
    "        f.write(best_model_state)\n",
    "\n",
    "    model.load_state_dict(pickle.loads(best_model_state))\n",
    "    test_dataloader = MyLoader(features.to(device), torch.tensor(dataset.labels).to(device),\n",
    "                               train_conf[\"batch_size\"], dataset.test_idx.to(device))\n",
    "    acc = evaluate(model, test_dataloader)\n",
    "    all_acc.append(acc.item())\n",
    "\n",
    "print(\"Test f1_micro {:.4f}\".format(acc.item()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "$0.785\\pm0.005$\n",
      "{'name': 'cora', 'dropout': 0.0, 'model': 'SGC', 'batch_size': 512, 'epoch': 100, 'lr': 0.5, 'weight_decay': 5e-06, 'hidden_size': 64, 'hidden_layer': 0, 'feature_normalize': 'row_sum', 'sgc_k': 3, 'sgc_is_add_original': False, 'sgc_is_add_ori': False}\n"
     ]
    }
   ],
   "source": [
    "# output to latex\n",
    "def to_latex(res, precision=3):\n",
    "    mean = np.mean(res)\n",
    "    std = np.std(res)\n",
    "    res_str = (\"{:.\"+str(precision)+\"f}\\pm{:.\"+str(precision)+\"f}\").format(mean, std)\n",
    "    return res_str\n",
    "\n",
    "res = []\n",
    "res.append(to_latex(all_acc))\n",
    "res = [\"$\"+i+\"$\" for i in res]\n",
    "res_str = \" & \".join(res)\n",
    "print(res_str)\n",
    "print(train_conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data loader created in 0.009240 seconds\n"
     ]
    }
   ],
   "source": [
    "# testing for the sampling time\n",
    "def test_inference_time(model, device, dataset, features, eval_size = 100):\n",
    "    sampling_time, infer_time = [], []\n",
    "\n",
    "    val_idx = dataset.val_idx[all_ind[:eval_size]].to(\"cpu\")\n",
    "    tic_data_loader = time()\n",
    "    val_dataloader = MyLoader(features.to(\"cpu\"), torch.tensor(dataset.labels).to(\"cpu\"),\n",
    "                               1, val_idx)\n",
    "    print(\"data loader created in {:.6f} seconds\".format(time() - tic_data_loader))\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        total_loss = 0\n",
    "        tic_sampling = time()\n",
    "        for it, (x, y) in enumerate(val_dataloader):\n",
    "            sampling_time.append(time() - tic_sampling)\n",
    "            tic_infer = time()\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            y_hat = model(x)\n",
    "            infer_time.append(time() - tic_infer)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            total_loss += loss.item()\n",
    "            del x\n",
    "            del loss\n",
    "            del y\n",
    "            del y_hat\n",
    "            torch.cuda.empty_cache()\n",
    "            tic_sampling = time()\n",
    "            \n",
    "    return sampling_time, infer_time\n",
    "\n",
    "all_ind = list(range(dataset.val_idx.size()[0]))\n",
    "random.shuffle(all_ind)\n",
    "\n",
    "all_sample_time, all_infer_time = [], []\n",
    "\n",
    "sampling_time, infer_time = test_inference_time(model, device, dataset, features)\n",
    "all_sample_time.append(sampling_time)\n",
    "all_infer_time.append(infer_time)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\hline\n",
      "Neighbor size & SGC \\\\ \\hline\n",
      "mean & 0.00003/0.00015 \\\\\n",
      "90-percentile & 0.00003/0.00017 \\\\\n",
      "99-percentile & 0.00004/0.00020 \\\\\n",
      "max & 0.00006/0.00020 \\\\\n",
      "std & 0.00000/0.00001 \\\\ \\hline\n"
     ]
    }
   ],
   "source": [
    "# final string generation\n",
    "\n",
    "def gen_res_table(stime, itime, func_list, func_name_list, head_list):\n",
    "    def apply_func(time_list, func):\n",
    "        return [func(i[1:]) for i in time_list]\n",
    "    rows = []\n",
    "    # first row\n",
    "    rows.append(\"\\hline\")\n",
    "    cur_row = \"Neighbor size & \" + \" & \".join(head_list) + \" \\\\\\\\ \\hline\"\n",
    "    rows.append(cur_row)\n",
    "    \n",
    "    for name, func in zip(func_name_list, func_list):\n",
    "        svals = apply_func(stime, func)\n",
    "        ivals = apply_func(itime, func)\n",
    "        comb = [\"{:.5f}/{:.5f}\".format(i, j) for i, j in zip(svals, ivals)]\n",
    "        cur_row = name + \" & \" + \" & \".join(comb) + \" \\\\\\\\\"\n",
    "        rows.append(cur_row)\n",
    "        \n",
    "    table_str = \"\\n\".join(rows)\n",
    "    \n",
    "    table_str += \" \\hline\"\n",
    "    return table_str\n",
    "        \n",
    "func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]\n",
    "func_name_list = [\"mean\", \"90-percentile\", \"99-percentile\", \"max\", \"std\"]\n",
    "head_list = [\"SGC\"]\n",
    "\n",
    "res_str = gen_res_table(all_sample_time, all_infer_time, func_list, func_name_list, head_list)\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "\n",
    "in_size = dataset.features.shape[1]\n",
    "out_size = dataset.num_classes\n",
    "model = MLP(in_size, hidden_size, out_size)\n",
    "model_path = os.path.join(model_folder, \"state_dict_0\")\n",
    "\n",
    "class CPU_Unpickler(pickle.Unpickler):\n",
    "    def find_class(self, module, name):\n",
    "        if module == 'torch.storage' and name == '_load_from_bytes':\n",
    "            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')\n",
    "        else: return super().find_class(module, name)\n",
    "\n",
    "model_state = CPU_Unpickler(open(model_path,\"rb\")).load()\n",
    "model.load_state_dict(model_state)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
