{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c01b7039",
   "metadata": {},
   "source": [
    "# Table Aggregation\n",
    "\n",
    "In this notebook we aggregate the raw data into tables as used in our figures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "announced-sample",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "def aggregate(df):\n",
    "    best_of_10 = df.groupby([\"sample_id\"]).max()\n",
    "    best_of_10[\"num_full_match\"] = (best_of_10[\"consistency\"] == 1.0)\n",
    "    res = best_of_10.groupby([\"num_nodes\", \"num_networks\"]).mean()\n",
    "    res[\"num_full_match\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_full_match\"]\n",
    "    return res\n",
    "\n",
    "def aggregate_3pred(df, consistency_column=\"overall\"):\n",
    "    best_of_10 = df.groupby([\"sample_id\"]).max()\n",
    "    best_of_10[\"num_full_match\"] = (best_of_10[\"overall\"] == 1.0)\n",
    "    best_of_10[\"num_close_match\"] = (best_of_10[\"overall\"] > 0.95)\n",
    "    best_of_10[\"num_close_match90\"] = (best_of_10[\"overall\"] > 0.9)\n",
    "    best_of_10[\"num_full_fwd\"] = (best_of_10[\"fwd\"] == 1.0)\n",
    "    best_of_10[\"num_reachable\"] = (best_of_10[\"reachable\"] == 1.0) # np.logical_and((best_of_10[\"fwd\"] == 1.0).values, (best_of_10[\"reachable\"] == 1.0).values)\n",
    "    best_of_10[\"num_trafficIsolation\"] = (best_of_10[\"trafficIsolation\"] == 1.0) # np.logical_and((best_of_10[\"fwd\"] == 1.0).values, (best_of_10[\"reachable\"] == 1.0).values)\n",
    "    res = best_of_10.groupby([\"num_nodes\", \"num_networks\"]).mean()\n",
    "    res[\"num_full_match\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_full_match\"]\n",
    "    res[\"num_full_fwd\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_full_fwd\"]\n",
    "    res[\"num_reachable\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_reachable\"]\n",
    "    res[\"num_trafficIsolation\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_trafficIsolation\"]\n",
    "    res[\"num_close_match\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_close_match\"]\n",
    "    res[\"num_close_match90\"] = best_of_10.groupby([\"num_nodes\",\"num_networks\"]).sum()[\"num_close_match90\"]\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c6b2e41e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_paper(df, consistency_column=\"overall\", aggregate_by_n=False):\n",
    "    best_of_10 = df.groupby([\"sample_id\"]).max()\n",
    "    best_of_10[\"num_full_match\"] = (best_of_10[\"overall\"] == 1.0)\n",
    "    best_of_10[\"num_close_match\"] = (best_of_10[\"overall\"] > 0.95)\n",
    "    best_of_10[\"num_close_match90\"] = (best_of_10[\"overall\"] > 0.9)\n",
    "    best_of_10[\"num_full_fwd\"] = (best_of_10[\"fwd\"] == 1.0)\n",
    "    best_of_10[\"num_reachable\"] = (best_of_10[\"reachable\"] == 1.0) # np.logical_and((best_of_10[\"fwd\"] == 1.0).values, (best_of_10[\"reachable\"] == 1.0).values)\n",
    "    best_of_10[\"num_trafficIsolation\"] = (best_of_10[\"trafficIsolation\"] == 1.0) # np.logical_and((best_of_10[\"fwd\"] == 1.0).values, (best_of_10[\"reachable\"] == 1.0).values)\n",
    "    \n",
    "    num_nodes_values = list(sorted(best_of_10[\"num_nodes\"].values))\n",
    "    if len(num_nodes_values) / 3 != int(len(num_nodes_values) / 3):\n",
    "        print(\"warning: provided samples cannot be evenly distributed into 3 datasets small/medium/large\")\n",
    "    first_boundary = int(len(num_nodes_values) / 3)\n",
    "    second_boundary = int(len(num_nodes_values) * 2 / 3)\n",
    "    # print(\"first_boundary:\", num_nodes_values[first_boundary - 1])\n",
    "    # print(\"second_boundary:\", num_nodes_values[second_boundary - 1])\n",
    "    # print(\"third_boundary:\", num_nodes_values[-1])\n",
    "    # print(num_nodes_values)\n",
    "\n",
    "    best_of_10[\"group\"] = pd.cut(best_of_10[\"num_nodes\"], [0,18,39,153])\n",
    "    assert all([v == 8 for v in best_of_10.groupby(\"group\").count()[\"overall\"].values]), \"The provided samples cannot distributed into three datasets of equal size 8 with num_nodes boundaries [0, 18, 39, 153]: Instead got num_nodes values {}\".format(num_nodes_values)\n",
    "    \n",
    "    res = best_of_10.groupby(\"group\").mean()\n",
    "    std = best_of_10.groupby(\"group\").std(ddof=0)\n",
    "\n",
    "    for f in [\"overall\"]:\n",
    "        res[f + \"-std\"] = std[f] \n",
    "\n",
    "    res[\"num_full_match\"] = best_of_10.groupby(\"group\").sum()[\"num_full_match\"]\n",
    "    res[\"num_full_fwd\"] = best_of_10.groupby(\"group\").sum()[\"num_full_fwd\"]\n",
    "    res[\"num_reachable\"] = best_of_10.groupby(\"group\").sum()[\"num_reachable\"]\n",
    "    res[\"num_trafficIsolation\"] = best_of_10.groupby(\"group\").sum()[\"num_trafficIsolation\"]\n",
    "    res[\"num_close_match\"] = best_of_10.groupby(\"group\").sum()[\"num_close_match\"]\n",
    "    res[\"num_close_match90\"] = best_of_10.groupby(\"group\").sum()[\"num_close_match90\"]\n",
    "    return res\n",
    "    \n",
    "def to_latex_paper(df):\n",
    "    #print(to_latex(df))\n",
    "    lines = []\n",
    "    for row, group in zip(df.iloc, [\"S\", \"M\", \"L\"]):\n",
    "        # def group_name(g): return \"S\" if g == 0 else (\"M\" if g == 1 else \"L\")\n",
    "        # group = group_name(row[\"group\"])\n",
    "        columns = [\n",
    "            group,\n",
    "            \"%.2f\" % row[\"fwd\"],\n",
    "            \"%.2f\" % row[\"reachable\"],\n",
    "            \"%.2f\" % row[\"trafficIsolation\"],\n",
    "            \"\\\\textbf{%.2f}\" % row[\"overall\"] + \"$\\pm$%.2f\" % row[\"overall-std\"],\n",
    "            \"%d/8\" % row[\"num_full_match\"],\n",
    "            \"%d/8\" % row[\"num_close_match90\"],\n",
    "        ]\n",
    "        lines.append(\"& \" + \" & \".join(columns) + \"\\\\\\\\\")\n",
    "    return \"\\n\".join(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1bd2a2d",
   "metadata": {},
   "source": [
    "## Figure on average specification consistency for increasingly large topologies and specifications"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc9b516f",
   "metadata": {},
   "source": [
    "### 3x2, EXP_ID=eval-64-bgp-reqs-2-4shot python3 eval_consistency.py ../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt --num-samples=5 --random=1 --num-shots=4 --num-iterations=6 --dataset ./dataset-ported/bgp-qlty-reqs-2 --cpu=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c955397",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "$3\\times 2$ & S & 0.97 & 0.94 & 1.00 & \\textbf{0.96}$\\pm$0.07 & 6/8 & 6/8\\\\\n",
      "& M & 0.95 & 0.94 & 1.00 & \\textbf{0.94}$\\pm$0.08 & 5/8 & 5/8\\\\\n",
      "& L & 0.92 & 1.00 & 1.00 & \\textbf{0.94}$\\pm$0.06 & 4/8 & 4/8\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "f = \"specification-consistency/results-eval-64-bgp-qlty-reqs-2-4shot-15649.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(\"$3\\\\times 2$ \" + to_latex_paper(df_2))\n",
    "print(\"\\\\midrule\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94ea8e75",
   "metadata": {},
   "source": [
    "### 3x8, EXP_ID=eval-64-bgp-reqs-8-4shot python3 eval_consistency.py ../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt --num-samples=5 --random=1 --num-shots=4 --num-iterations=6 --dataset ./dataset-ported/bgp-qlty-reqs-8 --cpu=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7cad34c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "$3\\times 8$ & S & 0.98 & 0.98 & 0.91 & \\textbf{0.96}$\\pm$0.05 & 4/8 & 7/8\\\\\n",
      "& M & 0.97 & 0.98 & 1.00 & \\textbf{0.98}$\\pm$0.03 & 4/8 & 8/8\\\\\n",
      "& L & 0.96 & 0.92 & 0.97 & \\textbf{0.95}$\\pm$0.03 & 1/8 & 8/8\\\\\n",
      "\\midrule\n"
     ]
    }
   ],
   "source": [
    "f = \"specification-consistency/results-eval-64-bgp-qlty-reqs-8-4shot-14793.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(\"$3\\\\times 8$ \" + to_latex_paper(df_2))\n",
    "print(\"\\\\midrule\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62a0a9c6",
   "metadata": {},
   "source": [
    "### 3x16, EXP_ID=eval-64-bgp-reqs-16-4shot python3 eval_consistency.py ../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt --num-samples=5 --random=1 --num-shots=4 --num-iterations=6 --dataset ./dataset-ported/bgp-qlty-reqs-16 --cpu=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "270f895f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "$3\\times 16$ & S & 0.98 & 0.92 & 0.95 & \\textbf{0.95}$\\pm$0.03 & 2/8 & 8/8\\\\\n",
      "& M & 0.95 & 0.95 & 0.98 & \\textbf{0.96}$\\pm$0.04 & 3/8 & 7/8\\\\\n",
      "& L & 0.94 & 0.91 & 0.95 & \\textbf{0.93}$\\pm$0.05 & 1/8 & 6/8\\\\\n"
     ]
    }
   ],
   "source": [
    "f = \"specification-consistency/results-eval-64-bgp-qlty-reqs-16-4shot-11562.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(\"$3\\\\times 16$ \" + to_latex_paper(df_2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6fb0096",
   "metadata": {},
   "source": [
    "## Figure on Baseline/Oneshot/Multi-Shot Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e6f4e108",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_latex_paper_baseline(df):\n",
    "    #print(to_latex(df))\n",
    "    lines = []\n",
    "    for row in df.iloc:\n",
    "        def group_name(g): return \"Small\" if g == 0 else (\"Medium\" if g == 1 else \"Large\")\n",
    "        group = group_name(row[\"num_nodes\"])\n",
    "        columns = [\n",
    "            group,\n",
    "            \"%.2f\" % row[\"overall\"] + \"$\\pm$%.2f\" % row[\"overall-std\"],\n",
    "        ]\n",
    "        lines.append(\"& \" + \" & \".join(columns) + \"\\\\\\\\\")\n",
    "    return \"\\n\".join(lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d415a08e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random\n",
      "& Large & 0.87$\\pm$0.11\\\\\n",
      "& Large & 0.81$\\pm$0.10\\\\\n",
      "& Large & 0.80$\\pm$0.05\\\\\n",
      "1-Shot\n",
      "& Large & 0.94$\\pm$0.04\\\\\n",
      "& Large & 0.96$\\pm$0.04\\\\\n",
      "& Large & 0.93$\\pm$0.05\\\\\n",
      "4-Shot\n",
      "& Large & 0.95$\\pm$0.03\\\\\n",
      "& Large & 0.96$\\pm$0.04\\\\\n",
      "& Large & 0.93$\\pm$0.05\\\\\n"
     ]
    }
   ],
   "source": [
    "f = \"multishot-sampling/results-eval-64-bgp-qlty-reqs-16-random-13819.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(\"Random\")\n",
    "print(to_latex_paper_baseline(df_2))\n",
    "\n",
    "print(\"1-Shot\")\n",
    "f = \"multishot-sampling/results-eval-64-bgp-qlty-reqs-16-oneshot-14551.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(to_latex_paper_baseline(df_2))\n",
    "\n",
    "print(\"4-Shot\")\n",
    "f = \"specification-consistency/results-eval-64-bgp-qlty-reqs-16-4shot-11562.pkl-results.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(to_latex_paper_baseline(df_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2eafa2d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8-Shot\n",
      "& Large & 0.95$\\pm$0.04\\\\\n",
      "& Large & 0.96$\\pm$0.04\\\\\n",
      "& Large & 0.94$\\pm$0.05\\\\\n"
     ]
    }
   ],
   "source": [
    "print(\"8-Shot\")\n",
    "f = \"multishot-sampling/results-eval-64-bgp-qlty-reqs-16-8shot.pkl\"\n",
    "df_2 = aggregate_paper(pd.read_pickle(f)).reset_index()\n",
    "print(to_latex_paper_baseline(df_2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c4e471c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9c2aca3d0dfc99fb2efc8287a7808a7621f9ab3042aa12ae3d168b7497f5397d"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 ('nsynth-sub')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
