{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Baseline tests for graphSAGE\n",
    "\"\"\"\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "# os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "import sys\n",
    "sys.path.extend([\"../\", \"./\"])\n",
    "import random\n",
    "import glob\n",
    "\n",
    "import yaml\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import dgl\n",
    "import dgl.data\n",
    "import dgl.nn as dglnn\n",
    "from dgl.data import AsNodePredDataset\n",
    "from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler\n",
    "from tqdm import tqdm\n",
    "\n",
    "from model import SAGE\n",
    "from utils import torch_f1\n",
    "from data_utils import get_index, split_dataset, load_data, hash_list\n",
    "from path_dict import train_conf_dict\n",
    "\n",
    "device = 'cuda:7'\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Utility functions\n",
    "def evaluate(model, dataloader):\n",
    "    model.eval()\n",
    "    ys = []\n",
    "    y_hats = []\n",
    "    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):\n",
    "        with torch.no_grad():\n",
    "            x = blocks[0].srcdata['feat']\n",
    "            ys.append(blocks[-1].dstdata['label'])\n",
    "            y_hats.append(model(blocks, x))\n",
    "    return torch_f1(torch.cat(y_hats), torch.cat(ys))\n",
    "\n",
    "def layerwise_infer(device, graph, nid, model, batch_size, is_sample=False, sample_size=2):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        pred = model.inference(graph, device, batch_size, is_sample=is_sample, sample_size=sample_size) # pred in buffer_device\n",
    "        pred = pred[nid]\n",
    "        label = graph.ndata['label'][nid].to(pred.device)\n",
    "        return torch_f1(pred, label)\n",
    "\n",
    "def train(device, g, model, train_conf, is_sample=False):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = g.train_idx.to(device)\n",
    "    val_idx =g.val_idx.to(device)\n",
    "    if is_sample:\n",
    "        sampler = NeighborSampler([train_conf[\"train_neighbor_size\"]] * model.GNN_layer,  # fanout for [layer-0, layer-1, layer-2]\n",
    "                                  prefetch_node_feats=['feat'],\n",
    "                                  prefetch_labels=['label'])\n",
    "    else:\n",
    "        sampler = MultiLayerFullNeighborSampler(model.GNN_layer,\n",
    "                                                prefetch_node_feats=['feat'],\n",
    "                                                prefetch_labels=['label'])\n",
    "    use_uva = False\n",
    "    train_dataloader = DataLoader(g, train_idx, sampler, device=device,\n",
    "                                  batch_size=train_conf[\"batch_size\"], shuffle=True,\n",
    "                                  drop_last=False, num_workers=0,\n",
    "                                  use_uva=use_uva)\n",
    "\n",
    "    val_dataloader = DataLoader(g, val_idx, sampler, device=device,\n",
    "                                batch_size=256, shuffle=True,\n",
    "                                drop_last=False, num_workers=0,\n",
    "                                use_uva=use_uva)\n",
    "\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf[\"weight_decay\"])\n",
    "    \n",
    "    best_state, best_val, best_epoch = None, 0, 0\n",
    "    for epoch in tqdm(range(train_conf[\"epoch\"])):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):\n",
    "            x = blocks[0].srcdata['feat']\n",
    "            y = blocks[-1].dstdata['label']\n",
    "            y_hat = model(blocks, x)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "        acc = evaluate(model, val_dataloader)\n",
    "        if acc.item() > best_val:\n",
    "            best_val = acc.item()\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "            \n",
    "    print(\"Epoch {:05d} | F1_micro {:.4f} \"\n",
    "           .format(best_epoch, best_val))\n",
    "    \n",
    "    return best_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameter settings and data loading\n",
    "test_sample_size_pool = []\n",
    "\n",
    "# name_pool = [\"cora\", \"citeseer\", \"pubmed\"]\n",
    "# name_pool = [\"a-computer\", \"a-photo\"]\n",
    "# name_pool = [\"ogbn-arxiv\", \"ogbn-products\"]\n",
    "name_pool = [\"ogbn-products\"]\n",
    "# name_pool = [\"a-computer\", \"a-photo\", \"ogbn-arxiv\", \"ogbn-products\"]\n",
    "\n",
    "# split_seed_pool = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
    "split_seed_pool = [2, 3, 4, 5, 6, 7, 8, 9]\n",
    "\n",
    "for name in name_pool:\n",
    "    train_conf = train_conf_dict[name]\n",
    "    p_train_f_test = []\n",
    "    for I, cur_seed in enumerate(split_seed_pool):\n",
    "        print(f\"round: {I}, seed: {cur_seed}\")\n",
    "        # Initialization\n",
    "        model_folder = os.path.join(\"../result\", name, f\"SAGE/seed_{cur_seed}\") # for general test\n",
    "        # model_folder = os.path.join(\"../result\", name , \"SAGE_time\", \"Layer\"+str(train_conf['hidden_layer'])) # for layer time test\n",
    "        if not os.path.exists(model_folder):\n",
    "            os.makedirs(model_folder)\n",
    "        model_path = os.path.join(model_folder, \"model_stat\")\n",
    "\n",
    "        # output the configurations\n",
    "        with open(os.path.join(model_folder, \"train_conf.yml\"), 'w') as fout:\n",
    "            yaml.dump(train_conf, fout)\n",
    "\n",
    "        g = load_data(name, seed=cur_seed)\n",
    "        g = g.to(device)\n",
    "\n",
    "        # training loop\n",
    "\n",
    "        # create GraphSAGE model\n",
    "        in_size = g.ndata['feat'].shape[1]\n",
    "        out_size = g.num_classes\n",
    "        hidden_size = train_conf[\"hidden_size\"]\n",
    "\n",
    "        model = SAGE(in_size, hidden_size, out_size, GNN_layer=train_conf[\"hidden_layer\"], dropout=train_conf[\"dropout\"]).to(device)\n",
    "        partial_neighbor_model_state = train(device, g, model, train_conf, is_sample=True)\n",
    "        \n",
    "        with open(model_path, \"wb\") as f:\n",
    "            f.write(partial_neighbor_model_state)\n",
    "\n",
    "        # model = SAGE(in_size, hidden_size, out_size).to(device)\n",
    "        # full_neighbor_model_state = train(device, g, dataset, model, train_conf, is_sample=False)\n",
    "\n",
    "        # final evaluation\n",
    "\n",
    "        model.load_state_dict(pickle.loads(partial_neighbor_model_state))\n",
    "        acc = layerwise_infer(device, g, g.test_idx, model, batch_size=256, is_sample=False)\n",
    "        p_train_f_test.append(acc.item())\n",
    "        for test_sample_size in test_sample_size_pool:\n",
    "            acc = layerwise_infer(device, g, g.test_idx, model, batch_size=256, is_sample=True, sample_size=test_sample_size)\n",
    "            p_train_p_test[test_sample_size].append(acc.item())\n",
    "        \n",
    "        # save the f1 as well\n",
    "        with open(os.path.join(model_folder, \"f1.txt\"), \"w\") as fout:\n",
    "            fout.write(str(p_train_f_test[-1]))\n",
    "            fout.write(\"\\n\")\n",
    "            print(f\"Sample train full test of f1 {p_train_f_test[-1]} wrote to {os.path.join(model_folder, 'f1.txt')}\")\n",
    "            \n",
    "        # save the hash signatures\n",
    "        train_hash, val_hash, test_hash = hash_list(g.train_idx.tolist()), hash_list(g.val_idx.tolist()), hash_list(g.test_idx.tolist())\n",
    "\n",
    "        save_str = f\"train sig: {train_hash} | val sig: {val_hash} | test_sig: {test_hash}\\n\"\n",
    "        with open(os.path.join(model_folder, \"split_hash.txt\"), \"w\") as fout:\n",
    "            fout.write(save_str)\n",
    "            print(f\"Split signature wrote to {os.path.join(model_folder, 'split_hash.txt')}\")\n",
    "            \n",
    "        label_hash = hash_list(g.ndata[\"label\"][:1000].detach().cpu().numpy().tolist())\n",
    "        save_str = f\"First 1000 label hash is {label_hash}\\n\"\n",
    "        with open(os.path.join(model_folder, \"label_hash.txt\"), \"w\") as fout:\n",
    "            fout.write(save_str)\n",
    "            print(f\"Label signature wrote to {os.path.join(model_folder, 'label_hash.txt')}\")\n",
    "        \n",
    "        save_str = str(cur_seed)\n",
    "        with open(os.path.join(model_folder, \"data_split_seed.txt\"), 'w') as fout:\n",
    "            fout.write(str(cur_seed))\n",
    "            print(f\"Current data split seed wrote to {os.path.join(model_folder, 'data_split_seed.txt')}.\")\n",
    "    \n",
    "        # prepare and save the cache for GLNN and NOSMOG\n",
    "        ## load model\n",
    "        in_size = g.ndata['feat'].shape[1]\n",
    "        out_size = g.num_classes\n",
    "        hidden_size = train_conf[\"hidden_size\"]\n",
    "\n",
    "        model = SAGE(in_size, hidden_size, out_size, GNN_layer=train_conf[\"hidden_layer\"]).to(device)\n",
    "\n",
    "        with open(model_path, 'rb') as fin:\n",
    "            model.load_state_dict(pickle.load(fin))\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            hid_embed, soft_label = model.inference_teacher(g, device, 256)\n",
    "\n",
    "        ## main computation\n",
    "        embed_path = os.path.join(model_folder, \"out_emb_list.npz\")\n",
    "        soft_label_path = os.path.join(model_folder, \"out.npz\")\n",
    "\n",
    "        np.savez(embed_path, hid_embed.detach().cpu().numpy())\n",
    "        np.savez(soft_label_path, soft_label.detach().cpu().numpy())\n",
    "        # end of the loop for one dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output to latex\n",
    "def to_latex(res):\n",
    "    mean = np.mean(res)\n",
    "    std = np.std(res)\n",
    "    res_str = \"{:.2f}$\\pm${:.2f}\".format(mean*100, std*100)\n",
    "    return res_str\n",
    "\n",
    "# res = []\n",
    "# for size in test_sample_size_pool:\n",
    "#     res.append(to_latex(p_train_p_test[size]))\n",
    "# res.append(to_latex(p_train_f_test))\n",
    "# res = [\"$\"+i+\"$\" for i in res]\n",
    "# res_str = \" & \".join(res)\n",
    "# print(res_str)\n",
    "\n",
    "# res = []\n",
    "# for size in test_sample_size_pool:\n",
    "#     res.append(to_latex(f_train_p_test[size]))\n",
    "# res.append(to_latex(f_train_f_test))\n",
    "# res = [\"$\"+i+\"$\" for i in res]\n",
    "# res_str = \" \".join(res)\n",
    "# print(res_str)\n",
    "\n",
    "# split_seed_pool = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
    "split_seed_pool = [0, 1, 2, 3, 4]\n",
    "res = []\n",
    "name=\"ogbn-products\"\n",
    "print(name)\n",
    "for seed in split_seed_pool:\n",
    "    res_path = os.path.join(\"../result\", name, f\"SAGE/seed_{seed}/f1.txt\")\n",
    "    with open(res_path, \"r\") as fin:\n",
    "        res.append(eval(fin.read()))\n",
    "print(to_latex(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
