{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext jupyter_spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deeprobust.graph.data import Dataset\n",
    "import pandas\n",
    "import numpy as np\n",
    "import scipy.sparse\n",
    "import networkx as nx\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict, OrderedDict, deque\n",
    "import signac\n",
    "import pickle\n",
    "import itertools\n",
    "from jupyter_spaces import get_spaces\n",
    "import jupyter_spaces\n",
    "from scipy.special import softmax\n",
    "import warnings\n",
    "import itertools\n",
    "import plotly\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "from plotly.subplots import make_subplots\n",
    "from IPython.display import HTML\n",
    "import re\n",
    "from ipywidgets import interact, fixed, Textarea, Layout\n",
    "import ipywidgets as widgets\n",
    "from pathlib import Path\n",
    "from io import StringIO\n",
    "import sys\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HRPath = Path(\"../../\").resolve()\n",
    "assert (HRPath / \"HeteroRobust\" / \"__init__.py\").exists()\n",
    "sys.path.append(str(HRPath))\n",
    "import HeteroRobust\n",
    "from HeteroRobust.attacks.modules.sparse_smoothing.cert import binary_certificate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project = signac.get_project(\"../../\")\n",
    "datasetRoot = \"../../datasets/data\"\n",
    "datasetDict = dict()\n",
    "for datasetName in [\"citeseer\", \"cora\"]:\n",
    "    datasetDict[datasetName] = Dataset(root=datasetRoot, name=datasetName)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun = pandas.read_csv(\"sparse-smoothing-cert.csv\", index_col=0, keep_default_na=False, na_values=[\"\"])\n",
    "\n",
    "df_expRun = df_expRun.drop('evasionJobID', 1)\n",
    "df_expRun = df_expRun.drop('poisonJobID', 1)\n",
    "\n",
    "incomplete_mask = df_expRun.cleanJobID.isnull()\n",
    "if incomplete_mask.sum() > 0:\n",
    "    warnings.warn(f\"{incomplete_mask.sum()} experiments are incomplete!\")\n",
    "df_expRun_Original = df_expRun.copy()\n",
    "df_expRun = df_expRun.loc[~incomplete_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun.loc[df_expRun.model_arg.isna(), \"model_arg\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2gcn_code_mask = (df_expRun.model == \"H2GCN\")\n",
    "re_h2gcn2 = re.compile(r\"$^\")\n",
    "h2gcn2_mask = df_expRun.model_arg.apply(lambda x: re_h2gcn2.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[h2gcn2_mask, \"model\"] = \"H2GCN-2\"\n",
    "\n",
    "h2gcn_code_mask = (df_expRun.model == \"H2GCN\")\n",
    "re_h2gcn1 = re.compile(r\"--network_setup M64-R-T1-G-V-C1-D[\\d\\.]*-MO\")\n",
    "h2gcn1_mask = df_expRun.model_arg.apply(lambda x: re_h2gcn1.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[h2gcn1_mask, \"model\"] = \"H2GCN-1\"\n",
    "\n",
    "h2gcn_code_mask = (df_expRun.model == \"H2GCN\")\n",
    "re_model = re.compile(r\"--network_setup I-T1-G-V-C1-M64-R-T2-G-V-C2-MO-R\")\n",
    "model_mask = df_expRun.model_arg.apply(lambda x: re_model.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[model_mask, \"model\"] = \"GraphSAGE\"\n",
    "\n",
    "h2gcn_code_mask = (df_expRun.model == \"H2GCN\")\n",
    "re_mlp = re.compile(r\"--network_setup M64-R-D0.5-MO\")\n",
    "mlp_mask = df_expRun.model_arg.apply(lambda x: re_mlp.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[mlp_mask, \"model\"] = \"MLP\"\n",
    "\n",
    "df_expRun.loc[df_expRun.model.isin([\"MultiLayerGCN\"]), \"model\"] = \"GCN\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2gcn_code_mask = (df_expRun.model == \"CPGNN\")\n",
    "re_model = re.compile(r\"--network_setup GGM64-VS-R-G-GMO-VS-E-BP2\")\n",
    "model_mask = df_expRun.model_arg.apply(lambda x: re_model.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[model_mask, \"model\"] = \"CPGNN-Cheby\"\n",
    "\n",
    "h2gcn_code_mask = (df_expRun.model == \"CPGNN\")\n",
    "re_model = re.compile(r\"--network_setup M64-R-MO-E-BP2\")\n",
    "model_mask = df_expRun.model_arg.apply(lambda x: re_model.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[model_mask, \"model\"] = \"CPGNN-MLP\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2gcn_code_mask = (df_expRun.model == \"GPRGNN\")\n",
    "re_model = re.compile(r\"--dropout 0.5\")\n",
    "model_mask = df_expRun.model_arg.apply(lambda x: re_model.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[model_mask, \"model\"] = \"GPRGNN-D0.5\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2gcn_code_mask = (df_expRun.model == \"FAGCN\")\n",
    "re_model = re.compile(r\"--nhid 64 --eps [0-8\\.]* --dropout 0.5\")\n",
    "model_mask = df_expRun.model_arg.apply(lambda x: re_model.search(x) is not None) & h2gcn_code_mask\n",
    "df_expRun.loc[model_mask, \"model\"] = \"FAGCN-Tune\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(df_expRun)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_pivot = df_expRun.pivot_table(index=[\"AttackSession\", \"SessionConfig\", \"datasetName\", \"model\", \"model_arg\"], \n",
    "                      aggfunc=dict(\n",
    "                          cleanJobID=lambda x: \",\".join(x)\n",
    "                      ))\n",
    "df_expRun_pivot_flat = df_expRun_pivot.reset_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Robustness Certification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Basic Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Adapted from https://github.com/sigeisler/reliable_gnn_via_robust_aggregation/blob/main/experiment_smoothing.py\n",
    "\n",
    "def calc_certification_ratio(smoothing_result: dict, idx_selected: np.ndarray, labels: np.ndarray,\n",
    "                             mask: np.ndarray = None) -> np.ndarray:\n",
    "    \"\"\"Calculation of the certification ratio. `R(r_a, r_d)` in our paper.\n",
    "    Parameters\n",
    "    ----------\n",
    "    smoothing_result : Dict[str, Any]\n",
    "        Dictionary with smoothing results.\n",
    "    idx_selected : np.ndarray\n",
    "        Array containing the indices of e.g. the test nodes.\n",
    "    labels : np.ndarray, optional\n",
    "        Ground truth class labels.\n",
    "    mask : np.ndarray, optional\n",
    "        To select only a subset of nodes e.g. by degree, by default None.\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        Bivariate certification ratio R(r_a, r_d).\n",
    "    \"\"\"\n",
    "    grid_lower = smoothing_result['grid_lower'][idx_selected]\n",
    "    grid_upper = smoothing_result['grid_upper'][idx_selected]\n",
    "    if mask is not None:\n",
    "        grid_lower = grid_lower[mask[idx_selected]]\n",
    "        grid_upper = grid_upper[mask[idx_selected]]\n",
    "\n",
    "    correctly_classified = (smoothing_result['votes'][idx_selected].argmax(1) == labels[idx_selected])\n",
    "    if mask is not None:\n",
    "        correctly_classified = correctly_classified[mask[idx_selected]]\n",
    "    heatmap_loup = (\n",
    "        (grid_lower > grid_upper)\n",
    "        & correctly_classified[:, None, None]\n",
    "    )\n",
    "\n",
    "    heatmap_loup = heatmap_loup.mean(0)\n",
    "    heatmap_loup[0, 0] = correctly_classified.mean()\n",
    "\n",
    "    return heatmap_loup, correctly_classified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "certGridDict = dict()\n",
    "def getCertGrid(jobID, key=\"rho_grid\", get_binary_cert=True):\n",
    "    if jobID not in certGridDict:\n",
    "        job = project.open_job(id=jobID)\n",
    "        with job.data:\n",
    "            assert job.data.certGrid.type.decode() == \"separate\"\n",
    "\n",
    "            rho_grid = np.array(job.data.certGrid.rho_grid)\n",
    "            max_ra = job.data.certGrid.max_ra\n",
    "            max_rd = job.data.certGrid.max_rd\n",
    "            heatmap = (rho_grid > 0.5).mean(0)\n",
    "            \n",
    "            preVotes = np.array(job.data.preVotes)\n",
    "            votes = np.array(job.data.votes)\n",
    "        \n",
    "            if get_binary_cert:\n",
    "                if job.doc.get(\"binary_certificate_nb\", False):\n",
    "                    grid_base = np.array(job.data.binary_certificate_nb.grid_base)\n",
    "                    grid_lower = np.array(job.data.binary_certificate_nb.grid_lower)\n",
    "                    grid_upper = np.array(job.data.binary_certificate_nb.grid_upper)\n",
    "                else:\n",
    "                    conf_alpha = job.sp.conf_alpha\n",
    "                    assert job.sp.sampleConfig.votes.pf_plus_att == 0 and job.sp.sampleConfig.votes.pf_minus_att == 0\n",
    "                    pf_plus = job.sp.sampleConfig.votes.pf_plus_adj\n",
    "                    pf_minus = job.sp.sampleConfig.votes.pf_minus_adj\n",
    "                    n_samples = job.sp.sampleConfig.votes.n_samples\n",
    "\n",
    "                    grid_base, grid_lower, grid_upper = binary_certificate(votes, preVotes, n_samples, conf_alpha, pf_plus, pf_minus)\n",
    "                    job.data.binary_certificate_nb = dict(\n",
    "                        grid_base=grid_base,\n",
    "                        grid_lower=grid_lower,\n",
    "                        grid_upper=grid_upper\n",
    "                    )\n",
    "                    job.doc.binary_certificate_nb = True\n",
    "            else:\n",
    "                grid_base = None\n",
    "                grid_lower = None\n",
    "                grid_upper = None\n",
    "                \n",
    "        certGridDict[jobID] = dict(\n",
    "            rho_grid=rho_grid,\n",
    "            max_ra=max_ra,\n",
    "            max_rd=max_rd,\n",
    "            heatmap=heatmap,\n",
    "            datasetName=job.sp.datasetName,\n",
    "            preVotes=preVotes,\n",
    "            votes=votes,\n",
    "            job=job,\n",
    "            grid_base=grid_base,\n",
    "            grid_lower=grid_lower,\n",
    "            grid_upper=grid_upper\n",
    "        )\n",
    "    return certGridDict[jobID][key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pivotACRadii(task_ind_list, df=df_expRun_pivot_flat):\n",
    "    print(task_ind_list)\n",
    "    plotDfList = []\n",
    "    \n",
    "    for task_ind in task_ind_list:\n",
    "        jobItem = df.iloc[task_ind]\n",
    "\n",
    "        jobIDList = jobItem.cleanJobID.split(\",\")\n",
    "        rdDataDict = dict()\n",
    "        maxRa = 0\n",
    "        maxRd = 0\n",
    "\n",
    "        for jobID in jobIDList:\n",
    "            rho_grid = getCertGrid(jobID, key=\"rho_grid\")\n",
    "            preVotes = getCertGrid(jobID, key=\"preVotes\")\n",
    "            votes = getCertGrid(jobID, key=\"votes\")\n",
    "            datasetName = getCertGrid(jobID, key=\"datasetName\")\n",
    "            job = getCertGrid(jobID, key=\"job\")\n",
    "            with job:\n",
    "                if Path(\"data.pkl\").exists():\n",
    "                    with open(\"data.pkl\", \"rb\") as dataFile:\n",
    "                        data = pickle.load(dataFile)\n",
    "                else:\n",
    "                    data = datasetDict[datasetName]\n",
    "            \n",
    "            grid_base = getCertGrid(jobID, key=\"grid_base\")\n",
    "            grid_lower = getCertGrid(jobID, key=\"grid_lower\")\n",
    "            grid_upper = getCertGrid(jobID, key=\"grid_upper\")\n",
    "            smoothing_result = {\n",
    "                'grid_base': grid_base,\n",
    "                'grid_lower': grid_lower,\n",
    "                'grid_upper': grid_upper,\n",
    "                'votes': votes,\n",
    "                'pre_votes': preVotes\n",
    "            }\n",
    "            heatmap, preVoteCorrectMask = calc_certification_ratio(smoothing_result, data.idx_test, data.labels)\n",
    "            certResultNodes = (grid_lower > grid_upper)[data.idx_test]\n",
    "            \n",
    "            maxRa = max(getCertGrid(jobID, key=\"max_ra\"), maxRa + 1)\n",
    "            maxRd = max(getCertGrid(jobID, key=\"max_rd\"), maxRd + 1)\n",
    "            \n",
    "            certResultNodesCorrect = certResultNodes[preVoteCorrectMask, :, :]\n",
    "            radiiMat = np.zeros((certResultNodesCorrect.shape[0], 2))\n",
    "            for i in range(certResultNodesCorrect.shape[0]):\n",
    "                _cert_slice = certResultNodesCorrect[i, :, :]\n",
    "                _cert_slice[0, 0] = True\n",
    "                wRes = np.where(_cert_slice)\n",
    "                radiiMat[i, :] = (wRes[0].max(), wRes[1].max())\n",
    "            avgRadii = radiiMat.mean(0)\n",
    "            \n",
    "            rdDataDict[jobID] = (heatmap, avgRadii, preVoteCorrectMask)\n",
    "    \n",
    "        acScoreMat = np.zeros(len(rdDataDict))\n",
    "        avgRadiiMat = np.zeros((len(rdDataDict), 2))\n",
    "        avgAccMat = np.zeros(len(rdDataDict))\n",
    "        for i, (key, value) in enumerate(rdDataDict.items()): \n",
    "            value_flat = value[0].flatten()\n",
    "            ac_score = value_flat.sum() - value[0][0, 0]\n",
    "            acScoreMat[i] = ac_score\n",
    "            \n",
    "            avgRadiiMat[i, :] = value[1]\n",
    "            avgAccMat[i] = value[2].mean()\n",
    "        avgRadiiMatMean = avgRadiiMat.mean(0)\n",
    "        avgRadiiMatStd = avgRadiiMat.std(0)\n",
    "        plotDf = pandas.Series({\n",
    "            \"model\": jobItem.model,\n",
    "            \"model_arg\": jobItem.model_arg,\n",
    "            \"exp_count\": len(rdDataDict),\n",
    "            \"max_ra\": maxRa,\n",
    "            \"max_rd\": maxRd,\n",
    "            \"AC_mean\": acScoreMat.mean(),\n",
    "            \"AC_std\": acScoreMat.std(),\n",
    "            \"ra_mean\": avgRadiiMatMean[0],\n",
    "            \"ra_std\": avgRadiiMatStd[0],\n",
    "            \"rd_mean\": avgRadiiMatMean[1],\n",
    "            \"rd_std\": avgRadiiMatStd[1],\n",
    "            \"acc_mean\": avgAccMat.mean(),\n",
    "            \"acc_std\": avgAccMat.std()\n",
    "        })\n",
    "        plotDfList.append(plotDf)\n",
    "    \n",
    "    resultDf = pandas.DataFrame(data=plotDfList)\n",
    "    return resultDf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results on Multiple Seeds\n",
    "\n",
    "- Ratio of **all nodes** that are certified to be robust without conditioning on correctness\n",
    "- Ratio of **all nodes** that are **both** correctly classified and certified to be robust\n",
    "- Ratio of **correctly classified nodes** that are certified to be robust"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### LaTeX Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelHeadDict = {\n",
    "    \"H2GCN-2\": r\"\\textbf{H$_2$GCN} & \\checkmark & &\",\n",
    "    \"GraphSAGE\": r\"\\textbf{GraphSAGE} & \\checkmark & &\",\n",
    "    \"CPGNN-MLP\": r\"\\textbf{CPGNN} & \\checkmark & &\", \n",
    "    \"GPRGNN\": r\"\\textbf{GPR-GNN} & \\checkmark & &\",\n",
    "    \"FAGCN\": r\"\\textbf{FAGCN} & \\checkmark & &\", \n",
    "    \"#sep\": r\"\"\"\\noalign{\\vskip 0.25ex}\n",
    "\\cdashline{1-3}[0.8pt/2pt]\n",
    "\\cdashline{5-8}[0.8pt/2pt]\n",
    "\\cdashline{10-13}[0.8pt/2pt]\n",
    "\\noalign{\\vskip 0.25ex}\"\"\",\n",
    "    \"GAT\": r\"\\textbf{GAT} & & &\",\n",
    "    \"GCN\": r\"\\textbf{GCN} & & &\"\n",
    "}\n",
    "datasetNameStr = r\"\\multirow{modelCount}{*}{\\rotatebox[origin=c]{90}{\\textbf{datasetName}}}\"\n",
    "accuracyStdStr = r\"acc\\tiny{$\\pm$std}\"\n",
    "cellColorDict = {\n",
    "    \"1st\": r\"\\cellcolor{blue!20}\",\n",
    "    \"incomplete\": r\"\\cellcolor{red!20}\"\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, datasetTableDict, cellColorDict=None, showTable=False):\n",
    "    datasetResultDict = dict()\n",
    "    modelResultDict = dict()\n",
    "    for datasetName, table in datasetTableDict.items():\n",
    "        if showTable:\n",
    "            print(datasetName)\n",
    "            display(table)\n",
    "        table = table.loc[{key for key in modelHeadDict if not key.startswith(\"#\")}]\n",
    "        tableResultDict = {model: deque() for model in modelHeadDict}\n",
    "        if datasetName is not None:\n",
    "            dNameStr = datasetNameStr.replace(\"datasetName\", datasetName)\n",
    "            dNameStr = \"\\n\" + dNameStr.replace(\"modelCount\", str(len(modelHeadDict))) + \"\\n\"\n",
    "        else:\n",
    "            dNameStr = \"\\n\"\n",
    "        for ind, model in enumerate(modelHeadDict):\n",
    "            if model.startswith(\"#\"):\n",
    "                continue\n",
    "            for col in [\"AC\", \"ra\", \"rd\", \"acc\"]:\n",
    "                if col != \"acc\":\n",
    "                    dTable = table.round({f\"{col}_mean\": 2, f\"{col}_std\": 2})\n",
    "                else:\n",
    "                    dTable = table.round({f\"{col}_mean\": 4, f\"{col}_std\": 4})\n",
    "                \n",
    "                \n",
    "                # Mean and STD\n",
    "                if model not in dTable.index:\n",
    "                    modelResultStr = f'{accuracyStdStr.replace(\"acc\", \"nan\").replace(\"std\", \"nan\")}'\n",
    "                    if cellColorDict:\n",
    "                        modelResultStr += cellColorDict[\"incomplete\"]\n",
    "                else:\n",
    "                    mean_value = dTable.loc[model, f\"{col}_mean\"]\n",
    "                    std_value = dTable.loc[model, f\"{col}_std\"]\n",
    "                    if col in [\"ra\", \"rd\"] and mean_value == 0 and std_value == 0:\n",
    "                        modelResultStr = \"-\"\n",
    "                        mean_value = \"-\"\n",
    "                    elif col != \"acc\":\n",
    "                        modelResultStr = f'{accuracyStdStr.replace(\"acc\", f\"{mean_value:.2f}\").replace(\"std\", f\"{std_value:.2f}\")}'\n",
    "                    else:\n",
    "                        modelResultStr = f'{accuracyStdStr.replace(\"acc\", f\"{mean_value*100:.2f}\").replace(\"std\", f\"{std_value*100:.2f}\")}'\n",
    "                    if dTable.loc[model, \"exp_count\"] != 3 and cellColorDict:\n",
    "                        modelResultStr += cellColorDict[\"incomplete\"]\n",
    "                    elif mean_value == dTable[f\"{col}_mean\"].max() and cellColorDict:\n",
    "                        modelResultStr += cellColorDict[\"1st\"]\n",
    "                    \n",
    "                tableResultDict[model].append(modelResultStr)\n",
    "        \n",
    "            if ind == 0:\n",
    "                tableResultDict[model].appendleft(dNameStr)\n",
    "            else:\n",
    "                tableResultDict[model].appendleft(\"\\n\")\n",
    "        \n",
    "        datasetResultDict[datasetName] = tableResultDict\n",
    "    \n",
    "    output = \"\"\n",
    "    for model, value in modelHeadDict.items():\n",
    "        modelStr = value\n",
    "        if not model.startswith(\"#\"):\n",
    "            for ind, tableResultDict in enumerate(datasetResultDict.values()):\n",
    "                modelStr += \" & \".join(tableResultDict[model])\n",
    "                if ind != len(datasetResultDict) - 1:\n",
    "                    modelStr += \" & \"\n",
    "            modelStr += r\" \\\\\"\n",
    "        output += modelStr + \"\\n\\n\"\n",
    "    print(\"===\")\n",
    "    \n",
    "    display(Textarea(\n",
    "        value=output,\n",
    "        layout=Layout(width=\"auto\"),\n",
    "        rows=output.count(\"\\n\") + 5\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Main Paper - Top"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"cora\"]].reset_index()\n",
    "cora_df = pivotACRadii(np.where(~df_expRun_set.model.isin([\"H2GCN-1\", \"CPGNN-Cheby\"]))[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"citeseer\"]].reset_index()\n",
    "citeseer_df = pivotACRadii(np.where(~df_expRun_set.model.isin([\"H2GCN-1\", \"CPGNN-Cheby\"]))[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"Cora\": cora_df,\n",
    "    \"Citeseer\": citeseer_df\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Main Paper - Bottom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"fb100\"]].reset_index()\n",
    "fb100_df = pivotACRadii(np.where(~df_expRun_set.model.isin([\"H2GCN-1\", \"CPGNN-Cheby\"]))[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"snap-patents-downsampled\"]].reset_index()\n",
    "snap_df = pivotACRadii(np.where(~df_expRun_set.model.isin([\"H2GCN-1\", \"CPGNN-Cheby\"]))[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"FB100\": fb100_df,\n",
    "    \"Snap\": snap_df\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.0 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"cora\"]].reset_index()\n",
    "cora_df_plus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.0 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"cora\"]].reset_index()\n",
    "cora_df_minus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"Cora\": cora_df_plus,\n",
    "    None: cora_df_minus\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.0 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"citeseer\"]].reset_index()\n",
    "df_plus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.0 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", \"citeseer\"]].reset_index()\n",
    "df_minus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"Citeseer\": df_plus,\n",
    "    None: df_minus\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasetName = \"fb100\"\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.0 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", datasetName]].reset_index()\n",
    "df_plus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.0 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", datasetName]].reset_index()\n",
    "df_minus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"FB100\": df_plus,\n",
    "    None: df_minus\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasetName = \"snap-patents-downsampled\"\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.001 --pf_minus_adj 0.0 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", datasetName]].reset_index()\n",
    "df_plus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "df_expRun_set = df_expRun_pivot.loc[pandas.IndexSlice[\"SparseSmoothingSession\", \"--pf_plus_adj 0.0 --pf_minus_adj 0.4 --pf_plus_att 0.0 --pf_minus_att 0.0 --conf_alpha 0.01\", datasetName]].reset_index()\n",
    "df_minus = pivotACRadii(np.where(df_expRun_set.model != \"H2GCN-1\")[0], df=df_expRun_set).set_index(\"model\")\n",
    "\n",
    "genTeXTable(modelHeadDict, datasetNameStr, accuracyStdStr, {\n",
    "    \"Snap\": df_plus,\n",
    "    None: df_minus\n",
    "}, cellColorDict=cellColorDict, showTable=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "toc-autonumbering": false,
  "toc-showcode": false,
  "toc-showmarkdowntxt": true,
  "toc-showtags": false,
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
