{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "various analysis\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "from tqdm import tqdm\n",
    "import glob\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "from copy import deepcopy\n",
    "\n",
    "# import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import scipy.sparse as sp\n",
    "from sklearn.preprocessing import normalize\n",
    "from collections import defaultdict\n",
    "from dgl.dataloading import MultiLayerFullNeighborSampler\n",
    "\n",
    "\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "from plotlib import *\n",
    "from script_utils import load_SDMP, confHolder, latexTableBase\n",
    "from path_dict import path_dict, name_model_pool\n",
    "from plotlib import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:0'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "TARGET_ROOT_FOLDER = \"../result/baselines\"\n",
    "\n",
    "conf_dict = defaultdict(dict)\n",
    "\n",
    "conf_dict[\"cora\"][\"dense_k\"] = 3\n",
    "conf_dict[\"pubmed\"][\"dense_k\"] = 3\n",
    "conf_dict[\"citeseer\"][\"dense_k\"] = 3\n",
    "conf_dict[\"a-computer\"][\"dense_k\"] = 3\n",
    "conf_dict[\"a-photo\"][\"dense_k\"] = 3\n",
    "conf_dict[\"ogbn-arxiv\"][\"dense_k\"] = 3\n",
    "conf_dict[\"ogbn-products\"][\"dense_k\"] = 2\n",
    "# conf_dict[\"cora\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"pubmed\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"citeseer\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"a-computer\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"a-photo\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"ogbn-arxiv\"][\"dense_k\"] = 1\n",
    "# conf_dict[\"ogbn-products\"][\"dense_k\"] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hop-wise analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_paths(name, model, result_root):\n",
    "#     model_paths = os.path.join(result_root, f\"{name}_{model}_MLP_TEST\")\n",
    "#     return next(glob.iglob(f\"{model_paths}/*\"))\n",
    "\n",
    "def get_paths(name, model):\n",
    "    model_paths = path_dict[name, model][\"result_root_parent_folder\"]\n",
    "    return next(glob.iglob(f\"{model_paths}/*\"))\n",
    "\n",
    "class receptiveNode_all_fast:\n",
    "    def __init__(self, A, Theta, Candidate):\n",
    "        self.list_ori_receptive = A.indptr[1:] - A.indptr[:-1]\n",
    "        self.list_candi_receptive = Candidate.indptr[1:] - Candidate.indptr[:-1]\n",
    "        self.list_SDMP_receptive = Theta.indptr[1:] - Theta.indptr[:-1]\n",
    "\n",
    "# all_res = []\n",
    "all_res_dict = defaultdict(dict)\n",
    "result_root = \"../result/best/\"\n",
    "\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n",
    "name_model_pool = [(\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "# name_model_pool = [(\"cora\", \"SAGE\")]\n",
    "\n",
    "for name, model in name_model_pool:\n",
    "    seed = 0 # the first seed is sufficient for current analysis\n",
    "    # data loading\n",
    "    cur_conf = conf_dict[name]\n",
    "    TARGET_GNN_FOLDER = os.path.join(TARGET_ROOT_FOLDER, name, model, f\"seed_{seed}\")\n",
    "    result_root_path = get_paths(name, model)\n",
    "    \n",
    "    preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(result_root_path,\n",
    "                                                            TARGET_GNN_FOLDER,\n",
    "                                                            DATA_ROOT_FOLDER,\n",
    "                                                            device=DEVICE)\n",
    "    Theta = sdmp.ThetaT.tocsr()\n",
    "    print(len(Theta.data))\n",
    "    print(f\"Negative entry count: {np.sum(Theta.data<0)}\")\n",
    "    Theta.data[Theta.data<0] = 0\n",
    "    Theta.eliminate_zeros()\n",
    "    print(len(Theta.data), Theta.nnz, np.sum(Theta.data<0), np.sum(Theta.data>0), np.sum(Theta.data==0))\n",
    "\n",
    "    # Get the full aggregation results\n",
    "    dense_neigh = preprocesser.get_A_pow(cur_conf[\"dense_k\"])\n",
    "    receptive = receptiveNode_all_fast(dense_neigh, Theta, preprocesser.theta_cand)\n",
    "    ## collect all the results\n",
    "    all_res_dict[(name, model)][\"ori_all\"] = receptive.list_ori_receptive\n",
    "    all_res_dict[(name, model)][\"candi_all\"] = receptive.list_candi_receptive \n",
    "    all_res_dict[(name, model)][\"SDMP_all\"] = receptive.list_SDMP_receptive\n",
    "    \n",
    "    # get the per hop results\n",
    "    hop_tester = graphTester(g, layer=cur_conf[\"dense_k\"], sample=1000, device=DEVICE)\n",
    "    hop_tester.gen_breakdown_base()\n",
    "    SDMP_overlaps = hop_tester.csr_receptive_overlap(Theta)\n",
    "    candi_overlaps = hop_tester.csr_receptive_overlap(preprocesser.theta_cand)\n",
    "    all_res_dict[(name, model)][\"ori_hop\"] = hop_tester.exclusive_breakdown_neighbour\n",
    "    all_res_dict[(name, model)][\"SDMP_hop\"] = SDMP_overlaps\n",
    "    all_res_dict[(name, model)][\"candi_hop\"] = candi_overlaps\n",
    "\n",
    "# generate the desired table\n",
    "def wrap_one_row(name, model, res_dict):\n",
    "    cur_list = [f\"{name}/{model}\"]\n",
    "    cur_list.append(f\"{np.mean(res_dict['SDMP_all']):.1f}\")\n",
    "    cur_list.append(f\"{np.mean(res_dict['candi_all']):.1f}\")\n",
    "    cur_list.append(f\"{np.mean(res_dict['ori_all']):.1f}\")\n",
    "    \n",
    "    k = res_dict[\"ori_hop\"][0].__len__()\n",
    "    for i in range(k):\n",
    "        cur_list.append(f\"{np.mean([len(each[i]) for each in res_dict['SDMP_hop']]):.1f}\")\n",
    "        cur_list.append(f\"{np.mean([len(each[i]) for each in res_dict['candi_hop']]):.1f}\")\n",
    "        cur_list.append(f\"{np.mean([len(each[i]) for each in res_dict['ori_hop']]):.1f}\")\n",
    "    cur_list.append(f\"{np.mean([len(each[-1]) for each in res_dict['SDMP_hop']]):.1f}\")\n",
    "    cur_list.append(f\"{np.mean([len(each[-1]) for each in res_dict['candi_hop']]):.1f}\")\n",
    "    return cur_list\n",
    "    \n",
    "def wrap_results(all_res_dict):\n",
    "    list_to_show = [[\"Data/Model\", \"SDGNN\", \"Candidate\", \"3-hop Neigh.\", \"SDGNN\", \"Candidate\", \"Original\", \"SDGNN\", \"Candidate\", \"Original\", \"SDGNN\", \"Candidate\", \"Original\", \"SDGNN\", \"Candidate\"]]\n",
    "    for k, v in all_res_dict.items():\n",
    "        list_to_show.append(wrap_one_row(k[0], k[1], v))\n",
    "    return list_to_show\n",
    "\n",
    "list_to_show = wrap_results(all_res_dict)\n",
    "table_res_str = latexTableBase.gen_table(list_to_show)\n",
    "print(table_res_str)\n",
    "\n",
    "# # save the results\n",
    "# with open(\"../script_logs/table_3_1000_samples_short_case.pkl\", \"wb\") as fout:\n",
    "#     pickle.dump(all_res_dict, fout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Receptive label analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "def get_paths(name, model, result_root):\n",
    "    model_paths = os.path.join(result_root, f\"{name}_{model}_MLP_TEST\")\n",
    "    return next(glob.iglob(f\"{model_paths}/*\"))\n",
    "\n",
    "# all_res = []\n",
    "all_res_dict = defaultdict(dict)\n",
    "result_root = \"../result/best\"\n",
    "save_folder = \"../script_logs/label_ratio\"\n",
    "if not os.path.exists(save_folder):\n",
    "    os.makedirs(save_folder)\n",
    "\n",
    "# name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n",
    "# name_model_pool = [(\"ogbn-products\", \"RevGNN-112\")]\n",
    "name_model_pool = [(\"citeseer\", \"geomGCN\")]\n",
    "# name_model_pool = [(\"cora\", \"SAGE\")]\n",
    "\n",
    "for name, model in name_model_pool:\n",
    "    # data loading\n",
    "    cur_conf = conf_dict[name]\n",
    "    TARGET_GNN_FOLDER = path_dict[(name, model)][\"TARGET_GNN_FOLDER\"]\n",
    "    result_root_path = get_paths(name, model, result_root)\n",
    "    \n",
    "    preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(result_root_path,\n",
    "                                                            TARGET_GNN_FOLDER,\n",
    "                                                            DATA_ROOT_FOLDER,\n",
    "                                                            device=DEVICE)\n",
    "    Theta = sdmp.ThetaT.tocsr()\n",
    "    print(len(Theta.data))\n",
    "    print(f\"Negative entry count: {np.sum(Theta.data<0)}\")\n",
    "    Theta.data[Theta.data<0] = 0\n",
    "    Theta.eliminate_zeros()\n",
    "    print(len(Theta.data), Theta.nnz, np.sum(Theta.data<0), np.sum(Theta.data>0), np.sum(Theta.data==0))\n",
    "\n",
    "    hop_tester = graphTester(g, layer=cur_conf[\"dense_k\"], sample=1000, device=DEVICE)\n",
    "\n",
    "    # get the label distribution results\n",
    "    list_ratio, list_weighted_ratio = hop_tester.csr_receptive_label_state(Theta)\n",
    "\n",
    "    with open(os.path.join(save_folder, name+\"_\"+model), 'wb') as fout:\n",
    "        pickle.dump([list_ratio, list_weighted_ratio], fout, protocol=4)\n",
    "    \n",
    "    print(name, model, np.mean(list_ratio), np.mean(list_weighted_ratio))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n",
    "#  collect all the results\n",
    "all_res = dict()\n",
    "table_list = list()\n",
    "for name, model in name_model_pool:\n",
    "    with open(os.path.join(save_folder, name+\"_\"+model), 'rb') as fin:\n",
    "        list_ratio, list_weighted_ratio = pickle.load(fin)\n",
    "    all_res[name, model] = [list_ratio, list_weighted_ratio]\n",
    "    cur_table_list = [f\"{name}/{model}\"]\n",
    "    cur_table_list.append(f\"{np.mean(list_ratio):.4f}\")\n",
    "    cur_table_list.append(f\"{np.mean(list_weighted_ratio):.4f}\")\n",
    "    table_list.append(cur_table_list)\n",
    "\n",
    "table_str = latexTableBase.gen_table(table_list)\n",
    "print(table_str)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# display all conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# display the conf\n",
    "all_conf = confHolder(path_dict, name_model_pool)\n",
    "to_show = [\"theta_cand_k1\", \"theta_cand_k2\", \"theta_cand_fanout\", \"h_lr\", \"h_l2\"]\n",
    "table_str_conf = all_conf.gen_latex_table(all_conf.sdmp_conf, to_show)\n",
    "print(table_str_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_res_dict[('ogbn-arxiv', 'SAGE')]['SDMP'][0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
