{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import platform\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "from copy import copy\n",
    "import re\n",
    "\n",
    "# if platform.system() == 'Darwin':\n",
    "#     matplotlib.use(\"TkAgg\")\n",
    "\n",
    "import matplotlib.animation\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.legend_handler import HandlerLine2D\n",
    "\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "pd.options.mode.chained_assignment = None  # default='warn'\n",
    "\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pprint\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "from relnet.agent.gnn.prediction_agent import *\n",
    "from relnet.io.storage import EvaluationStorage\n",
    "from relnet.io.file_paths import FilePaths\n",
    "from relnet.evaluation.eval_utils import *\n",
    "\n",
    "from relnet.state.state_generators import *\n",
    "from relnet.objective_functions.objective_functions import *\n",
    "\n",
    "graph_names=[\"Aconet\", \"Agis\", \"Arnes\", \"Cernet\", \"Cesnet201006\", \"Grnet\", \"Iij\", \"Internode\", \\\n",
    "             \"Janetlense\", \"Karen\", \"Marnet\", \"Niif\", \"PionierL3\", \"Sinet\", \"SwitchL3\", \"Ulaknet\", \"Uninett2011\"]\n",
    "# graph_names = [\"Aconet\"]\n",
    "\n",
    "suffix = \"final1d\"\n",
    "\n",
    "exp_routing_models = [\"ssp\", \"ecmp\"]\n",
    "\n",
    "\n",
    "exp_ids = []\n",
    "\n",
    "for gn in graph_names:\n",
    "    for exp_type in exp_routing_models:\n",
    "        exp_ids.append(f\"{gn}_{exp_type}_{suffix}\")\n",
    "    \n",
    "\n",
    "fp_out = FilePaths('/experiment_data', 'aggregate', setup_directories=True)\n",
    "storage = EvaluationStorage(fp_out)\n",
    "\n",
    "network_generator = TmGenStateGenerator\n",
    "objective_function = MLU\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def print_results(exp_ids, separate_by_hyp=False):\n",
    "    for exp_id in exp_ids:\n",
    "        print(\"*\" * 20)\n",
    "        print(exp_id)\n",
    "        print(\"*\" * 20)\n",
    "        results = storage.get_evaluation_data(exp_id)\n",
    "        if separate_by_hyp:\n",
    "            all_hyp_ids = set([row['hyps_id'] for row in results])\n",
    "        \n",
    "        if separate_by_hyp:\n",
    "            all_hyp_ids = set([row['hyps_id'] for row in results])\n",
    "            print(f\"Scores over all variations:\")\n",
    "            for hyps_id in all_hyp_ids:\n",
    "                filtered_results = [row for row in results if row['hyps_id'] == hyps_id]\n",
    "                print(f\"All scores, hyp id {hyps_id}\")\n",
    "                print_pred_scores(filtered_results)\n",
    "\n",
    "        else:\n",
    "            print(f\"Scores over all variations:\")\n",
    "            print_pred_scores(results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# print_results(exp_ids, separate_by_hyp=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "table_data_df = get_results_table(storage, exp_ids, filter_best_demand_rep=True)\n",
    "out_df = pd.pivot_table(table_data_df, values='perf', index=[\"routing_model\", \"graph_name\", \"metric\"], columns=[\"algorithm\"])\n",
    "pd.set_option('display.max_rows', out_df.shape[0]+1)\n",
    "out_df.style.highlight_min(color = 'lightgreen', axis = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## visualization of main results grouped by topology"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def legend_handle_update(handle, orig):\n",
    "    handle.update_from(orig)\n",
    "    handle.set_linewidth(5)\n",
    "\n",
    "normalize = True\n",
    "\n",
    "plot_df = table_data_df.copy(deep=True)\n",
    "if normalize:\n",
    "    plot_df = normalize_results(plot_df)\n",
    "    \n",
    "\n",
    "alg_display_names = {\"rgat_uniqueedge\": \"PEW (ours)\",\n",
    "                     \"rgat_uniform\": \"GAT\",\n",
    "                     \"mlp_default\": \"MLP\",\n",
    "                     \"sage_uniform\": \"GraphSAGE\",\n",
    "                     \"gcn_uniform\": \"GCN\"}\n",
    "\n",
    "plot_df['algorithm'] = plot_df['algorithm'].replace(alg_display_names)\n",
    "plot_df.algorithm = pd.Categorical(plot_df.algorithm,\n",
    "                                    categories=[\"PEW (ours)\", \"GAT\", \"MLP\", \"GraphSAGE\", \"GCN\"],\n",
    "                                   ordered=True)\n",
    "    \n",
    "sns.set(font_scale=3)\n",
    "plt.rc('font', family='serif')\n",
    "mpl.rcParams['text.usetex'] = True\n",
    "\n",
    "\n",
    "dims = (8.26 * 4, 8.26 * 2)\n",
    "\n",
    "fig, axes = plt.subplots(2, 1, figsize=dims, squeeze=False)\n",
    "\n",
    "for i, routing_model in enumerate([\"ssp\", \"ecmp\"]):\n",
    "    ax = axes[i][0]\n",
    "    ax.tick_params(axis='x', rotation=45)\n",
    "    subplot_df = plot_df[plot_df[\"routing_model\"] == routing_model]\n",
    "    #  sns.lineplot(data=subplot_df, x=\"graph_name\", hue=\"algorithm\", ax=ax, ci=None)\n",
    "    sns.barplot(data=subplot_df, x=\"graph_name\", y=\"perf\", hue=\"algorithm\", ax=ax, dodge=True)\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "    # sns.stripplot(data=subplot_df, x=\"graph_name\", y=\"perf\", hue=\"algorithm\", ax=ax, dodge=True)\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel(\"Normalized MSE\", fontsize=\"large\")\n",
    "\n",
    "    ax.set_title(['SSP', 'ECMP'][i], fontdict={\"fontsize\": \"x-large\"}, pad=-50)\n",
    "\n",
    "    if ax.legend_ is not None:\n",
    "        ax.legend_.remove()\n",
    "    \n",
    "    if i == 1:\n",
    "        ax.legend(handles[0:], labels, loc='upper right', borderaxespad=0.25, fontsize=\"medium\", ncol=1,\n",
    "           handler_map={plt.Line2D : HandlerLine2D(update_func=legend_handle_update)})\n",
    "    \n",
    "\n",
    "fig.tight_layout()\n",
    "filename = f\"primaryresults_{suffix}.pdf\"\n",
    "plt.savefig(fp_out.figures_dir / filename, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## table of ranking metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_ranking_table_rows(results_df, setting):\n",
    "    rows = []\n",
    "    \n",
    "    rdf = results_df.copy(deep=True)\n",
    "    rdf = rdf[~(rdf['algorithm'] == 'predict_mean')]\n",
    "    all_algos = list(set(rdf['algorithm']) - {'predict_mean'})\n",
    "    prdf = pd.pivot_table(rdf, values='perf', index=[\"routing_model\", \"graph_name\", \"metric\"], \n",
    "                          columns=[\"algorithm\"])\n",
    "    prdf = prdf.reset_index()\n",
    "\n",
    "    rankings = {\n",
    "        'ssp': {},\n",
    "        'ecmp': {}\n",
    "    }\n",
    "\n",
    "\n",
    "    for row in prdf.itertuples():\n",
    "        model = getattr(row, 'routing_model')\n",
    "\n",
    "        algo_perfs = []\n",
    "        for algo in all_algos:\n",
    "            algo_perfs.append(getattr(row, algo))\n",
    "\n",
    "\n",
    "        perfs_arr = np.array(algo_perfs)\n",
    "\n",
    "        perfs_ranked = np.argsort(np.argsort(perfs_arr)) + np.ones(len(algo_perfs))\n",
    "        #     print(algo_perfs)\n",
    "        #     print(perfs_ranked)\n",
    "\n",
    "        for i, algo in enumerate(all_algos):\n",
    "            if algo not in rankings[model]:\n",
    "                rankings[model][algo] = []\n",
    "\n",
    "            rankings[model][algo].append(perfs_ranked[i])\n",
    "\n",
    "\n",
    "\n",
    "    for routing, routing_ranks in rankings.items():\n",
    "        print(f\"routing model {routing}\")\n",
    "        for algo in all_algos:\n",
    "            rr = routing_ranks[algo]\n",
    "\n",
    "\n",
    "            num_wins = len( np.isclose(rr, np.array(1.0)).nonzero()[0] )\n",
    "            wr = (num_wins / len(graph_names)) * 100\n",
    "            mrr = np.mean(np.array(1.0) / rr)\n",
    "        \n",
    "            rows.append({\"routing_model\":  routing,\n",
    "             \"algorithm\": algo,\n",
    "             \"metric\": \"wr\",\n",
    "             \"value\": wr,\n",
    "             \"setting\": setting})\n",
    "            \n",
    "            rows.append({\"routing_model\":  routing,\n",
    "             \"algorithm\": algo,\n",
    "             \"metric\": \"mrr\",\n",
    "             \"value\": mrr,\n",
    "             \"setting\": setting})\n",
    "        \n",
    "    \n",
    "    \n",
    "    return rows\n",
    "\n",
    "    \n",
    "all_rows = []\n",
    "suffixes = [\"final1d\", \"topvarfinal1d\"]\n",
    "\n",
    "settings = [\"original graph\", \"topology variations\"]\n",
    "\n",
    "for i, suffix in enumerate(suffixes):\n",
    "    exp_ids = []\n",
    "\n",
    "    for gn in graph_names:\n",
    "        for exp_type in exp_routing_models:\n",
    "            exp_ids.append(f\"{gn}_{exp_type}_{suffix}\")\n",
    "        \n",
    "    results = get_results_table(storage, exp_ids, filter_best_demand_rep=True)\n",
    "    rankings = get_ranking_table_rows(results, settings[i])\n",
    "    all_rows.extend(rankings)\n",
    "\n",
    "rdf = pd.DataFrame(all_rows)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "out_df = pd.pivot_table(rdf, values=['value'], index=[\"routing_model\", \"metric\", \"setting\"], columns=[\"algorithm\"])\n",
    "out_df.style.format(\"{:.3f}\")\n",
    "out_df = out_df.sort_values(by=[\"routing_model\"], ascending=[False], axis=0)\n",
    "out_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "replace_dict = {\n",
    "    \"mrr\": \"MRR $\\\\uparrow$\",\n",
    "    \"wr\": \"WR $\\\\uparrow$\",\n",
    "    \"ssp\": \"SSP\",\n",
    "    \"ecmp\": \"ECMP\",\n",
    "    \"mlp\\\\_default\": \"MLP\",\n",
    "    \"rgat\\\\_uniform\": \"GAT\",\n",
    "    \"rgat\\\\_uniqueedge\": \"PEW (ours)\",\n",
    "    \"gcn\\\\_uniform\": \"GCN\",\n",
    "    \"sage\\\\_uniform\": \"GraphSAGE\",\n",
    "\n",
    "    \"routing\\\\_model\": \"$\\\\mathscr{R}$\",\n",
    "\n",
    "    \"original graph\": \"Original\",\n",
    "    \"topology variations\": \"Variations\",\n",
    "    \"setting\": \"\",\n",
    "    \"algorithm\": \"\",\n",
    "    \"value\": \"\",\n",
    "    \"metric\": \"Metric\"\n",
    "}\n",
    "\n",
    "\n",
    "latex_string = str(out_df.to_latex(float_format=\"{:0.3f}\".format))\n",
    "for k, v in replace_dict.items():\n",
    "    latex_string = latex_string.replace(k, v)\n",
    "\n",
    "# print(latex_string)\n",
    "# print(latex_string)\n",
    "latex_string\n",
    "\n",
    "table_filename = \"topology_variations.tex\"\n",
    "with open(fp_out.figures_dir / table_filename, \"w\") as fh:\n",
    "    fh.write(latex_string)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## demand representation experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "graph_names=[\"Aconet\", \"Agis\", \"Arnes\", \"Cernet\", \"Cesnet201006\", \"Grnet\", \"Iij\", \"Internode\", \\\n",
    "             \"Janetlense\", \"Karen\", \"Marnet\", \"Niif\", \"PionierL3\", \"Sinet\", \"SwitchL3\", \"Ulaknet\", \"Uninett2011\"]\n",
    "# graph_names=[\"Aconet\"]\n",
    "\n",
    "methods_keep = [\"uniqueedge\", \"uniform\"]\n",
    "\n",
    "dms_root = \"final\"\n",
    "dms_mults = [0.05, 0.1, 0.25, 0.5, 1]\n",
    "dms_suffixes = [\"0d05\", \"0d1\", \"0d25\", \"0d5\", \"1d\"]\n",
    "exp_routing_models = [\"ssp\", \"ecmp\"]\n",
    "\n",
    "dmrep_exp_ids = []\n",
    "for gn in graph_names:\n",
    "    for et in exp_routing_models:\n",
    "        for ds in dms_suffixes:\n",
    "            dmrep_exp_ids.append(f\"{gn}_{et}_{dms_root}{ds}\")\n",
    "        \n",
    "    \n",
    "dmrep_full_df = get_results_table(storage, dmrep_exp_ids, filter_best_demand_rep=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "dmrep_dfs = []\n",
    "for method_keep in methods_keep:\n",
    "    dmrep_df = dmrep_full_df.drop(dmrep_full_df[~dmrep_full_df.algorithm.str.contains(method_keep)].index)\n",
    "    # raw will come before sum after sort.\n",
    "    dm_alg_names = sorted(list(set(dmrep_df['algorithm'])))\n",
    "    \n",
    "    dmrep_pivot = pd.pivot_table(dmrep_df, values='perf', \n",
    "                             index=[\"graph_name\", \"routing_model\", \"dms_mult\", \"agent_seed\"], \n",
    "                             columns=[\"algorithm\"])\n",
    "    dmrep_pivot = dmrep_pivot.reset_index()\n",
    "    \n",
    "    dmrep_pivot['raw_minus_summed'] = (dmrep_pivot[dm_alg_names[0]] - dmrep_pivot[dm_alg_names[1]]) / \\\n",
    "                                 (dmrep_pivot[dm_alg_names[0]] + dmrep_pivot[dm_alg_names[1]])\n",
    "    dmrep_pivot = dmrep_pivot.drop(columns=dm_alg_names)\n",
    "    dmrep_pivot['method'] = method_keep\n",
    "    dmrep_dfs.append(dmrep_pivot)\n",
    "    print(dmrep_pivot)\n",
    "\n",
    "dmrep_df_final = pd.concat(dmrep_dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def legend_handle_update(handle, orig):\n",
    "    handle.update_from(orig)\n",
    "    handle.set_linewidth(15)\n",
    "\n",
    "    \n",
    "alg_display_names = {\"uniqueedge\": \"PEW (ours)\", \"uniform\": \"GAT\"}\n",
    "dmrep_df_final['method'] = dmrep_df_final['method'].replace(alg_display_names)\n",
    "\n",
    "\n",
    "sns.set(font_scale=6.5)\n",
    "plt.rc('font', family='serif')\n",
    "mpl.rcParams['text.usetex'] = True\n",
    "mpl.rcParams[\"lines.linewidth\"] = 5\n",
    "mpl.rcParams[\"lines.markersize\"] = 25\n",
    "dims = (8.26 * 4, 8.26 * 3.5)\n",
    "fig, axes = plt.subplots(1, 2, figsize=dims, squeeze=False, sharey=True)\n",
    "\n",
    "for i, routing_model in enumerate(exp_routing_models):\n",
    "    ax = axes[0][i]\n",
    "    sp_df = dmrep_df_final[dmrep_df_final['routing_model']==routing_model]\n",
    "\n",
    "    sp_df[\"dms_mult\"] = sp_df[\"dms_mult\"] * 1000\n",
    "    sp_df[\"raw_minus_summed\"] = (sp_df[\"raw_minus_summed\"])\n",
    "    sns.lineplot(data=sp_df, x=\"dms_mult\", y=\"raw_minus_summed\", hue=\"method\", ax=ax, marker=\"o\")\n",
    "    \n",
    "    ax.axhline(y=0, color='k', linestyle='--')\n",
    "    \n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    \n",
    "    if ax.legend_ is not None:\n",
    "        ax.legend_.remove()\n",
    "        \n",
    "    \n",
    "        \n",
    "    raw_x, sum_x = (50, 50)\n",
    "    raw_y, sum_y = ((0.10, -0.22))\n",
    "    ax.text(raw_x, raw_y, \"\\\\textit{raw} \\n demands \\n better\", fontsize=\"large\", rotation=90)\n",
    "    ax.text(sum_x, sum_y, \"\\\\textit{sum} \\n demands \\n better\", fontsize=\"large\", rotation=90)\n",
    "    \n",
    "    ax.set_title(['SSP', 'ECMP'][i], fontdict={\"fontsize\": \"x-large\"})\n",
    "    ax.set_xlabel(\"\\# of training DMs\", fontsize=\"small\")\n",
    "    ax.set_ylabel(\"\\\\textit{raw} NMSE $-$ \\\\textit{sum} NMSE\", fontsize=\"large\")\n",
    "    \n",
    "    ax.tick_params(axis='x', which='major', labelsize=56)\n",
    "    ax.tick_params(axis='x', which='minor', labelsize=40)\n",
    "    \n",
    "\n",
    "fig.legend(handles[0:], labels, loc='upper center', borderaxespad=-0.05, fontsize=\"medium\", ncol=2,\n",
    "           handler_map={plt.Line2D : HandlerLine2D(update_func=legend_handle_update)})\n",
    "# fig.tight_layout()\n",
    "\n",
    "plt.subplots_adjust(wspace=0.05, \n",
    "                    hspace=0.05)\n",
    "\n",
    "\n",
    "filename = f\"dmrep_all.pdf\"\n",
    "plt.savefig(fp_out.figures_dir / filename, bbox_inches='tight')\n",
    "\n",
    "\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3-relnet",
   "language": "python",
   "name": "relnet"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}