{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d850bd96",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Rough plan\n",
    "\n",
    "\n",
    "- retrieve data points from primarynovar experiment. for each graph, get datapoints with the eval performance.\n",
    "    + move some of the result processing code to the codebase, shared between notebooks by now\n",
    "- retrieve the topologies of the graphs\n",
    "- compute some properties of interest and set as columns on the dataframe\n",
    "- for each property:\n",
    "    + create a new plot: ssp on the left, ecmp on the right\n",
    "    + x-axis: (normalised) difficulty of predicting link utilization\n",
    "    + y-axis: the property\n",
    "    + individual points should have the name of the graph written so it's easier to spot patterns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ae8a71",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import platform\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "from copy import copy\n",
    "import re\n",
    "\n",
    "# if platform.system() == 'Darwin':\n",
    "#     matplotlib.use(\"TkAgg\")\n",
    "\n",
    "import matplotlib.animation\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set()\n",
    "import pandas as pd\n",
    "pd.options.mode.chained_assignment = None  # default='warn'\n",
    "\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pprint\n",
    "import networkx as nx\n",
    "\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "from relnet.agent.gnn.prediction_agent import *\n",
    "from relnet.io.storage import EvaluationStorage\n",
    "from relnet.io.file_paths import FilePaths\n",
    "from relnet.evaluation.eval_utils import *\n",
    "\n",
    "from relnet.state.state_generators import *\n",
    "from relnet.objective_functions.objective_functions import *\n",
    "\n",
    "graph_names=[\"Aconet\", \"Agis\", \"Arnes\", \"Cernet\", \"Cesnet201006\", \"Grnet\", \"Iij\", \"Internode\", \\\n",
    "             \"Janetlense\", \"Karen\", \"Marnet\", \"Niif\", \"PionierL3\", \"Sinet\", \"SwitchL3\", \"Ulaknet\", \"Uninett2011\"]\n",
    "\n",
    "# graph_names = [\"Aconet\"]\n",
    "\n",
    "exp_routing_models = [\"ssp\", \"ecmp\"]\n",
    "\n",
    "suffix = \"final1d\"\n",
    "# suffix = \"final0d1\"\n",
    "\n",
    "exp_ids = []\n",
    "\n",
    "for gn in graph_names:\n",
    "    for exp_type in exp_routing_models:\n",
    "        exp_ids.append(f\"{gn}_{exp_type}_{suffix}\")\n",
    "    \n",
    "\n",
    "fp_out = FilePaths('/experiment_data', 'aggregate', setup_directories=True)\n",
    "storage = EvaluationStorage(fp_out)\n",
    "\n",
    "network_generator = TmGenStateGenerator\n",
    "objective_function = MLU\n",
    "\n",
    "topological_properties = [\"num_nodes\", \"num_edges\", \"diameter\", \"edge_density\", \"num_flows\",\n",
    "                          \"capacity_var\", \"degree_var\", \"wbetweenness_var\"] # \"betweenness_var\"\n",
    "\n",
    "\n",
    "properties_plot = [\"num_nodes\", \"diameter\", \"edge_density\",\n",
    "                          \"capacity_var\", \"degree_var\", \"wbetweenness_var\"] # \"betweenness_var\"\n",
    "\n",
    "properties_table = [\"num_nodes\", \"num_edges\", \"diameter\", \"edge_density\", \"num_flows\"]\n",
    "\n",
    "\n",
    "keep_all_seeds = False\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "table_data_df = get_results_table(storage, exp_ids, filter_best_demand_rep=True)\n",
    "table_data_df = normalize_results(table_data_df)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# table_data_df"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def rgb2hex(r,g,b):\n",
    "    return \"#{:02x}{:02x}{:02x}\".format(r,g,b)\n",
    "\n",
    "\n",
    "def compute_property(graph_state, prop_name):\n",
    "    if prop_name == \"num_nodes\":\n",
    "        return graph_state.num_nodes\n",
    "    elif prop_name == \"num_edges\":\n",
    "        return graph_state.num_edges\n",
    "    elif prop_name == \"edge_density\":\n",
    "        return graph_state.num_edges / graph_state.num_nodes\n",
    "    elif prop_name == \"diameter\":\n",
    "        g_nx = graph_state.to_nx_graph()\n",
    "        return nx.diameter(g_nx)\n",
    "    elif prop_name == \"num_flows\":\n",
    "        return (graph_state.num_nodes * graph_state.num_nodes) * 3000\n",
    "    elif prop_name == \"capacity_var\":\n",
    "        capacities = np.array([graph_state.get_edge_property(e, GraphState.CAPACITY_EPROP_NAME) for e in graph_state.edge_list])\n",
    "        cap_var = np.var(capacities)\n",
    "        return cap_var\n",
    "    elif prop_name == \"degree_var\":\n",
    "        g_nx = graph_state.to_nx_graph()\n",
    "        deg_centralities = np.array([v for n, v in nx.degree_centrality(g_nx).items()])\n",
    "        deg_var = np.var(deg_centralities)\n",
    "        return deg_var\n",
    "    elif prop_name == \"betweenness_var\":\n",
    "        g_nx = graph_state.to_nx_graph()\n",
    "        bw_centralities = np.array([v for n, v in nx.betweenness_centrality(g_nx).items()])\n",
    "        bw_var = np.var(bw_centralities)\n",
    "        return bw_var\n",
    "    elif prop_name == \"wbetweenness_var\":\n",
    "        g_nx = graph_state.to_nx_graph()\n",
    "        capacities_dict = {e: graph_state.get_edge_property(e, GraphState.CAPACITY_EPROP_NAME) for e in graph_state.edge_list}\n",
    "        nx.set_edge_attributes(g_nx, capacities_dict, \"weight\")\n",
    "        wbw_centralities = np.array([v for n, v in nx.betweenness_centrality(g_nx, weight=\"weight\").items()])\n",
    "        wbw_var = np.var(wbw_centralities)\n",
    "        return wbw_var"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plot_targets = [\"normalized_mse\", \"rgat_to_pew\", \"mlp_to_pew\", \"sage_to_pew\", \"gcn_to_pew\"]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "agg_index = [\"graph_name\", \"routing_model\", \"dms_mult\"] + ([\"agent_seed\"] if keep_all_seeds else [])\n",
    "agg_df = pd.pivot_table(table_data_df, values='perf',\n",
    "                             index=agg_index,\n",
    "                             columns=[\"algorithm\"])\n",
    "agg_df = agg_df.reset_index()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "agg_df[\"normalized_mse\"] = agg_df[\"rgat_uniqueedge\"].copy()\n",
    "agg_df[\"rgat_to_pew\"] = (agg_df[\"rgat_uniqueedge\"] - agg_df[\"rgat_uniform\"]) / (agg_df[\"rgat_uniform\"]) * 100.\n",
    "agg_df[\"mlp_to_pew\"] = (agg_df[\"rgat_uniqueedge\"] - agg_df[\"mlp_default\"]) / (agg_df[\"mlp_default\"]) * 100.\n",
    "agg_df[\"sage_to_pew\"] = (agg_df[\"rgat_uniqueedge\"] - agg_df[\"sage_uniform\"]) / (agg_df[\"sage_uniform\"]) * 100.\n",
    "agg_df[\"gcn_to_pew\"] = (agg_df[\"rgat_uniqueedge\"] - agg_df[\"gcn_uniform\"]) / (agg_df[\"gcn_uniform\"]) * 100.\n",
    "agg_df = agg_df.drop(columns=[\"mlp_default\", \"rgat_uniform\", \"rgat_uniqueedge\", \"sage_uniform\", \"gcn_uniform\", \"dms_mult\"])\n",
    "\n",
    "agg_df"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from relnet.state.graph_dataset import GraphDataset\n",
    "\n",
    "props_dict = {gn: {} for gn in graph_names}\n",
    "\n",
    "for exp_id in exp_ids:\n",
    "    fp_in = FilePaths('/experiment_data', exp_id, setup_directories=False)\n",
    "    gn, routing_model, _ = tuple(exp_id.split(\"_\"))\n",
    "    gds = GraphDataset(fp_in, gn, objective_function.name, network_generator.name)\n",
    "    orig_G = gds.load_graph_file(gds.metadata_dict['global_metadata']['original_graph_hash'])\n",
    "    \n",
    "    for prop in topological_properties:\n",
    "        val = compute_property(orig_G, prop)\n",
    "        props_dict[gn][prop] = val\n",
    "\n",
    "# props_dict"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for prop in topological_properties:\n",
    "    agg_df[prop] = agg_df.apply(lambda x: props_dict[x['graph_name']][prop], axis=1)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "agg_df"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "show_R = False\n",
    "\n",
    "sns.set(font_scale=6.5)\n",
    "plt.rc('font', family='serif')\n",
    "mpl.rcParams['text.usetex'] = True\n",
    "mpl.rcParams[\"lines.linewidth\"] = 5\n",
    "mpl.rcParams[\"lines.markersize\"] = 72\n",
    "\n",
    "display_names = {\"num_nodes\" : \"number of \\n nodes $N$\",\n",
    "                \"diameter\": \"diameter\",\n",
    "                \"edge_density\": \"edge \\n density $\\\\frac{m}{N}$\",\n",
    "                \"capacity_var\": \"capacity \\n variance\",\n",
    "                \"degree_var\": \"degree \\n variance\",\n",
    "                \"wbetweenness_var\": \"weighted betweenness \\n variance\"}\n",
    "\n",
    "palettes = {\n",
    "    \"normalized_mse\": sns.color_palette([rgb2hex(89,116,162)]), # deep\n",
    "    \"rgat_to_pew\": sns.color_palette([rgb2hex(203, 138, 102)]), # gist_earth\n",
    "    \"mlp_to_pew\": sns.color_palette([rgb2hex(96, 157, 111)]),\n",
    "    \"sage_to_pew\": sns.color_palette([rgb2hex(180, 95, 98)]),\n",
    "    \"gcn_to_pew\": sns.color_palette([rgb2hex(133, 122, 169)])\n",
    "}\n",
    "ylims = {\n",
    "    \"normalized_mse\": [-0.1, 1.0],\n",
    "    \"rgat_to_pew\": None,\n",
    "    \"mlp_to_pew\": None,\n",
    "    \"sage_to_pew\": None,\n",
    "    \"gcn_to_pew\": None\n",
    "}\n",
    "\n",
    "metric_displays = {\n",
    "    \"normalized_mse\": \"\\n PEW NMSE\",\n",
    "    \"rgat_to_pew\": \"\\% change \\n GAT to PEW\",\n",
    "    \"mlp_to_pew\": \"\\% change \\n MLP to PEW\",\n",
    "    \"sage_to_pew\": \"\\% change \\n GraphSAGE to PEW\",\n",
    "    \"gcn_to_pew\": \"\\% change \\n GCN to PEW\",\n",
    "}\n",
    "\n",
    "plot_df = agg_df.rename(columns=display_names)\n",
    "\n",
    "# print(plot_df)\n",
    "\n",
    "for plot_target in plot_targets:\n",
    "    dims = (8.26 * 1.5 * len(properties_plot), 8.26 * 3.5)\n",
    "    fig, axes = plt.subplots(2, len(properties_plot), figsize=dims, squeeze=False, sharey=True, sharex=False)\n",
    "    for j, prop in enumerate(display_names.values()):\n",
    "        for i, routing_model in enumerate(exp_routing_models):\n",
    "            ax = axes[i][j]\n",
    "\n",
    "            sp_df = plot_df[plot_df['routing_model']==routing_model]\n",
    "            sp_df = sp_df.reset_index()\n",
    "\n",
    "            sp_df['dummy_var'] = [1.] * len(sp_df)\n",
    "    #         print(sp_df)\n",
    "\n",
    "            if plot_target != \"normalized_mse\":\n",
    "                ax.axhline(y=0, color='k', linestyle='--')\n",
    "\n",
    "                if j == len(properties_plot) - 1:\n",
    "                    first_x, second_x = 0.035, 0.035\n",
    "                    first_y, second_y = 10, -80\n",
    "\n",
    "                    if plot_target == \"rgat_to_pew\":\n",
    "                        ax.text(first_x, first_y, \"GAT \\n better\", fontsize=\"small\", rotation=90)\n",
    "                    elif plot_target == \"mlp_to_pew\":\n",
    "                        ax.text(first_x, first_y, \"MLP \\n better\", fontsize=\"small\", rotation=90)\n",
    "                    elif plot_target == \"sage_to_pew\":\n",
    "                        ax.text(first_x, first_y, \"SAGE \\n better\", fontsize=\"small\", rotation=90)\n",
    "                    elif plot_target == \"gcn_to_pew\":\n",
    "                        ax.text(first_x, first_y, \"GCN \\n better\", fontsize=\"small\", rotation=90)\n",
    "\n",
    "                    ax.text(second_x, second_y, \"PEW \\n better\", fontsize=\"small\", rotation=90)\n",
    "\n",
    "            sns.scatterplot(data=sp_df, x=prop, y=plot_target, ax=ax, style=\"graph_name\", palette=palettes[plot_target], hue=\"dummy_var\")\n",
    "\n",
    "            handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "            if i == 0:\n",
    "                # ax.set_xticks([])\n",
    "                ax.set_xlabel('')\n",
    "\n",
    "            if i == 0 and j == 0:\n",
    "                ax.set_ylabel(f'SSP {metric_displays[plot_target]}')\n",
    "            elif i == 1 and j == 0:\n",
    "                ax.set_ylabel(f'ECMP {metric_displays[plot_target]}')\n",
    "\n",
    "            ax.legend_.remove()\n",
    "\n",
    "            if ylims[plot_target] is not None:\n",
    "                ax.set_ylim(ylims[plot_target])\n",
    "\n",
    "\n",
    "\n",
    "            if show_R:\n",
    "                pearson_r = sp.stats.pearsonr(sp_df[prop], sp_df[plot_target])\n",
    "                if keep_all_seeds:\n",
    "                    r_text = f\"$R^2$: {pearson_r[0]:.3f}\\n p-val: {pearson_r[1]:.3f}\"\n",
    "                else:\n",
    "                    r_text = f\"$R^2$: {pearson_r[0]:.3f}\"\n",
    "                ax.text(0.8, 0.9, r_text, ha='center', va='center', transform=ax.transAxes)\n",
    "            # ax.title.set_text(routing_model.upper())\n",
    "\n",
    "            #         if not keep_all_seeds:\n",
    "            #             for line in range(0,sp_df.shape[0]):\n",
    "            #                 x = getattr(sp_df, prop)[line]\n",
    "            #                 y = getattr(sp_df, plot_target)[line]\n",
    "\n",
    "            #                 max_x = sp_df[prop].max()\n",
    "            #                 max_y = sp_df[plot_target].max()\n",
    "\n",
    "            #                 graph_name = sp_df.graph_name[line]\n",
    "            #                 ax.text(x - ((max_x / 125) * len(graph_name)),\n",
    "            #                         y + (max_y / 12.5),\n",
    "            #                         graph_name,\n",
    "            #                         horizontalalignment='left', size=32, color='black', weight='semibold') # 'xx-small'\n",
    "\n",
    "\n",
    "    bap = 0.85 if plot_target == \"normalized_mse\" else 0.25\n",
    "    fig.legend(handles[3:], labels[3:], loc='upper center', borderaxespad=bap, fontsize=\"small\", ncol=6)\n",
    "    # fig.tight_layout()\n",
    "    plt.subplots_adjust(wspace=0.05,\n",
    "                        hspace=0.15)\n",
    "\n",
    "    filename = f\"topvspredictability_{suffix}_{plot_target}_{keep_all_seeds}.pdf\"\n",
    "    plt.savefig(fp_out.figures_dir / filename, bbox_inches='tight')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "table_rows = []\n",
    "\n",
    "for k, v in props_dict.items():\n",
    "    row = {}\n",
    "    row['graph_name'] = k\n",
    "    for pname, value in v.items():\n",
    "        if pname in properties_table:\n",
    "            row[pname] = value\n",
    "    table_rows.append(row)\n",
    "    \n",
    "props_df = pd.DataFrame(table_rows)\n",
    "# table_df.reset_index(drop=True, inplace=True)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "props_df"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "filename = 'graph_metadata.tex'\n",
    "texfile =  str(fp_out.figures_dir / filename)\n",
    "fh = open(texfile, 'w')\n",
    "table_colformat = f\"lrrrrr\"\n",
    "props_df.to_latex(buf=fh, column_format=table_colformat, index=False, float_format=\"%.2f\")\n",
    "fh.close()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "replace_dict = {\n",
    "        r\"graph\\\\_name\": r\"Graph\",\n",
    "        r\"num\\\\_nodes\": r\"$N$\",\n",
    "        r\"num\\\\_edges\": r\"$m$\",\n",
    "        r\"diameter\": r\"Diam.\",\n",
    "        r\"edge\\\\_density\": r\"$\\\\frac{m}{N}$\",\n",
    "        r\"num\\\\_flows\": r\"Flows in $\\\\mathcal{D}$\",\n",
    "    }\n",
    "\n",
    "with open(texfile, 'r') as f:\n",
    "    raw_content = f.read()\n",
    "\n",
    "processed_content = raw_content\n",
    "for orig, targ in replace_dict.items():\n",
    "    processed_content = re.sub(orig, targ, processed_content, flags = re.M)\n",
    "\n",
    "with open(texfile, 'w') as g:\n",
    "    g.write(processed_content)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}