{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import platform\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "from copy import copy\n",
    "import re\n",
    "\n",
    "# if platform.system() == 'Darwin':\n",
    "#     matplotlib.use(\"TkAgg\")\n",
    "\n",
    "import matplotlib.animation\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pprint\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "from relnet.agent.gnn.prediction_agent import *\n",
    "from relnet.io.storage import EvaluationStorage\n",
    "from relnet.io.file_paths import FilePaths\n",
    "from relnet.evaluation.eval_utils import *\n",
    "\n",
    "from relnet.state.state_generators import *\n",
    "from relnet.objective_functions.objective_functions import *\n",
    "\n",
    "graph_names=[\"Aconet\", \"Agis\", \"Arnes\", \"Cernet\", \"Cesnet201006\", \"Grnet\", \"Iij\", \"Internode\", \\\n",
    "             \"Janetlense\", \"Karen\", \"Marnet\", \"Niif\", \"PionierL3\", \"Sinet\", \"SwitchL3\", \"Ulaknet\", \"Uninett2011\"]\n",
    "# graph_names = [\"Aconet\"]\n",
    "\n",
    "suffix = \"final1d\"\n",
    "\n",
    "\n",
    "exp_types = [\"ssp\", \"ecmp\"]\n",
    "\n",
    "exp_ids = []\n",
    "\n",
    "for gn in graph_names:\n",
    "    for exp_type in exp_types:\n",
    "        exp_ids.append(f\"{gn}_{exp_type}_{suffix}\")\n",
    "    \n",
    "\n",
    "fp_out = FilePaths('/experiment_data', 'aggregate', setup_directories=True)\n",
    "storage = EvaluationStorage(fp_out)\n",
    "\n",
    "network_generator = TmGenStateGenerator\n",
    "objective_function = MLU\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Hyperparameter data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "hyps_df, hyps_dict = storage.get_grouped_hyp_data(exp_ids[0], {}, False)\n",
    "pd.set_option('display.max_rows', hyps_df.shape[0]+1)\n",
    "hyps_df.sort_values(by=['agent_name', 'avg_perf'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "hyps_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "for experiment_id in exp_ids:\n",
    "    agent_hyperparam_dfs = {}\n",
    "    param_spaces, df = storage.get_hyperparameter_optimisation_data(experiment_id, {}, train_individually=False)\n",
    "    # print(df)\n",
    "    latest_experiment = storage.get_experiment_details(experiment_id)\n",
    "    agent_names = latest_experiment[\"agents\"]\n",
    "\n",
    "    for agent_name in set(df['agent_name']):\n",
    "        expanded_data = []\n",
    "\n",
    "        subset = df[(df['agent_name'] == agent_name)]\n",
    "        subset.drop(columns=['agent_name'])\n",
    "\n",
    "        for idx, row in subset.iterrows():\n",
    "            row_copy = dict(row)\n",
    "            hyperparams_id = row['hyperparams_id']\n",
    "            hyperparams = param_spaces[objective_function.name][agent_name][hyperparams_id]\n",
    "            row_copy.update(hyperparams)\n",
    "            expanded_data.append(row_copy)\n",
    "\n",
    "        hyp_df = pd.DataFrame(expanded_data).drop(columns=['hyperparams_id'])\n",
    "        agent_hyperparam_dfs[agent_name] = hyp_df\n",
    "\n",
    "    for agent_name in agent_names:\n",
    "        if agent_name not in agent_hyperparam_dfs:\n",
    "            continue\n",
    "        hyperparams_df = agent_hyperparam_dfs[agent_name]\n",
    "        hyperparams_df.replace({False: 0, True: 1}, inplace=True)\n",
    "        \n",
    "        #         print(hyperparams_df)\n",
    "        separate_per_budgets = False\n",
    "\n",
    "\n",
    "        # print(set(hyperparams_df.columns))\n",
    "        non_hyp_cols = {\"avg_perf\", \"network_generator\", \"objective_function\", \"graph_id\", \"agent_name\",\n",
    "                       \"batch_size\", \"input_layer_heads\", \"activation_fn\", \"subgraph_agg\"}\n",
    "        hyperparam_cols = list(set(hyperparams_df.columns) - non_hyp_cols)\n",
    "        for hyperparam_name in hyperparam_cols:\n",
    "            plt.figure()\n",
    "            title = f\"{agent_name}-{network_generator.name}-{objective_function.name}-{hyperparam_name}-all\"\n",
    "            filename = f\"{experiment_id}-hyperparams-{title}.pdf\"\n",
    "            plt.title(title)\n",
    "            sns.lineplot(data=hyperparams_df, x=hyperparam_name, y=\"avg_perf\")\n",
    "            plt.savefig(fp_out.figures_dir / filename, bbox_inches='tight')\n",
    "            plt.close()            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Eval curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from relnet.objective_functions.objective_functions import *\n",
    "objective_function_display_names = {\n",
    "                                    MLU.name: \"MLU\",\n",
    "                                    }\n",
    "\n",
    "fig_dpi = 200\n",
    "\n",
    "def set_latex_if_required():\n",
    "    mpl.rcParams['text.usetex'] = True\n",
    "\n",
    "def plot_eval_histories(results_df,\n",
    "                      figure_save_path,\n",
    "                      separate_seeds=True):\n",
    "    sns.set(font_scale=2.5)\n",
    "    plt.rcParams[\"lines.linewidth\"] = 1\n",
    "    plt.rc('font', family='serif')\n",
    "    set_latex_if_required()\n",
    "\n",
    "    # dims = (16.54, 24.81)\n",
    "    # dims = (16.54, 16.54)\n",
    "\n",
    "    objs = results_df[\"objective_function\"].unique()\n",
    "\n",
    "    num_objs = len(objs)\n",
    "\n",
    "    dims = (8.26 * num_objs, 8.26)\n",
    "\n",
    "    fig, axes = plt.subplots(1, num_objs, sharex='none', sharey='none', figsize=dims, squeeze=False)\n",
    "\n",
    "    for i in range(num_objs):\n",
    "        obj = objs[i]\n",
    "        filtered_data = results_df[(results_df['objective_function'] == obj)]\n",
    "\n",
    "        filtered_data = filtered_data.rename(columns={\"timestep\": \"epoch\",\n",
    "                                                      \"perf\": \"Evaluation performance\"}).reset_index()\n",
    "\n",
    "        ax = axes[0][i]\n",
    "        ax = sns.lineplot(data=filtered_data, x=\"epoch\", y=\"Evaluation performance\",\n",
    "                          ax=ax, hue=(\"model_seed\" if separate_seeds else None))\n",
    "\n",
    "        handles, labels = ax.get_legend_handles_labels()\n",
    "\n",
    "        if i == 0:\n",
    "            ax.set_ylabel('$\\mathbf{G}^{eval}$ performance', size=\"small\")\n",
    "        else:\n",
    "            ax.set_ylabel('')\n",
    "\n",
    "            #ax.legend_.remove()\n",
    "            #ax.set_xticks(network_sizes)\n",
    "\n",
    "    pad = 2.5  # in points\n",
    "\n",
    "    rows = objs\n",
    "\n",
    "    for ax, row in zip(axes[:, 0], rows):\n",
    "        ax.annotate(f\"{objective_function_display_names[row]}\", xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),\n",
    "                    rotation=90,\n",
    "                    xycoords=ax.yaxis.label, textcoords='offset points',\n",
    "                    size='medium', ha='right', va='center')\n",
    "\n",
    "    fig.tight_layout()\n",
    "    # fig.tight_layout(rect=[0,0,1,0.90])\n",
    "    # fig.subplots_adjust(left=0.15, top=0.95)\n",
    "    fig.savefig(figure_save_path, bbox_inches='tight', dpi=fig_dpi)\n",
    "\n",
    "    # plt.show()\n",
    "    plt.close()\n",
    "    plt.rcParams[\"lines.linewidth\"] = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "separate_seeds = True\n",
    "nrows_to_skip = 5\n",
    "\n",
    "for exp_id in exp_ids:\n",
    "    latest_experiment = storage.get_experiment_details(experiment_id)\n",
    "    fp_in = FilePaths('/experiment_data', exp_id, setup_directories=False)\n",
    "    agent_names = latest_experiment[\"agents\"]\n",
    "    \n",
    "    experiment_details = storage.get_experiment_details(exp_id)\n",
    "    optimal_hyps = storage.retrieve_optimal_hyperparams(exp_id, {}, False)\n",
    "    hyp_data = storage.get_grouped_hyp_data(exp_id, {}, False)[1][objective_function.name]\n",
    "    \n",
    "    for agent_name in set(df['agent_name']):\n",
    "        opt_hyps_setting = optimal_hyps[(network_generator.name, objective_function.name, agent_name)]\n",
    "        optimal_hyps_id = int(opt_hyps_setting[1])\n",
    "        \n",
    "        all_hyp_ids = [int(hid) for hid in hyp_data[agent_name].keys()]\n",
    "        for hyp_id in all_hyp_ids:\n",
    "            \n",
    "            experiment_conditions = experiment_details['experiment_conditions']\n",
    "            steps_used = experiment_conditions['agent_budgets'][objective_function.name][agent_name]\n",
    "            model_seeds = experiment_conditions['experiment_params']['model_seeds']\n",
    "\n",
    "            try:\n",
    "                data_df = storage.fetch_all_eval_curves(agent_name, hyp_id, fp_in, [objective_function.name],\n",
    "                                                                    [network_generator.name],\n",
    "                                                                    model_seeds,\n",
    "                                                                    train_individually=False,\n",
    "                                                                    nrows_to_skip=nrows_to_skip\n",
    "                                                                )\n",
    "\n",
    "                eval_plot_filename = f'{exp_id}-eval_curves_{agent_name}-{hyp_id}{\"_OPT\" if hyp_id == optimal_hyps_id else \"\"}.pdf'\n",
    "                plot_eval_histories(data_df, fp_out.figures_dir / eval_plot_filename, separate_seeds=separate_seeds)\n",
    "            except ValueError:\n",
    "                pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3-relnet",
   "language": "python",
   "name": "relnet"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}