{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\"\"\"\n",
    "various analysis\n",
    "\"\"\"\n",
    "# pylint: disable=anomalous-backslash-in-string\n",
    "# pylint: disable=invalid-name\n",
    "# pylint: disable=import-error\n",
    "# pylint: disable=missing-function-docstring\n",
    "import os\n",
    "import sys\n",
    "sys.path.extend([\"../\"]) # pylint: disable=wrong-import-position\n",
    "import random\n",
    "from time import time\n",
    "import warnings\n",
    "import pickle\n",
    "import datetime\n",
    "import socket\n",
    "from copy import deepcopy, copy\n",
    "\n",
    "import shutil\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import glob\n",
    "\n",
    "# import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import scipy.sparse as sp\n",
    "from sklearn.preprocessing import normalize\n",
    "from collections import defaultdict\n",
    "\n",
    "from data_utils import *\n",
    "from graph_dict import *\n",
    "from utils import *\n",
    "from plotlib import *\n",
    "from script_utils import load_SDMP, load_SDMP_MLP\n",
    "from path_dict import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:0'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# %lprun -f\n",
    "# Parameter settings and data loading\n",
    "DATA_ROOT_FOLDER = \"../dataset\"\n",
    "CONF_ROOT_FOLDER = \"../config\"\n",
    "RES_ROOT_FOLDER = \"../result\"\n",
    "\n",
    "def get_ori_acc(model, new_sample, ori_node_emb, g):\n",
    "    test_idx = np.array([i[0] for i in new_sample])\n",
    "    test_idx = torch.from_numpy(test_idx).to(DEVICE)\n",
    "    # y = g.ndata['label'].to(DEVICE)[test_idx]\n",
    "    labels = torch.tensor(g.ndata['label']).to(DEVICE)\n",
    "    test_dataloader = PlainLoader(ori_node_emb, labels, 64, test_idx)\n",
    "    test_acc = evaluate(model, test_dataloader)\n",
    "    return test_acc\n",
    "\n",
    "def get_perturb_acc(sdmp, model, new_sample, all_h, g):\n",
    "    # get the labels\n",
    "    y = torch.tensor([i[1] for i in new_sample]).to(DEVICE)\n",
    "    # get the predicted labels\n",
    "    ## generate the new feat\n",
    "    ThetaT = deepcopy(sdmp.ThetaT)\n",
    "    rep_cnt, total_cnt = 0, 0\n",
    "    for sample in new_sample:\n",
    "        idx, label, B_ori, B_new, C_ori, C_new = sample\n",
    "        to_search = B_ori + C_ori\n",
    "        to_replace = B_new + C_new\n",
    "        cur_row = ThetaT.rows[idx]\n",
    "        \n",
    "        total_cnt += len(cur_row)\n",
    "        for j, row_val in enumerate(cur_row):\n",
    "            try:\n",
    "                cur_row[j] = to_replace[to_search.index(row_val)]\n",
    "                rep_cnt += 1\n",
    "            except:\n",
    "                pass\n",
    "    # print(f\"{rep_cnt} elements has been replaced. Total nnz {total_cnt}.\")\n",
    "    test_idx =[i[0] for i in new_sample]\n",
    "    local_Theta = scipy_coo_to_torch_sparse(ThetaT[test_idx].tocoo()).to(DEVICE)\n",
    "    feat = torch.sparse.mm(local_Theta, all_h)\n",
    "    ## Compute the new pred\n",
    "    y_hat = model(feat)\n",
    "    return torch_f1(y_hat, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_paths(name, model, result_root):\n",
    "    model_paths = os.path.join(result_root, f\"{name}_{model}_MLP_TEST\")\n",
    "    return next(glob.iglob(f\"{model_paths}/*\"))\n",
    "###########\n",
    "name, model = \"cora\", \"SAGE\"\n",
    "# name, model = \"cora\", \"geomGCN\"\n",
    "# name, model = \"citeseer\", \"SAGE\"\n",
    "# name, model = \"citeseer\", \"geomGCN\"\n",
    "# name, model = \"pubmed\", \"SAGE\"\n",
    "# name, model = \"pubmed\", \"geomGCN\"\n",
    "# name, model = \"a-computer\", \"SAGE\"\n",
    "# name, model = \"a-computer\", \"exphormer\"\n",
    "# name, model = \"a-photo\", \"SAGE\"\n",
    "# name, model = \"a-photo\", \"exphormer\"\n",
    "# name, model = \"ogbn-arxiv\", \"SAGE\"\n",
    "# name, model = \"ogbn-arxiv\", \"DRGAT\"\n",
    "# name, model = \"ogbn-products\", \"SAGE\"\n",
    "# name, model = \"ogbn-products\", \"RevGNN-112\"\n",
    "\n",
    "TARGET_GNN_FOLDER = path_dict[(name, model)][\"TARGET_GNN_FOLDER\"]\n",
    "result_root_path = get_paths(name, model, \"../result/best\")\n",
    "pert_test_root_path = os.path.join(\"/home/yaochen/AISTATS24_snap/AISTATS24_extened_results/outputs_permutation/transductive/\", name, model+\"_MLP\")\n",
    "\n",
    "preprocesser, sdmp, g, sdmp_conf, train_conf = load_SDMP(result_root_path,\n",
    "                                                         TARGET_GNN_FOLDER,\n",
    "                                                         DATA_ROOT_FOLDER,\n",
    "                                                         device=DEVICE)\n",
    "mlp_model = load_SDMP_MLP(train_conf, sdmp, g, result_root_path, device=DEVICE)\n",
    "\n",
    "ori_node_emb, all_h = sdmp.infer_torch_node_approximal_features(return_H=True)\n",
    "# ori_node_emb = sdmp.infer_torch_node_approximal_features(return_H=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_acc_ori, all_acc_per = [], []\n",
    "pert_test_root_path = os.path.join(\"/home/yaochen/AISTATS24_snap/AISTATS24_extened_results/outputs_permutation/transductive/\", name, model+\"_MLP\")\n",
    "\n",
    "for seed in tqdm(range(0,10)):\n",
    "    cur_sample_path = os.path.join(pert_test_root_path, f\"permutation_info_seed_{seed}.pickle\")\n",
    "    with open(cur_sample_path, \"rb\") as fin:\n",
    "        new_sample = pickle.load(fin)\n",
    "        \n",
    "    SDMP_ori_acc = get_ori_acc(mlp_model, new_sample, ori_node_emb, g)   \n",
    "    SDMP_new_acc = get_perturb_acc(sdmp, mlp_model, new_sample, all_h, g)\n",
    "    all_acc_ori.append(SDMP_ori_acc.cpu().item())\n",
    "    all_acc_per.append(SDMP_new_acc.cpu().item())\n",
    "\n",
    "print(\"ori\", all_acc_ori)\n",
    "print(latex_sample_mean_std_confidence(all_acc_ori, confidence=0.99))\n",
    "print(\"per\", all_acc_per)\n",
    "print(latex_sample_mean_std_confidence(all_acc_per, confidence=0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc-autonumbering": false,
  "toc-showmarkdowntxt": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
