{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Sparse Decomposition of Message Passing\n",
    "\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "import glob\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn import linear_model\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "from model import SAGE\n",
    "from script_utils import load_SDMP, load_SDMP_MLP\n",
    "from path_dict import *\n",
    "from script_utils import latexTableBase\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:3'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "\n",
    "\n",
    "#############################################\n",
    "# # cora \n",
    "CONF_NAME = \"cora/SDMP/cora_SDMP_base.yml\"\n",
    "TARGET_GNN_FOLDER = '../result/baselines/cora/SAGE/seed_0'\n",
    "\n",
    "# # pubmed\n",
    "# CONF_NAME = \"pubmed/SDMP/pubmed_SDMP_base.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/pubmed/SAGE'\n",
    "\n",
    "# # reddit\n",
    "# CONF_NAME = \"reddit/SDMP/reddit_SDMP_base.yml\"\n",
    "# CONF_NAME = \"reddit/SDMP/reddit_SDMP_no_feature_norm.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/cora/SAGE'\n",
    "\n",
    "# # OGBN-products\n",
    "# CONF_NAME = \"ogbn-products/SDMP/ogbn-products_SDMP_base.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/ogbn-products/SAGE'\n",
    "# CONF_NAME = \"ogbn-products/SDMP/ogbn-products_SDMP_RevGNN-112.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/ogbn-products/RevGNN-112'\n",
    "\n",
    "# # OGBN-arxiv\n",
    "# TARGET_GNN_FOLDER = '../result/ogbn-arxiv/DRGAT/seed_1'\n",
    "# CONF_NAME = 'ogbn-arxiv/SDMP/ogbn-arxiv_SDMP_DRGAT_tune_1.yml'\n",
    "\n",
    "# # a-computer\n",
    "# CONF_NAME = \"a-computer/SDMP/a-computer_SDMP_base.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/a-computer/SAGE'\n",
    "\n",
    "# # a-photo\n",
    "# CONF_NAME = \"a-photo/SDMP/a-photo_SDMP_SAGE_best_search.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/a-photo/SAGE/seed_0/'\n",
    "\n",
    "# ablation studies\n",
    "# CONF_NAME = \"ablation/cand_select/cand_select_cora_full.yml\"\n",
    "# TARGET_GNN_FOLDER = '../result/cora/geomGCN/seed_0'\n",
    "\n",
    "#######################################\n",
    "# Initialization and hyper-parameter setting\n",
    "\n",
    "_TIME_ZONE = 0\n",
    "TIMESTAMP = time()\n",
    "TIMESTAMP_FORMATTED = datetime.datetime.fromtimestamp(\n",
    "    int(TIMESTAMP)+_TIME_ZONE*3600).strftime('%Y%m%d-%H%M%S')\n",
    "HOST_NAME = socket.gethostname()\n",
    "\n",
    "conf_path = os.path.join(CONF_ROOT_FOLDER, CONF_NAME)\n",
    "train_conf = load_sdmp_conf_with_default(conf_path)\n",
    "\n",
    "DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf[\"name\"])\n",
    "if not os.path.exists(DATA_FOLDER):\n",
    "    os.makedirs(DATA_FOLDER)\n",
    "\n",
    "# sys.stdout = Logger(os.path.join(RES_FOLDER, \"log.txt\"))\n",
    "\n",
    "print(train_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "######################################\n",
    "h_model_path = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_path'])\n",
    "h_conf_path = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_conf_path'])\n",
    "GNN_ACC_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_metric_path'])\n",
    "\n",
    "# target f1\n",
    "with open(GNN_ACC_PATH, 'r') as fin:\n",
    "    tmp = fin.read()\n",
    "    try:\n",
    "        ori_f1 = eval(tmp)\n",
    "    except:\n",
    "        ori_f1 = eval(tmp.split(':')[-1])\n",
    "print(f\"Target GNN fi: {ori_f1}\")\n",
    "\n",
    "# data loading and prepare the data\n",
    "preprocesser = SDMPDataPre(train_conf[\"name\"], train_conf[\"feature_normalize\"],\n",
    "                           train_conf['target_h_mode'],\n",
    "                           h_conf_path, h_model_path, train_conf[\"target_h_model\"], \n",
    "                           train_conf[\"h_init_theta_mode\"], train_conf[\"h_init_theta_k\"],\n",
    "                           train_conf[\"h_init_theta_k_fanout\"],\n",
    "                           train_conf[\"theta_cand_mode\"], train_conf[\"theta_cand_k2\"],\n",
    "                           train_conf[\"theta_cand_k1\"], train_conf[\"theta_cand_fanout\"],\n",
    "                           train_conf[\"theta_cand_add_self\"],\n",
    "                           train_conf,\n",
    "                           use_cache=True, \n",
    "                           cache_path=os.path.join(DATA_FOLDER, \"SDMPPre\"),\n",
    "                           save_cache=False,\n",
    "                           device=DEVICE)\n",
    "preprocesser.disp_states()\n",
    "theta_cand, h_init_theta, X, target = preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loading and prepare the data\n",
    "preprocesser = SDMPDataPre(train_conf[\"name\"], train_conf[\"feature_normalize\"],\n",
    "                           train_conf['target_h_mode'],\n",
    "                           h_conf_path, h_model_path, train_conf[\"target_h_model\"], \n",
    "                           train_conf[\"h_init_theta_mode\"], train_conf[\"h_init_theta_k\"],\n",
    "                           train_conf[\"h_init_theta_k_fanout\"],\n",
    "                           \"dense\", train_conf[\"theta_cand_k2\"],\n",
    "                           train_conf[\"theta_cand_k1\"], train_conf[\"theta_cand_fanout\"],\n",
    "                           train_conf[\"theta_cand_add_self\"],\n",
    "                           train_conf,\n",
    "                           use_cache=True, \n",
    "                           cache_path=os.path.join(DATA_FOLDER, \"SDMPPre\"),\n",
    "                           save_cache=False,\n",
    "                           device=DEVICE)\n",
    "preprocesser.disp_states()\n",
    "theta_cand, h_init_theta, X, target = preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target\n",
    "\n",
    "## verify the sampling results\n",
    "g = load_data(train_conf[\"name\"])\n",
    "\n",
    "graph_test = graphTester(g, sample=100, layer=3, device=DEVICE)\n",
    "graph_test.gen_breakdown_base()\n",
    "print(\"all_neighbours_stats\")\n",
    "print(\"exclusive\")\n",
    "_ = graph_test.disp_neigh_stats()\n",
    "print(\"inclusive\")\n",
    "_ = graph_test.disp_neigh_stats(exclusive=False)\n",
    "\n",
    "\n",
    "### theta cand\n",
    "print(\"theta_candi overlap stats\")\n",
    "_, _ = graph_test.disp_overlap(theta_cand, exclusive=True)\n",
    "_, _ = graph_test.disp_overlap(theta_cand, exclusive=False)\n",
    "\n",
    "### h cand\n",
    "print(\"h overlap stats\")\n",
    "over_lap, _ = graph_test.disp_overlap(h_init_theta, exclusive=True)\n",
    "over_lap, _ = graph_test.disp_overlap(h_init_theta, exclusive=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.ndata['label'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Initializing the model...\")\n",
    "test = SDMP(X,\n",
    "            target,\n",
    "            theta_cand,\n",
    "            h_init_theta,\n",
    "            train_conf,\n",
    "            device=DEVICE,\n",
    "            verbose=True)\n",
    "\n",
    "print(\"Starting to fit...\")\n",
    "test.fit(eval_level=1)\n",
    "\n",
    "# # test \n",
    "# tic = time()\n",
    "# test._init_Theta()\n",
    "# print(time()-tic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test.save(\"../result/test_proximal\")\n",
    "# export_train_conf(os.path.join(\"../result/test_proximal/conf.yml\"), train_conf)\n",
    "test.save(\"../result/test\")\n",
    "export_train_conf(os.path.join(\"../result/test/conf.yml\"), train_conf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Debug lasso lars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bug_path = \"/home/yaochen/Working/localhashgnn/result/ogbn-products_SAGE_SDMP/blackhole_2023-08-31_17:38:35/debug.pkl\"\n",
    "\n",
    "with open(bug_path, \"rb\") as fin:\n",
    "    bug = pickle.load(fin)\n",
    "\n",
    "cur = bug[-1]\n",
    "\n",
    "X = torch.from_numpy(cur['H'])\n",
    "y = torch.from_numpy(target[cur['node_id']]).reshape(-1, 1)\n",
    "print('X, y shape',X.shape, y.shape)\n",
    "my_gram = torch.mm(X.t(), X)\n",
    "Xy = torch.mm(X.t(), y)\n",
    "print('gram. Xy shape', my_gram.shape, Xy.shape)\n",
    "npXy = Xy.cpu().detach().numpy()\n",
    "npX = X.cpu().detach().numpy()\n",
    "npy = y.cpu().detach().numpy()\n",
    "npgram = my_gram.cpu().detach().numpy()\n",
    "cur"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg = linear_model.LassoLars(alpha=0, precompute=npgram, max_iter=40,\n",
    "                             normalize=False, fit_intercept=False, positive=True)\n",
    "reg.fit(X=npX, y=npy, Xy=npXy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Speedtest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_inference_time(model, sdmp, device, g, eval_size = 1000):\n",
    "    sampling_time, infer_time = [], []\n",
    "    val_idx = g.val_idx[all_ind[:eval_size]].to(\"cpu\")\n",
    "    tic_data_loader = time()\n",
    "    val_dataloader = PlainLoader(torch.from_numpy(np.arange(g.num_nodes())),\n",
    "                                 g.ndata['label'].to(\"cpu\"),\n",
    "                                 1, val_idx)\n",
    "    print(\"data loader created in {:.6f} seconds\".format(time() - tic_data_loader))\n",
    "    model.eval()\n",
    "    sdmp.efficient_prepare()\n",
    "    with torch.no_grad():\n",
    "        total_loss = 0\n",
    "        y_hat_list, y_list = [], []\n",
    "        tic_sampling = time()\n",
    "        for it, (x, y) in tqdm(enumerate(val_dataloader)):\n",
    "            # feat = sdmp.infer_torch_node_approximal_features_idx(x)\n",
    "            feat = sdmp.efficient_node_wise_infer(x)\n",
    "            sampling_time.append(time() - tic_sampling)\n",
    "            tic_infer = time()\n",
    "            y_hat = model(feat)\n",
    "            infer_time.append(time() - tic_infer)\n",
    "            y = y.to(device)\n",
    "            loss = F.cross_entropy(y_hat, y)\n",
    "            total_loss += loss.item()\n",
    "            y_hat_list.append(y_hat)\n",
    "            y_list.append(y)\n",
    "            torch.cuda.empty_cache()\n",
    "            tic_sampling = time()\n",
    "        # end of for\n",
    "        y_hat = torch.cat(y_hat_list)\n",
    "        y = torch.cat(y_list)\n",
    "        metric = torch_f1(y_hat, y)\n",
    "    print(f\"f1 score {metric}\")\n",
    "    return sampling_time, infer_time\n",
    "\n",
    "def get_paths(name, model, result_root):\n",
    "    model_paths = os.path.join(result_root, f\"{name}_{model}_MLP_TEST\")\n",
    "    return next(glob.iglob(f\"{model_paths}/*\"))\n",
    "# device=\"cuda:3\"\n",
    "DEVICE=\"cpu\"\n",
    "print(\"=\"*10+f\"Test on {DEVICE}\")\n",
    "res_root = \"../result/best\"\n",
    "save_res_root = \"../script_logs/speedtest/SDMP\"\n",
    "if not os.path.exists(save_res_root):\n",
    "    os.makedirs(save_res_root)\n",
    "    \n",
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "for name, model in name_model_pool:\n",
    "    cur_result_root = get_paths(name, model, res_root)\n",
    "    # cur_result_root = os.path.join(res_root, cur_result_root)\n",
    "\n",
    "    with open(os.path.join(cur_result_root, \"GNN_data_split_seed.txt\"), 'r') as fin:\n",
    "        seed_str = fin.read()\n",
    "\n",
    "    TARGET_GNN_FOLDER = os.path.join(path_dict[(name, model)][\"TARGET_GNN_parent_folder\"], \"seed_\"+seed_str)\n",
    "    print(cur_result_root)\n",
    "    print(TARGET_GNN_FOLDER)\n",
    "\n",
    "    preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(cur_result_root,\n",
    "                                                            TARGET_GNN_FOLDER,\n",
    "                                                            DATA_ROOT_FOLDER,\n",
    "                                                            device=DEVICE)\n",
    "    SDMP_MLP_model = load_SDMP_MLP(train_conf, sdmp, g, cur_result_root, device=DEVICE)\n",
    "\n",
    "\n",
    "    all_ind = list(range(g.val_idx.size()[0]))\n",
    "    random.seed(666)\n",
    "    random.shuffle(all_ind)\n",
    "\n",
    "    all_sample_time, all_infer_time = [], []\n",
    "\n",
    "    sampling_time, infer_time = test_inference_time(SDMP_MLP_model,sdmp, DEVICE, g)\n",
    "    all_sample_time.append(sampling_time)\n",
    "    all_infer_time.append(infer_time)\n",
    "\n",
    "    tmp = sdmp.ThetaT[all_ind[:1000]].tocsr()\n",
    "    nonzero_cnt = tmp.indptr[1:] - tmp.indptr[:-1]\n",
    "    print(f\"Theta nonzero cnt mean {np.mean(nonzero_cnt):.1f}, std {np.std(nonzero_cnt):.1f}\")\n",
    "\n",
    "    with open(os.path.join(save_res_root, f\"{name}_{model}.pkl\"), \"wb\") as fout:\n",
    "            pickle.dump([all_sample_time, all_infer_time], fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# display the mean of all\n",
    "name_model_pool = [(\"cora\", \"SAGE\"), (\"cora\", \"geomGCN\"), (\"citeseer\", \"SAGE\"), (\"citeseer\", \"geomGCN\"), (\"pubmed\", \"SAGE\"), (\"pubmed\", \"geomGCN\"), (\"a-computer\", \"SAGE\"), (\"a-computer\", \"exphormer\"), (\"a-photo\", \"SAGE\"), (\"a-photo\", \"exphormer\"), (\"ogbn-arxiv\", \"SAGE\"), (\"ogbn-arxiv\", \"DRGAT\"), (\"ogbn-products\", \"SAGE\"), (\"ogbn-products\", \"RevGNN-112\")]\n",
    "\n",
    "\n",
    "table_list = []\n",
    "for name, model in name_model_pool:\n",
    "    with open(os.path.join(save_res_root, f\"{name}_{model}.pkl\"), \"rb\") as fin:\n",
    "        [all_sample_time, all_infer_time] = pickle.load(fin)\n",
    "    cur_list = [f\"{name}/{model}\"]\n",
    "    all_time_list =[[a+b for a, b in zip(s, i)] for s, i in zip(all_sample_time, all_infer_time)]\n",
    "    cur_list.append(f\"{np.mean(all_time_list[0])*1000:.1f}\")\n",
    "    table_list.append(cur_list)\n",
    "    \n",
    "table_str = latexTableBase.gen_table(table_list)\n",
    "\n",
    "print(table_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# final string generation with various functions\n",
    "\n",
    "def gen_res_table(stime, itime, func_list, func_name_list, head_list):\n",
    "    def apply_func(time_list, func):\n",
    "        return [func(i[1:]) for i in time_list]\n",
    "    rows = []\n",
    "    # first row\n",
    "    rows.append(\"\\hline\")\n",
    "    cur_row = \"Neighbor size & \" + \" & \".join(head_list) + \" \\\\\\\\ \\hline\"\n",
    "    rows.append(cur_row)\n",
    "    \n",
    "    for name, func in zip(func_name_list, func_list):\n",
    "        svals = apply_func(stime, func)\n",
    "        ivals = apply_func(itime, func)\n",
    "        all_list = [[a+b for a, b in zip(s, i)] for s, i in zip(stime, itime)]\n",
    "        alls = apply_func(all_list, func)\n",
    "        comb = [\"{:.5f}/{:.5f}/{:.5f}\".format(i*1000, j*1000, k*1000) for i, j, k in zip(svals, ivals, alls)]\n",
    "        cur_row = name + \" & \" + \" & \".join(comb) + \" \\\\\\\\\"\n",
    "        rows.append(cur_row)\n",
    "        \n",
    "    table_str = \"\\n\".join(rows)\n",
    "    \n",
    "    table_str += \" \\hline\"\n",
    "    return table_str\n",
    "        \n",
    "func_list = [np.mean, lambda x: np.percentile(x, 90), lambda x: np.percentile(x, 99), np.max, np.std]\n",
    "func_name_list = [\"mean\", \"90-percentile\", \"99-percentile\", \"max\", \"std\"]\n",
    "head_list = [\"SDMP\"]\n",
    "\n",
    "res_str = gen_res_table(all_sample_time, all_infer_time, func_list, func_name_list, head_list)\n",
    "print(res_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
