{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "various analysis\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "from copy import deepcopy, copy\n",
    "from collections import defaultdict\n",
    "import glob\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import scipy.sparse as sp\n",
    "from sklearn.preprocessing import normalize\n",
    "from collections import defaultdict\n",
    "\n",
    "import ppr\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "from plotlib import *\n",
    "from script_utils import load_SDMP, load_SDMP_MLP, latexTableBase\n",
    "from path_dict import *\n",
    "from train import distill_train as train\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:0'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "TARGET_ROOT_FOLDER = \"../result/baselines\"\n",
    "LOG_ROOT_FOLDER = \"../script_logs/neighbour_alg\"\n",
    "\n",
    "def train_and_eval(features, teacher_emb, g, train_conf):\n",
    "    in_size = features.shape[1]\n",
    "    out_size = g.num_classes\n",
    "    hidden_size = [train_conf['hidden_size']] * train_conf['hidden_layer']\n",
    "\n",
    "    model = MLP(in_size, hidden_size, out_size, dropout=train_conf['dropout']).to(DEVICE)\n",
    "    best_model_state = train(DEVICE, features, teacher_emb, g,  model, train_conf, verbose=False)\n",
    "\n",
    "    # with open(os.path.join(MODEL_FOLDER, \"state_dict_\"+str(I)), \"wb\") as f:\n",
    "    #     f.write(best_model_state)\n",
    "\n",
    "    model.load_state_dict(pickle.loads(best_model_state))\n",
    "    test_dataloader = PlainLoader(features.to(DEVICE), torch.tensor(g.ndata['label']).to(DEVICE), \n",
    "                                  train_conf[\"batch_size\"],\n",
    "                                  g.test_idx.to(DEVICE))\n",
    "    acc = evaluate(model, test_dataloader)\n",
    "    return model, acc.item()\n",
    "\n",
    "# def gen_equal_feature(sdmp, all_h):\n",
    "#     ThetaT = sdmp.ThetaT.tocsr()\n",
    "#     data = np.ones(ThetaT.data.shape[0])\n",
    "#     indices, indptr = ThetaT.indices, ThetaT.indptr\n",
    "#     equal_ThetaT = sp.csr_matrix((data, indices, indptr), shape=ThetaT.shape)\n",
    "#     equal_ThetaT = row_norm(equal_ThetaT)\n",
    "#     torch_equal_ThetaT = scipy_coo_to_torch_sparse(equal_ThetaT.tocoo()).to(DEVICE)\n",
    "#     equal_ThetaT_feature = torch.sparse.mm(torch_equal_ThetaT, all_h)\n",
    "#     return equal_ThetaT_feature.detach()\n",
    "\n",
    "def gen_equal_feature(sdmp, all_h):\n",
    "    ThetaT = sdmp.ThetaT.tocsr()\n",
    "    data = np.ones(ThetaT.data.shape[0])\n",
    "    indices, indptr = ThetaT.indices, ThetaT.indptr\n",
    "    equal_ThetaT = sp.csr_matrix((data, indices, indptr), shape=ThetaT.shape)\n",
    "    torch_equal_ThetaT = scipy_coo_to_torch_sparse(equal_ThetaT.tocoo()).to(DEVICE)\n",
    "    row_cnt = indptr[1:] - indptr[:-1]\n",
    "    row_cnt[row_cnt == 0] = 1\n",
    "    row_cnt = torch.from_numpy(row_cnt).to(DEVICE)\n",
    "    equal_ThetaT_feature = torch.sparse.mm(torch_equal_ThetaT, all_h)\n",
    "    equal_ThetaT_feature /= row_cnt.reshape([-1, 1])\n",
    "    return equal_ThetaT_feature.detach()\n",
    "\n",
    "def gen_k_hop_feature(g, A, k):\n",
    "    feat = g.ndata[\"feat\"].to(DEVICE)\n",
    "    torch_A = scipy_coo_to_torch_sparse(A.tocoo()).to(DEVICE)\n",
    "    for _ in range(k):\n",
    "        feat = torch.sparse.mm(torch_A, feat)\n",
    "    return feat.detach()\n",
    "\n",
    "def gen_k_hop_tranc_feature(g, A, receptive_field_size):\n",
    "    indptr = A.indptr\n",
    "    for i in range(A.shape[0]):\n",
    "        cur_data = A.data[indptr[i]:indptr[i+1]]\n",
    "        sort_idx = np.argsort(cur_data)\n",
    "        A.data[indptr[i]+sort_idx[::-1][receptive_field_size[i]:]] = 0\n",
    "    A.eliminate_zeros()\n",
    "    print(f\"Trancated average nonezeros: {len(A.data)/A.shape[0]}\")\n",
    "    torch_A = scipy_coo_to_torch_sparse(A.tocoo()).to(DEVICE)\n",
    "    feat = g.ndata[\"feat\"].to(DEVICE)\n",
    "    feat = torch.sparse.mm(torch_A, feat)\n",
    "    return feat.detach()\n",
    "\n",
    "def gen_pprgo_tranc_feature(g, A, receptive_field_size):\n",
    "    ppr_matrix = ppr.topk_ppr_matrix(adj_matrix=A,\n",
    "                                     alpha=0.5,\n",
    "                                     eps=1e-4,\n",
    "                                     idx=list(range(A.shape[0])),\n",
    "                                     topk=receptive_field_size)\n",
    "    print(f\"Trancated average nonezeros: {len(ppr_matrix.data)/ppr_matrix.shape[0]}\")\n",
    "    torch_A = scipy_coo_to_torch_sparse(ppr_matrix.tocoo()).to(DEVICE)\n",
    "    feat = g.ndata[\"feat\"].to(DEVICE)\n",
    "    feat = torch.sparse.mm(torch_A, feat)\n",
    "    return feat.detach()\n",
    "\n",
    "# def get_paths(name, model, result_root):\n",
    "#     model_paths = os.path.join(result_root, f\"{name}_{model}_MLP_TEST\")\n",
    "#     return model_paths\n",
    "\n",
    "def get_paths(name, model):\n",
    "    return path_dict[name, model][\"result_root_parent_folder\"]\n",
    "\n",
    "def dump_and_save_model_dict(model, path):\n",
    "    with open(path, 'wb') as fout:\n",
    "        fout.write(pickle.dumps(model.state_dict()))\n",
    "###########\n",
    "# for the ablation studies\n",
    "conf_dict=defaultdict(dict)\n",
    "# cora\n",
    "conf_dict[\"cora\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"cora\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"cora\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"cora\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# citeseer\n",
    "conf_dict[\"citeseer\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"citeseer\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"citeseer\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"citeseer\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# pubmed\n",
    "conf_dict[\"pubmed\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"pubmed\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"pubmed\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"pubmed\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# a-computer\n",
    "conf_dict[\"a-computer\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-computer\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-computer\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-computer\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# a-photo\n",
    "conf_dict[\"a-photo\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-photo\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-photo\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"a-photo\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# ogbn-arxiv\n",
    "conf_dict[\"ogbn-arxiv\"][\"equal\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"ogbn-arxiv\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"ogbn-arxiv\"][\"k_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"ogbn-arxiv\"][\"PPRGo_tranc\"] = {\"batch_size\": 512, \"epoch\": 500, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "# ogbn-products\n",
    "conf_dict[\"ogbn-products\"][\"equal\"] = {\"batch_size\": 1024, \"epoch\": 100, \"hidden_layer\": 1, \"lr\": 0.0001}\n",
    "conf_dict[\"ogbn-products\"][\"k_hop\"] = {\"batch_size\": 512, \"epoch\": 200, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"ogbn-products\"][\"k_tranc\"] = {\"batch_size\": 1024, \"epoch\": 100, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "conf_dict[\"ogbn-products\"][\"PPRGo_tranc\"] = {\"batch_size\": 1024, \"epoch\": 100, \"hidden_layer\": 1, \"lr\": 0.001}\n",
    "\n",
    "testing_cases = ['origin', 'PPRGo', 'SGC', 'equal', \"SDMP\"]\n",
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\")]\n",
    "# name_model_pool = [(\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n",
    "# name_model_pool = [(\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\")]\n",
    "# name_model_pool = [(\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, model in name_model_pool:\n",
    "    print(\"=\"*21)\n",
    "    print(name, model)\n",
    "    # initializations\n",
    "    all_res = defaultdict(list)\n",
    "    cur_res_folder = os.path.join(LOG_ROOT_FOLDER, f\"{name}_{model}\")\n",
    "    if not os.path.exists(cur_res_folder):\n",
    "        os.makedirs(cur_res_folder)\n",
    "    res_root = get_paths(name, model)\n",
    "\n",
    "    for result_root_path in os.listdir(res_root):\n",
    "        # load data\n",
    "        tic = time()\n",
    "        print(\"*\"*21)\n",
    "        result_root_path = os.path.join(res_root, result_root_path)\n",
    "        with open(os.path.join(result_root_path, \"GNN_data_split_seed.txt\"), 'r') as fin:\n",
    "            seed_str = fin.read()\n",
    "        # data loading\n",
    "        print(\"-\"*11)\n",
    "        print(\"Dataloading...\")\n",
    "        # result_root_path, TARGET_GNN_FOLDER = path_dict[(name, model)][\"result_root_path\"], path_dict[(name, model)][\"TARGET_GNN_FOLDER\"]\n",
    "        TARGET_GNN_FOLDER =os.path.join(TARGET_ROOT_FOLDER, name, model, f\"seed_{seed_str}\")\n",
    "        print(TARGET_GNN_FOLDER)\n",
    "\n",
    "        preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(result_root_path,\n",
    "                                                                TARGET_GNN_FOLDER,\n",
    "                                                                DATA_ROOT_FOLDER,\n",
    "                                                                device=DEVICE)\n",
    "        SDMP_MLP_model = load_SDMP_MLP(train_conf, sdmp, g, result_root_path, device=DEVICE)\n",
    "\n",
    "        ori_node_emb, all_h = sdmp.infer_torch_node_approximal_features(return_H=True)\n",
    "\n",
    "        print(\"The train conf:!!!\")\n",
    "        print(train_conf)\n",
    "        print('distill lam', train_conf['lam_distill'])\n",
    "\n",
    "        with open(os.path.join(result_root_path, \"GNN_f1.txt\"), 'r') as fin:\n",
    "            cur_line = fin.read()\n",
    "            try:\n",
    "                target_GNN_acc = eval(cur_line)\n",
    "            except:\n",
    "                target_GNN_acc = eval(cur_line.split(':')[1])\n",
    "        all_res['origin'].append(target_GNN_acc)\n",
    "        ## Load teacher embedding\n",
    "        GNN_TEACHER_PATH = os.path.join(TARGET_ROOT_FOLDER, name, model, f\"seed_{seed}\")\n",
    "        teacher_emb = torch.from_numpy(np.load(GNN_TEACHER_PATH)['arr_0'])\n",
    "        print(f\"Data loading finished in {time()-tic:.1f} s. \")\n",
    "        ###########\n",
    "        # Get the performance of the SDMP model\n",
    "        print(\"-\"*11)\n",
    "        print(\"Get original SDMP performance...\")\n",
    "        tic = time()\n",
    "        test_dataloader = PlainLoader(ori_node_emb.to(DEVICE), torch.tensor(g.ndata['label']).to(DEVICE),\n",
    "                                    train_conf[\"batch_size\"], g.test_idx.to(DEVICE))\n",
    "        SDMP_acc = evaluate(SDMP_MLP_model, test_dataloader)\n",
    "        all_res['SDMP'].append(SDMP_acc.item())\n",
    "        print(f\"Original SDMP finished in {time()-tic:.1f} s. \")\n",
    "\n",
    "        # Get the performance of the equal weight setting\n",
    "        print(\"-\"*11)\n",
    "        print(\"Get equal weight performance...\")\n",
    "        tic = time()\n",
    "        equal_train_conf = deepcopy(train_conf)\n",
    "        equal_train_conf.update(conf_dict[name][\"equal\"])\n",
    "        print(equal_train_conf)\n",
    "        ## Construct the new feautures\n",
    "        equal_ThetaT_feature = gen_equal_feature(sdmp, all_h)\n",
    "        trained_model, equal_acc = train_and_eval(equal_ThetaT_feature, teacher_emb, g, equal_train_conf)\n",
    "        all_res['equal'].append(equal_acc)\n",
    "\n",
    "        model_path = os.path.join(cur_res_folder, f\"model_equal_s{seed_str}.pkl\")\n",
    "        dump_and_save_model_dict(trained_model, model_path)\n",
    "\n",
    "        print(f\"Equal weights finished in {time()-tic:.1f} s. \")\n",
    "        # Get the performance of SGC\n",
    "        hops = 2 if name in [\"ogbn-products\"] else 3\n",
    "        csr_theta = sdmp.ThetaT.tocsr()\n",
    "        receptive_field_size = csr_theta.indptr[1:] - csr_theta.indptr[:-1]\n",
    "        print(f\"Average number of receptive node from SDMP: {np.mean(receptive_field_size):.1f} +- {np.std(receptive_field_size):.1f}\")\n",
    "\n",
    "        print(\"-\"*11)\n",
    "        print(f\"Get {hops} hop performance...\")\n",
    "        tic = time()\n",
    "        # preprocesser.g = g\n",
    "        A = preprocesser.get_A_pow(hops)\n",
    "        k_hop_tranc_feat = gen_k_hop_tranc_feature(g, A, receptive_field_size)\n",
    "\n",
    "        k_hop_train_conf = deepcopy(train_conf)\n",
    "        k_hop_train_conf.update(conf_dict[name][\"k_tranc\"])\n",
    "        print(k_hop_train_conf)\n",
    "        trained_model, A_hop_acc = train_and_eval(k_hop_tranc_feat, teacher_emb, g, k_hop_train_conf)\n",
    "        all_res['SGC'].append(A_hop_acc)\n",
    "\n",
    "        model_path = os.path.join(cur_res_folder, f\"model_SGC_s{seed_str}.pkl\")\n",
    "        dump_and_save_model_dict(trained_model, model_path)\n",
    "\n",
    "        print(f\"Get {hops} trancated hop performance finished in {time()-tic:.1f} s.\")\n",
    "        # Get the performance of PPRGo\n",
    "        csr_theta = sdmp.ThetaT.tocsr()\n",
    "        receptive_field_size = csr_theta.indptr[1:] - csr_theta.indptr[:-1]\n",
    "        print(f\"Average number of receptive node from SDMP: {np.mean(receptive_field_size):.1f} +- {np.std(receptive_field_size):.1f}\")\n",
    "\n",
    "        print(\"-\"*11)\n",
    "        print(f\"Get PPRGo performance...\")\n",
    "        tic = time()\n",
    "        \n",
    "        A = torch_sparse_to_scipy_coo(g.adj()).tocsr()\n",
    "\n",
    "        PPRGo_tranc_feat = gen_pprgo_tranc_feature(g, A, receptive_field_size)\n",
    "\n",
    "        pprgo_tranc_train_conf = deepcopy(train_conf)\n",
    "        pprgo_tranc_train_conf.update(conf_dict[name][\"PPRGo_tranc\"])\n",
    "        print(pprgo_tranc_train_conf)\n",
    "        trained_model, PPRGo_acc = train_and_eval(PPRGo_tranc_feat, teacher_emb, g, pprgo_tranc_train_conf)\n",
    "        all_res['PPRGo'].append(PPRGo_acc)\n",
    "\n",
    "        model_path = os.path.join(cur_res_folder, f\"model_PPRGo_s{seed_str}.pkl\")\n",
    "        dump_and_save_model_dict(trained_model, model_path)\n",
    "\n",
    "        print(f\"Get PPRGo trancated hop performance finished in {time()-tic:.1f} s.\")\n",
    "\n",
    "        # end of seed for\n",
    "\n",
    "    print(all_res)\n",
    "    all_res_path = os.path.join(cur_res_folder, \"all_res_dict.pkl\")\n",
    "    with open(all_res_path, \"wb\") as fout:\n",
    "        pickle.dump(all_res, fout)\n",
    "\n",
    "    # end of name model for"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load and present all results\n",
    "class ablationTable(latexTableBase):\n",
    "    def gen_a_row(self, name_model_tuple):\n",
    "        data, name = name_model_tuple\n",
    "        cur_row = [f\"{data}/{name}\"]\n",
    "        try:\n",
    "            with open(f\"../script_logs/neighbour_alg/{data}_{name}/all_res_dict.pkl\", \"rb\") as fin:\n",
    "                all_res = pickle.load(fin)\n",
    "            # parse it as a list\n",
    "            \n",
    "            for each_test in self.header:\n",
    "                cur_row.append(f\"{np.mean(all_res[each_test])*100:.2f}$\\pm${np.std(all_res[each_test])*100:.2f}\")\n",
    "        except:\n",
    "            pass\n",
    "        return cur_row\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\")]\n",
    "my_tab = ablationTable(name_model_pool, testing_cases)\n",
    "print(my_tab.table_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Obsolete code segments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\")]\n",
    "# name_model_pool = [(\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\")]\n",
    "# name_model_pool = [(\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"citeseer\", \"geomGCN\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"exphormer\")]\n",
    "# name_model_pool = [(\"citeseer\", \"SAGE\"), (\"a-computer\", \"SAGE\"), (\"a-photo\", \"SAGE\")]\n",
    "\n",
    "MAX_SEED = 1\n",
    "for name, model in name_model_pool:    \n",
    "    all_ori, all_SDMP, all_equal = [], [], []\n",
    "    all_k_hop = defaultdict(list)\n",
    "    all_k_tranc = defaultdict(list)\n",
    "    hop_to_test = 2 if name in [\"ogbn-products\"] else 3\n",
    "\n",
    "    res_root = get_paths(name, model, \"../result/best\")\n",
    "    seed_cnt = 0\n",
    "    for result_root_path in os.listdir(res_root):\n",
    "        tic = time()\n",
    "        print(\"*\"*21)\n",
    "        result_root_path = os.path.join(res_root, result_root_path)\n",
    "        with open(os.path.join(result_root_path, \"GNN_data_split_seed.txt\"), 'r') as fin:\n",
    "            seed_str = fin.read()\n",
    "        # data loading\n",
    "        print(\"-\"*11)\n",
    "        print(\"Dataloading...\")\n",
    "        # result_root_path, TARGET_GNN_FOLDER = path_dict[(name, model)][\"result_root_path\"], path_dict[(name, model)][\"TARGET_GNN_FOLDER\"]\n",
    "        TARGET_GNN_FOLDER = os.path.join(path_dict[(name, model)][\"TARGET_GNN_parent_folder\"], \"seed_\"+seed_str)\n",
    "        print(result_root_path)\n",
    "        print(TARGET_GNN_FOLDER)\n",
    "\n",
    "        preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(result_root_path,\n",
    "                                                                TARGET_GNN_FOLDER,\n",
    "                                                                DATA_ROOT_FOLDER,\n",
    "                                                                device=DEVICE)\n",
    "        SDMP_MLP_model = load_SDMP_MLP(train_conf, sdmp, g, result_root_path, device=DEVICE)\n",
    "\n",
    "        ori_node_emb, all_h = sdmp.infer_torch_node_approximal_features(return_H=True)\n",
    "\n",
    "        with open(os.path.join(result_root_path, \"GNN_f1.txt\"), 'r') as fin:\n",
    "            cur_line = fin.read()\n",
    "            try:\n",
    "                target_GNN_acc = eval(cur_line)\n",
    "            except:\n",
    "                target_GNN_acc = eval(cur_line.split(':')[1])\n",
    "        all_ori.append(target_GNN_acc)\n",
    "        ## Load teacher embedding\n",
    "        GNN_TEACHER_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['teacher_path'])\n",
    "        teacher_emb = torch.from_numpy(np.load(GNN_TEACHER_PATH)['arr_0'])\n",
    "        print(f\"Data loading finished in {time()-tic:.1f} s. \")\n",
    "        ###########\n",
    "        # Get the performance of the SDMP model\n",
    "        print(\"-\"*11)\n",
    "        print(\"Get original SDMP performance...\")\n",
    "        tic = time()\n",
    "        test_dataloader = PlainLoader(ori_node_emb.to(DEVICE), torch.tensor(g.ndata['label']).to(DEVICE),\n",
    "                                    train_conf[\"batch_size\"], g.test_idx.to(DEVICE))\n",
    "        SDMP_acc = evaluate(SDMP_MLP_model, test_dataloader)\n",
    "        all_SDMP.append(SDMP_acc.item())\n",
    "        print(f\"Original SDMP finished in {time()-tic:.1f} s. \")\n",
    "        ###########\n",
    "        # Get the performance of the equal weight setting\n",
    "        print(\"-\"*11)\n",
    "        print(\"Get equal weight performance...\")\n",
    "        tic = time()\n",
    "        equal_train_conf = deepcopy(train_conf)\n",
    "        equal_train_conf.update(conf_dict[name][\"equal\"])\n",
    "        print(equal_train_conf)\n",
    "        ## Construct the new feautures\n",
    "        equal_ThetaT_feature = gen_equal_feature(sdmp, all_h)\n",
    "        _, equal_acc = train_and_eval(equal_ThetaT_feature, teacher_emb, g, equal_train_conf)\n",
    "        all_equal.append(equal_acc)\n",
    "\n",
    "        print(f\"Equal weights finished in {time()-tic:.1f} s. \")\n",
    "        \n",
    "        # ###########\n",
    "        # k hop neighbours truncate\n",
    "        hop_pool = [hop_to_test] if name in [\"ogbn-products\"] else [3]\n",
    "        ## get the number of receptive nodes\n",
    "        mean_receptive = int(len(sdmp.ThetaT.tocoo().data) / g.num_nodes())\n",
    "        print(f\"Average number of receptive node from SDMP: {mean_receptive}\")\n",
    "        for hops in hop_pool:\n",
    "            print(\"-\"*11)\n",
    "            print(f\"Get {hops} hop performance...\")\n",
    "            tic = time()\n",
    "            # preprocesser.g = g\n",
    "            A = preprocesser.get_A_pow(hops)\n",
    "            k_hop_tranc_feat = gen_k_hop_tranc_feature(g, A, mean_receptive)\n",
    "\n",
    "            k_hop_train_conf = deepcopy(train_conf)\n",
    "            k_hop_train_conf.update(conf_dict[name][\"k_tranc\"])\n",
    "            print(k_hop_train_conf)\n",
    "            _, A_hop_acc = train_and_eval(k_hop_tranc_feat, teacher_emb, g, k_hop_train_conf)\n",
    "            all_k_tranc[hops].append(A_hop_acc)\n",
    "            print(f\"Get {hops} trancated hop performance finished in {time()-tic:.1f} s.\")\n",
    "        \n",
    "        ###########\n",
    "        # k hop neighbours\n",
    "        hop_pool = [hop_to_test]\n",
    "        for hops in hop_pool:\n",
    "            print(\"-\"*11)\n",
    "            print(f\"Get {hops} hop performance...\")\n",
    "            tic = time()\n",
    "            # preprocesser.g = g\n",
    "            A_1 = preprocesser.get_A_pow(1)\n",
    "            k_hop_feat = gen_k_hop_feature(g, A_1, hops)\n",
    "\n",
    "            k_hop_train_conf = deepcopy(train_conf)\n",
    "            k_hop_train_conf.update(conf_dict[name][\"k_hop\"])\n",
    "            print(k_hop_train_conf)\n",
    "            _, A_hop_acc = train_and_eval(k_hop_feat, teacher_emb, g, k_hop_train_conf)\n",
    "            all_k_hop[hops].append(A_hop_acc)\n",
    "            print(f\"Get {hops} hop performance finished in {time()-tic:.1f} s.\")\n",
    "        ###############\n",
    "        # end of seed loop\n",
    "        seed_cnt += 1\n",
    "        if seed_cnt == MAX_SEED:\n",
    "            break\n",
    "\n",
    "    table_str = \"ori, SDMP, equal, tranc, hop\\n\"\n",
    "    table_str += name+\"/\"+model + \" & \"\n",
    "    table_str += table_ready_mean_std(all_ori) + \" & \"\n",
    "    table_str += table_ready_mean_std(all_SDMP) + \" & \"\n",
    "    table_str += table_ready_mean_std(all_equal) + \" & \"\n",
    "    table_str += table_ready_mean_std(all_k_tranc[hop_to_test]) + \" & \"\n",
    "    table_str += table_ready_mean_std(all_k_hop[hop_to_test]) + \"\\\\\\\\\"\n",
    "    print(table_str)\n",
    "    \n",
    "    with open(f\"../script_logs/breakdown/{name}_{model}.pkl\", \"wb\") as fout:\n",
    "        pickle.dump([all_ori, all_SDMP, all_equal, all_k_tranc, all_k_hop], fout)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load and present all results\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\")]\n",
    "# name_model_pool = [(\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\")]\n",
    "# name_model_pool = [(\"citeseer\", \"geomGCN\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"exphormer\")]\n",
    "name_model_pool = [(\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"citeseer\", \"SAGE\"), (\"a-computer\", \"SAGE\"), (\"a-photo\", \"SAGE\")]\n",
    "class ablationTable(latexTableBase):\n",
    "    # def gen_a_row(self, dict_log):\n",
    "    #     # load the current row\n",
    "    #     with open(f\"../script_logs/breakdown/{dict_log[0]}_{dict_log[1]}.pkl\", \"rb\") as fin:\n",
    "    #         all_ori, all_SDMP, all_equal,all_k_tranc, all_k_hop = pickle.load(fin)\n",
    "    #     # parse it as a list\n",
    "    #     cur_row = [f\"{dict_log[0]}/{dict_log[1]}\"]\n",
    "    #     cur_row.append(f\"{np.mean(all_ori)*100:.2f}$\\pm${np.std(all_ori)*100:.2f}\")\n",
    "    #     cur_row.append(f\"{np.mean(all_SDMP)*100:.2f}$\\pm${np.std(all_SDMP)*100:.2f}\")\n",
    "    #     cur_row.append(f\"{np.mean(all_equal)*100:.2f}$\\pm${np.std(all_equal)*100:.2f}\")\n",
    "    #     cur_row.append(f\"{np.mean(all_k_tranc[3])*100:.2f}$\\pm${np.std(all_k_tranc[3])*100:.2f}\")\n",
    "    #     cur_row.append(f\"{np.mean(all_k_hop[3])*100:.2f}$\\pm${np.std(all_k_hop[3])*100:.2f}\")\n",
    "    #     return cur_row\n",
    "    def gen_a_row(self, dict_log):\n",
    "        # load the current row\n",
    "        with open(f\"../script_logs/breakdown/{dict_log[0]}_{dict_log[1]}.pkl\", \"rb\") as fin:\n",
    "            all_ori, all_SDMP, all_equal,all_k_tranc, all_k_hop = pickle.load(fin)\n",
    "        # parse it as a list\n",
    "        cur_row = [f\"{dict_log[0]}/{dict_log[1]}\"]\n",
    "\n",
    "        hop_to_test = 2 if dict_log[0] in [\"ogbn-products\"] else 3\n",
    "        # cur_row.append(f\"{np.mean(all_ori)*100:.2f}\")\n",
    "        cur_row.append(f\"{np.mean(all_SDMP)*100:.2f}\")\n",
    "        cur_row.append(f\"{np.mean(all_equal)*100:.2f}\")\n",
    "        cur_row.append(f\"{np.mean(all_k_tranc[hop_to_test])*100:.2f}\")\n",
    "        cur_row.append(f\"{np.mean(all_k_hop[hop_to_test])*100:.2f}\")\n",
    "        return cur_row\n",
    "\n",
    "my_tab = ablationTable(name_model_pool, [\"dummpy\"]*5)\n",
    "print(my_tab.table_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
