{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext jupyter_spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deeprobust.graph.data import Dataset\n",
    "from hrdataset import CustomDataset\n",
    "import pandas\n",
    "import numpy as np\n",
    "import scipy.sparse\n",
    "import networkx as nx\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import defaultdict\n",
    "import signac\n",
    "import pickle\n",
    "import itertools\n",
    "import tqdm\n",
    "\n",
    "from jupyter_spaces import get_spaces\n",
    "import jupyter_spaces\n",
    "from scipy.special import softmax\n",
    "import warnings\n",
    "import itertools\n",
    "import plotly\n",
    "import copy\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "from plotly.subplots import make_subplots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project = signac.get_project(\"../../\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acc_margin(predictions, label):\n",
    "    if predictions.ndim == 1:\n",
    "        pred_part = np.argpartition(predictions, (-1, -2))\n",
    "        max_pred_ind, second_max_pred_ind = pred_part[-1], pred_part[-2]\n",
    "        if label == max_pred_ind:\n",
    "            return True, predictions[label], predictions[label] - predictions[second_max_pred_ind]\n",
    "        else:\n",
    "            return False, predictions[label], predictions[label] - predictions[max_pred_ind]\n",
    "    elif predictions.ndim == 2:\n",
    "        pred_part = np.argpartition(predictions, (-1, -2), axis=-1)\n",
    "        max_pred_ind, second_max_pred_ind = pred_part[:, -1], pred_part[:, -2]\n",
    "        acc = (max_pred_ind == label)\n",
    "        pred = predictions[np.arange(len(predictions)), label]\n",
    "        margin = acc * (pred - predictions[np.arange(len(predictions)), second_max_pred_ind]) + \\\n",
    "            (1 - acc) * (pred - predictions[np.arange(len(predictions)), max_pred_ind])\n",
    "        return acc, pred, margin\n",
    "    else:\n",
    "        raise ValueError(f\"Unsupported dim for predictions: {predictions.ndim}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load CSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun = pandas.read_csv(\"./nettack-adj-only.csv\", index_col=0, keep_default_na=False, na_values=[\"\"])\n",
    "df_expRun = df_expRun.melt(id_vars=[col for col in df_expRun.columns if col not in {\"evasionJobID\", \"poisonJobID\"}],\n",
    "                           var_name = 'attackIDType',\n",
    "                           value_name = 'attackID')\n",
    "df_expRun = df_expRun.drop('Attack Phase', 1)\n",
    "\n",
    "na_mask = (df_expRun['attackID'] == 'N/A')\n",
    "df_expRun = df_expRun[~na_mask]\n",
    "\n",
    "incomplete_mask = (df_expRun.attackID.isnull() | df_expRun.cleanJobID.isnull())\n",
    "if incomplete_mask.sum() > 0:\n",
    "    warnings.warn(f\"{incomplete_mask.sum()} experiments are incomplete!\")\n",
    "df_expRun_Original = copy.deepcopy(df_expRun)\n",
    "df_expRun = df_expRun.loc[~incomplete_mask]\n",
    "\n",
    "df_expRun_evasion = df_expRun[df_expRun['attackIDType'] == 'evasionJobID']\n",
    "df_expRun_poison = df_expRun[df_expRun['attackIDType'] == 'poisonJobID']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_expRun_Original[incomplete_mask.values]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HETERO_DATASETS = ['fb100', 'twitch-tw', 'snap-patent-downsampled']\n",
    "HOMO_DATASETES = ['citeseer', 'cora']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evasion (Post-training Attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `evasion`\n",
    "df_subtask = df_expRun_evasion\n",
    "perturbDataDict = dict()\n",
    "defenseTableDict = dict()\n",
    "\n",
    "for tid, tdata in df_subtask.iterrows():\n",
    "    if tdata.perturbJobID not in perturbDataDict:\n",
    "        perturbJob = project.open_job(id=tdata.perturbJobID)\n",
    "        with perturbJob:\n",
    "            \n",
    "            with open(\"perturbDict.pkl\", \"rb\") as dataFile:\n",
    "                dict_pertubation = pickle.load(dataFile)\n",
    "            datasetName_ = perturbJob.sp['datasetName']\n",
    "            print(datasetName_)\n",
    "            if datasetName_ in HETERO_DATASETS:\n",
    "                with open(f\"../../datasets/data/{datasetName_}.pkl\", \"rb\") as dataFile:\n",
    "                    dataset = pickle.load(dataFile)\n",
    "                    print(dataset)\n",
    "            else:\n",
    "                with open(\"data.pkl\", \"rb\") as dataFile:\n",
    "                    dataset = pickle.load(dataFile)\n",
    "                    print(dataset)\n",
    "\n",
    "        perturbDataDict[tdata.perturbJobID] = dict(\n",
    "            dict_pertubation=dict_pertubation,\n",
    "            dataset=dataset\n",
    "        )\n",
    "    else:\n",
    "        dict_pertubation = perturbDataDict[tdata.perturbJobID][\"dict_pertubation\"]\n",
    "        dataset = perturbDataDict[tdata.perturbJobID][\"dataset\"]\n",
    "\n",
    "    # Load attacked prediction\n",
    "    job = project.open_job(id=tdata.attackID)\n",
    "    # print(f\"a:{job.id}\")\n",
    "    assert job.sp.use_runner\n",
    "    with job:\n",
    "        with job.data.open(mode=\"r\"):\n",
    "            dict_prediction = {key: np.array(val) for key, val in job.data.predictionDict.items()}\n",
    "        with open(f\"resultTable.csv\", \"r\") as f:\n",
    "            resultTable = pandas.read_csv(f, index_col=0)\n",
    "    perturb_name = tdata.perturb_prefix\n",
    "    DEFENSE_MODEL = f\"{tdata.model}_p\"\n",
    "    defenseModelType = tdata.model\n",
    "\n",
    "    clean_job = project.open_job(id=tdata.cleanJobID)\n",
    "    # print(f\"c:{clean_job.id}\")\n",
    "    with clean_job.data.open(mode=\"r\"):\n",
    "        dict_prediction_clean = {key: np.array(val) for key, val in clean_job.data.predictionDict.items()}\n",
    "    prediction_result_clean = softmax(np.array(dict_prediction_clean[f\"f:{DEFENSE_MODEL}@clean\"]), axis=1)\n",
    "\n",
    "\n",
    "    defenseResultTable = resultTable[[\"target_node\", \"acc\"]].set_index(\n",
    "                \"target_node\").rename(columns=dict(acc=\"acc_attack\"))\n",
    "    acc_clean, ground_truth_confidence_clean, margin_clean = get_acc_margin(prediction_result_clean, dataset.labels)\n",
    "    defenseResultTable[\"label\"] = dataset.labels[defenseResultTable.index]\n",
    "    defenseResultTable[\"acc_clean\"] = acc_clean[defenseResultTable.index]\n",
    "    defenseResultTable[\"pred_clean\"] = ground_truth_confidence_clean[defenseResultTable.index]\n",
    "    defenseResultTable[\"margin_clean\"] = margin_clean[defenseResultTable.index]\n",
    "\n",
    "    for cur_node in defenseResultTable.index:\n",
    "        prediction_result_attack = softmax(dict_prediction[f\"e:{DEFENSE_MODEL}@{perturb_name}_{cur_node}\"][cur_node, :])\n",
    "        acc_s, confidence_attack_s, margin_attack_s = get_acc_margin(prediction_result_attack, dataset.labels[cur_node])\n",
    "        assert defenseResultTable.at[cur_node, \"acc_attack\"] == acc_s\n",
    "        defenseResultTable.at[cur_node, \"pred_attack\"] = confidence_attack_s\n",
    "        defenseResultTable.at[cur_node, \"margin_attack\"] = margin_attack_s\n",
    "\n",
    "    defenseResultTable[\"pred_delta\"] = defenseResultTable[\"pred_attack\"] - defenseResultTable[\"pred_clean\"]\n",
    "    defenseResultTable[\"margin_delta\"] = defenseResultTable[\"margin_attack\"] - defenseResultTable[\"margin_clean\"]\n",
    "\n",
    "    for key, nodeList in job.sp.targetNodes.items():\n",
    "        defenseResultTable[f\"{key}_group\"] = defenseResultTable.index.isin(nodeList)\n",
    "    defenseTableDict[job.id] = defenseResultTable\n",
    "        \n",
    "\n",
    "# print(\"!\")\n",
    "defenseTableFull = pandas.concat(defenseTableDict.values(), keys=defenseTableDict.keys(), names=[\"attackID\"])\n",
    "display(defenseTableFull)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `evasion`\n",
    "defenseTableFullRH = defenseTableFull.sort_index().reset_index()\n",
    "defenseTableFullExp = df_subtask.merge(defenseTableFullRH,\n",
    "                        how='outer',\n",
    "                        on=['attackID'])\n",
    "defenseTableFullExp[\"model_with_arg\"] = defenseTableFullExp[\"model\"] + \":\" + defenseTableFullExp[\"model_arg\"].fillna(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `evasion`\n",
    "group_key = \"DATASET\"\n",
    "pivot_index = [\"model_with_arg\", \"perturbJobID\"]\n",
    "defensePivotDict = dict()\n",
    "\n",
    "for h in sorted(defenseTableFullExp[group_key].unique()):\n",
    "    defenseTableFullR = defenseTableFullExp.loc[defenseTableFullExp[group_key] == h]\n",
    "    defensePivot = defenseTableFullR.pivot_table(values=[\"acc_clean\", \"acc_attack\"], index=pivot_index, \n",
    "                              aggfunc={\"acc_clean\": [np.mean], \"acc_attack\": [np.mean]})\n",
    "    defensePivot[\"acc_delta\"] = defensePivot.acc_attack - defensePivot.acc_clean\n",
    "    defensePivot[\"acc_delta_rel\"] = (defensePivot.acc_attack - defensePivot.acc_clean) / defensePivot.acc_clean\n",
    "    # display(defensePivot)\n",
    "    for key2 in defensePivot.index.levels[0]:\n",
    "        \n",
    "        defensePivot.loc[(key2, 'subtotal_mean'), :] = defensePivot.loc[key2].mean(axis=0)\n",
    "        defensePivot.loc[(key2, 'subtotal_std'), :] = defensePivot.loc[key2].std(axis=0)\n",
    "\n",
    "    defensePivot2 = defenseTableFullR.pivot_table(values=[\"margin_clean\", \"margin_attack\", \"margin_delta\"], index=pivot_index, \n",
    "                              aggfunc={\"margin_clean\": [np.mean], \"margin_attack\": [np.mean], \"margin_delta\": [np.mean]})\n",
    "    for key2 in defensePivot2.index.levels[0]:\n",
    "        defensePivot2.loc[(key2, 'subtotal_mean'), :] = defensePivot2.loc[key2].mean(axis=0)\n",
    "        defensePivot2.loc[(key2, 'subtotal_std'), :] = defensePivot2.loc[key2].std(axis=0)\n",
    "\n",
    "    defensePivot = pandas.concat([defensePivot, defensePivot2], axis=1)\n",
    "    defensePivot.sort_index(inplace=True)\n",
    "    \n",
    "    defensePivotDict[h] = defensePivot.loc[pandas.IndexSlice[:, [\"subtotal_mean\", \"subtotal_std\"]], :]\n",
    "    defensePivotDict[h].columns = defensePivotDict[h].columns.droplevel(1)\n",
    "    \n",
    "    display(h)\n",
    "    defensePivotDict[h].style.format(dict(acc_attack=\"{:.2%}\", acc_clean=\"{:.2%}\", acc_delta=\"{:.2%}\", acc_delta_rel=\"{:.2%}\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Poison (Pre-training Attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `poison`\n",
    "df_subtask = df_expRun_poison\n",
    "perturbDataDict = dict()\n",
    "defenseTableDict = dict()\n",
    "\n",
    "for tid, tdata in df_subtask.iterrows():\n",
    "    if tdata.perturbJobID not in perturbDataDict:\n",
    "        perturbJob = project.open_job(id=tdata.perturbJobID)\n",
    "        # print(\"p\")\n",
    "        with perturbJob:\n",
    "            with open(\"perturbDict.pkl\", \"rb\") as dataFile:\n",
    "                dict_pertubation = pickle.load(dataFile)\n",
    "            datasetName_ = perturbJob.sp['datasetName']\n",
    "            if datasetName_ in HETERO_DATASETS:\n",
    "                with open(f\"../../datasets/data/{datasetName_}.pkl\", \"rb\") as dataFile:\n",
    "                    dataset = pickle.load(dataFile)\n",
    "                    print(dataset)\n",
    "            else:\n",
    "                with open(\"data.pkl\", \"rb\") as dataFile:\n",
    "                    dataset = pickle.load(dataFile)\n",
    "                    print(dataset)\n",
    "        perturbDataDict[tdata.perturbJobID] = dict(\n",
    "            dict_pertubation=dict_pertubation,\n",
    "            dataset=dataset\n",
    "        )\n",
    "    else:\n",
    "        dict_pertubation = perturbDataDict[tdata.perturbJobID][\"dict_pertubation\"]\n",
    "        dataset = perturbDataDict[tdata.perturbJobID][\"dataset\"]\n",
    "\n",
    "    job = project.open_job(id=tdata.attackID)\n",
    "    assert job.sp.use_runner\n",
    "    with job:\n",
    "        with job.data.open(mode=\"r\"):\n",
    "            dict_prediction = {key: np.array(val) for key, val in job.data.predictionDict.items()}\n",
    "        with open(f\"resultTable.csv\", \"r\") as f:\n",
    "            resultTable = pandas.read_csv(f, index_col=0)\n",
    "    perturb_name = tdata.perturb_prefix\n",
    "    DEFENSE_MODEL = f\"{tdata.model}_p\"\n",
    "    defenseModelType = tdata.model\n",
    "\n",
    "    clean_job = project.open_job(id=tdata.cleanJobID)\n",
    "    with clean_job.data.open(mode=\"r\"):\n",
    "        dict_prediction_clean = {key: np.array(val) for key, val in clean_job.data.predictionDict.items()}\n",
    "    prediction_result_clean = softmax(np.array(dict_prediction_clean[f\"f:{DEFENSE_MODEL}@clean\"]), axis=1)\n",
    "\n",
    "\n",
    "    defenseResultTable = resultTable[[\"target_node\", \"acc\"]].set_index(\n",
    "                \"target_node\").rename(columns=dict(acc=\"acc_attack\"))\n",
    "    acc_clean, ground_truth_confidence_clean, margin_clean = get_acc_margin(prediction_result_clean, dataset.labels)\n",
    "    defenseResultTable[\"label\"] = dataset.labels[defenseResultTable.index]\n",
    "    defenseResultTable[\"acc_clean\"] = acc_clean[defenseResultTable.index]\n",
    "    defenseResultTable[\"pred_clean\"] = ground_truth_confidence_clean[defenseResultTable.index]\n",
    "    defenseResultTable[\"margin_clean\"] = margin_clean[defenseResultTable.index]\n",
    "\n",
    "    for cur_node in defenseResultTable.index:\n",
    "        prediction_result_attack = softmax(dict_prediction[f\"p:{defenseModelType}@0@{perturb_name}_{cur_node}\"][cur_node, :])\n",
    "        acc_s, confidence_attack_s, margin_attack_s = get_acc_margin(prediction_result_attack, dataset.labels[cur_node])\n",
    "        assert defenseResultTable.at[cur_node, \"acc_attack\"] == acc_s\n",
    "        defenseResultTable.at[cur_node, \"pred_attack\"] = confidence_attack_s\n",
    "        defenseResultTable.at[cur_node, \"margin_attack\"] = margin_attack_s\n",
    "\n",
    "    defenseResultTable[\"pred_delta\"] = defenseResultTable[\"pred_attack\"] - defenseResultTable[\"pred_clean\"]\n",
    "    defenseResultTable[\"margin_delta\"] = defenseResultTable[\"margin_attack\"] - defenseResultTable[\"margin_clean\"]\n",
    "\n",
    "    for key, nodeList in job.sp.targetNodes.items():\n",
    "        defenseResultTable[f\"{key}_group\"] = defenseResultTable.index.isin(nodeList)\n",
    "    defenseTableDict[job.id] = defenseResultTable\n",
    "        \n",
    "\n",
    "# print(\"!\")\n",
    "defenseTableFull = pandas.concat(defenseTableDict.values(), keys=defenseTableDict.keys(), names=[\"attackID\"])\n",
    "display(defenseTableFull)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `poison`\n",
    "defenseTableFullRH = defenseTableFull.sort_index().reset_index()\n",
    "defenseTableFullExp = df_subtask.merge(defenseTableFullRH,\n",
    "                        how='outer',\n",
    "                        on=['attackID'])\n",
    "defenseTableFullExp[\"model_with_arg\"] = defenseTableFullExp[\"model\"] + \":\" + defenseTableFullExp[\"model_arg\"].fillna(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%space `poison`\n",
    "group_key = \"DATASET\"\n",
    "pivot_index = [\"model_with_arg\", \"perturbJobID\"]\n",
    "defensePivotDict = dict()\n",
    "\n",
    "for h in sorted(defenseTableFullExp[group_key].unique()):\n",
    "    defenseTableFullR = defenseTableFullExp.loc[defenseTableFullExp[group_key] == h]\n",
    "    defensePivot = defenseTableFullR.pivot_table(values=[\"acc_clean\", \"acc_attack\"], index=pivot_index, \n",
    "                              aggfunc={\"acc_clean\": [np.mean], \"acc_attack\": [np.mean]})\n",
    "    defensePivot[\"acc_delta\"] = defensePivot.acc_attack - defensePivot.acc_clean\n",
    "    defensePivot[\"acc_delta_rel\"] = (defensePivot.acc_attack - defensePivot.acc_clean) / defensePivot.acc_clean\n",
    "    # display(defensePivot)\n",
    "    for key2 in defensePivot.index.levels[0]:\n",
    "        \n",
    "        defensePivot.loc[(key2, 'subtotal_mean'), :] = defensePivot.loc[key2].mean(axis=0)\n",
    "        defensePivot.loc[(key2, 'subtotal_std'), :] = defensePivot.loc[key2].std(axis=0)\n",
    "\n",
    "    defensePivot2 = defenseTableFullR.pivot_table(values=[\"margin_clean\", \"margin_attack\", \"margin_delta\"], index=pivot_index, \n",
    "                              aggfunc={\"margin_clean\": [np.mean], \"margin_attack\": [np.mean], \"margin_delta\": [np.mean]})\n",
    "    for key2 in defensePivot2.index.levels[0]:\n",
    "        defensePivot2.loc[(key2, 'subtotal_mean'), :] = defensePivot2.loc[key2].mean(axis=0)\n",
    "        defensePivot2.loc[(key2, 'subtotal_std'), :] = defensePivot2.loc[key2].std(axis=0)\n",
    "\n",
    "    defensePivot = pandas.concat([defensePivot, defensePivot2], axis=1)\n",
    "    defensePivot.sort_index(inplace=True)\n",
    "    \n",
    "    defensePivotDict[h] = defensePivot.loc[pandas.IndexSlice[:, [\"subtotal_mean\", \"subtotal_std\"]], :]\n",
    "    defensePivotDict[h].columns = defensePivotDict[h].columns.droplevel(1)\n",
    "    \n",
    "    display(h)\n",
    "    defensePivotDict[h].style.format(dict(acc_attack=\"{:.2%}\", acc_clean=\"{:.2%}\", acc_delta=\"{:.2%}\", acc_delta_rel=\"{:.2%}\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
