{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "    \n",
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import seaborn as sns\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "sns.set_context(\"paper\", font_scale=2)\n",
    "sns.set_style(\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_PATH = Path(\"path/to/the/folder/hephaestus/is/located\")\n",
    "\n",
    "BASE_EXPERIMENT_DIRECTORY = Path(\"path/to/where/model_storage/models/is/located\")\n",
    "MODELS_RESULTS = os.path.join(\"model_storage\", \"models\")\n",
    "\n",
    "SAVE_PATH = BASE_PATH / \"plots\" / \"compare_models\"\n",
    "SAVE_PATH.mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir(BASE_PATH)\n",
    "sys.path.insert(0, BASE_PATH)\n",
    "\n",
    "import hephaestus.utils.general_utils as hutils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_map = {\n",
    "    \"d_gcn\": [\"TALOS_20240110-163351\", \"TAUROI_KHALKEOI_20231231-004707\"],\n",
    "    \"d_gat\": [\"TALOS_20240113-155714\", \"TAUROI_KHALKEOI_20240101-170302\"],\n",
    "    \"d_sage\": [\"TALOS_20240119-150636\", \"TAUROI_KHALKEOI_20240108-205055\"],\n",
    "    \"d_gin\": [\"TALOS_20240121-203233\", \"TAUROI_KHALKEOI_20240106-193249\"],\n",
    "    \"nd_gin\": [\"TALOS_20240129-182636\", \"TALOS_20240201-180351\"],\n",
    "    \"nd_gat\": [\"TALOS_20240203-234629\", \"TALOS_20240209-190133\"],\n",
    "    \"nd_gcn\": [\"TALOS_20240212-124657\"],\n",
    "    \"nd_sage\": [\"TALOS_20240216-185022\"],\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_map_2 = {}\n",
    "for k in experiment_map.keys():\n",
    "    for v in experiment_map[k]:\n",
    "        experiment_map_2[v] = k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_map_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_stats = {\n",
    "    \"d_gcn\": 0,\n",
    "    \"d_gat\": 0,\n",
    "    \"d_sage\": 0,\n",
    "    \"d_gin\": 0,\n",
    "    \"nd_gin\": 0,\n",
    "    \"nd_gat\": 0,\n",
    "    \"nd_gcn\": 0,\n",
    "    \"nd_sage\": 0,\n",
    "}\n",
    "\n",
    "experiment_progress = {\n",
    "    \"d_gcn\": [],\n",
    "    \"d_gat\": [],\n",
    "    \"d_sage\": [],\n",
    "    \"d_gin\": [],\n",
    "    \"nd_gin\": [],\n",
    "    \"nd_gat\": [],\n",
    "    \"nd_gcn\": [],\n",
    "    \"nd_sage\": [], \n",
    "}\n",
    "\n",
    "for experiment in os.listdir(BASE_EXPERIMENT_DIRECTORY/MODELS_RESULTS):\n",
    "    trial_cnt = 0\n",
    "    for trial in os.listdir(BASE_EXPERIMENT_DIRECTORY/MODELS_RESULTS/experiment):\n",
    "        if \"TorchTrainer\" not in trial:\n",
    "            continue\n",
    "        \n",
    "        try:\n",
    "            df = pd.read_csv(BASE_EXPERIMENT_DIRECTORY/MODELS_RESULTS/experiment/trial/\"progress.csv\")\n",
    "            experiment_progress[experiment_map_2[experiment]].append(df)\n",
    "            trial_cnt += 1\n",
    "        except FileNotFoundError:\n",
    "            print(experiment_map_2[experiment])\n",
    "            print(experiment, trial, \" is incomplete\")\n",
    "        except KeyError:\n",
    "            print(experiment, \"does not exist in the given dict\")\n",
    "            break\n",
    "\n",
    "    try:\n",
    "        experiment_stats[experiment_map_2[experiment]] += trial_cnt\n",
    "    except KeyError:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_dfs = {}\n",
    "\n",
    "value_vars_d = [\n",
    "            \"loss\",\n",
    "            \"train_loss\",\n",
    "            \"med_abs_error\",\n",
    "            \"max_abs_err_pattern\",\n",
    "            \"mean_worse_abs_errs_graph\",\n",
    "\n",
    "        ]\n",
    "value_vars_nd = copy.deepcopy([value_vars_d])\n",
    "value_vars_nd.append(['q11', 'q12', 'q13', 'q14', 'q15',\n",
    "       'q21', 'q22', 'q23', 'q24', 'q25', 'q31', 'q32', 'q33', 'q34', 'q35',\n",
    "       'q41', 'q42', 'q43', 'q44', 'q45', 'q51', 'q52', 'q53', 'q54', 'q55',\n",
    "       'q61', 'q62', 'q63', 'q64', 'q65', 'q71', 'q72', 'q73', 'q74', 'q75',\n",
    "       'q81', 'q82', 'q83', 'q84', 'q85'])\n",
    "value_vars_nd = hutils.flatten_nested_list(value_vars_nd, sort=False)\n",
    "\n",
    "\n",
    "for k in experiment_progress.keys():\n",
    "    experiment_dfs[k] = pd.concat(experiment_progress[k])\n",
    "    if \"nd_\" in k:\n",
    "        vals = value_vars_nd\n",
    "    else:\n",
    "        vals = value_vars_d\n",
    "        \n",
    "    experiment_dfs[k + \"_melt\"] = pd.melt(\n",
    "        experiment_dfs[k].drop(\n",
    "            [\"max_abs_err_pattern_idx\", \"mean_worse_abs_errs_graph_idx\"], axis=1\n",
    "        ),\n",
    "        id_vars=[\n",
    "            \"timestamp\",\n",
    "            \"checkpoint_dir_name\",\n",
    "            \"should_checkpoint\",\n",
    "            \"done\",\n",
    "            \"training_iteration\",\n",
    "            \"trial_id\",\n",
    "            \"date\",\n",
    "            \"time_this_iter_s\",\n",
    "            \"time_total_s\",\n",
    "            \"pid\",\n",
    "            \"hostname\",\n",
    "            \"node_ip\",\n",
    "            \"time_since_restore\",\n",
    "            \"iterations_since_restore\",\n",
    "        ],\n",
    "        value_vars=vals,\n",
    "        value_name=\"Score\",\n",
    "        var_name=\"Metric\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_t = []\n",
    "for k in [k for k in experiment_dfs.keys() if \"_melt\" in k and \"nd_\" not in k]:\n",
    "    df = experiment_dfs[k].copy(deep=True)\n",
    "    df[\"GNN\"] = k.split(\"_\")[1]\n",
    "    _t.append(df)\n",
    "d_all_dfs = pd.concat(_t)\n",
    "\n",
    "_t = []\n",
    "for k in [k for k in experiment_dfs.keys() if \"_melt\" in k and \"nd_\" in k]:\n",
    "    df = experiment_dfs[k].copy(deep=True)\n",
    "    df[\"GNN\"] = k.split(\"_\")[1]\n",
    "    _t.append(df)\n",
    "nd_all_dfs = pd.concat(_t)\n",
    "del _t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_context(\"paper\", font_scale=2.5)\n",
    "sns.set_style(\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = [\n",
    "    \"#000000\",\n",
    "    \"#E69F00\",\n",
    "    \"#56B4E9\",\n",
    "    \"#009E73\",\n",
    "    \"#FB6467FF\",\n",
    "    \"#808282\",\n",
    "    \"#F0E442\",\n",
    "    \"#440154FF\",\n",
    "    \"#0072B2\",\n",
    "    \"#D55E00\",\n",
    "    \"#CC79A7\",\n",
    "    \"#C2CD23\",\n",
    "    \"#918BC3\",\n",
    "    \"#FFFFFF\",\n",
    "]\n",
    "\n",
    "mpnn_pal = {\"gcn\": \"#E69F00\", \"gin\": \"#56B4E9\", \"sage\": \"#FB6467FF\", \"gat\": \"#009E73\"}\n",
    "\n",
    "sns.color_palette(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(\n",
    "    d_all_dfs,\n",
    "    sharex=True,\n",
    "    sharey=False,\n",
    "    col=\"Metric\",\n",
    "    col_wrap=3,\n",
    "    height=6,\n",
    "    aspect=19 / 11,\n",
    ")\n",
    "ax = g.map_dataframe(\n",
    "    sns.lineplot,\n",
    "    x=\"training_iteration\",\n",
    "    y=\"Score\",\n",
    "    hue=\"GNN\",\n",
    "    palette=mpnn_pal,\n",
    "    estimator=\"mean\",\n",
    "    errorbar=\"se\",\n",
    "    legend=\"full\",\n",
    "    marker=\".\",\n",
    "    markeredgecolor=(0, 0, 0, 0.5),\n",
    ")\n",
    "g.set_titles(col_template=\"{col_name}\")\n",
    "g.add_legend(label_order=sorted(d_all_dfs[\"GNN\"].unique()))\n",
    "sns.move_legend(g, \"lower right\")\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.set_xlabel(\"Epochs\")  # Set x label\n",
    "    ax.set_ylabel(\"Squaerd Error\")  # Set y label\n",
    "\n",
    "# Define new text for each title based on their text\n",
    "title_new_text = {\n",
    "    \"train_loss\": \"Train Loss\",\n",
    "    \"mean_worse_abs_errs_graph\": \"Mean of the Absolute Error of the Worst $g \\in \\Omega$\",\n",
    "    \"med_abs_error\": \"Median Absolute Error\",\n",
    "    \"loss\": \"Validation Loss\",\n",
    "    \"max_abs_err_pattern\": \"Absolute Error of the Worst Predicted Pattern\",\n",
    "}\n",
    "\n",
    "for ax, title in zip(g.axes.flat, g.col_names):\n",
    "    ax.set_title(title_new_text.get(title, 'TITLENOTFOUND'),)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(SAVE_PATH / \"average_breakdown_d.pdf\", dpi=1200)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(\n",
    "    nd_all_dfs[~nd_all_dfs[\"Metric\"].str.contains(\"q\")],\n",
    "    sharex=True,\n",
    "    sharey=False,\n",
    "    col=\"Metric\",\n",
    "    col_wrap=3,\n",
    "    height=6,\n",
    "    aspect=19 / 11,\n",
    ")\n",
    "ax = g.map_dataframe(\n",
    "    sns.lineplot,\n",
    "    x=\"training_iteration\",\n",
    "    y=\"Score\",\n",
    "    hue=\"GNN\",\n",
    "    palette=mpnn_pal,\n",
    "    estimator=\"mean\",\n",
    "    errorbar=\"se\",\n",
    "    legend=\"full\",\n",
    "    marker=\".\",\n",
    "    markeredgecolor=(0, 0, 0, 0.5),\n",
    ")\n",
    "g.set_titles(col_template=\"{col_name}\")\n",
    "g.add_legend(label_order=sorted(d_all_dfs[\"GNN\"].unique()))\n",
    "sns.move_legend(g, \"lower right\")\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.set_xlabel(\"Epochs\")  # Set x label\n",
    "    ax.set_ylabel(\"Squaerd Error\")  # Set y label\n",
    "\n",
    "# Define new text for each title based on their text\n",
    "title_new_text = {\n",
    "    \"train_loss\": \"Train Loss\",\n",
    "    \"mean_worse_abs_errs_graph\": \"Mean of the Absolute Error of the Worst $g \\in \\Omega$\",\n",
    "    \"med_abs_error\": \"Median Absolute Error\",\n",
    "    \"loss\": \"Validation Loss\",\n",
    "    \"max_abs_err_pattern\": \"Absolute Error of the Worst Predicted Pattern\",\n",
    "}\n",
    "\n",
    "for ax, title in zip(g.axes.flat, g.col_names):\n",
    "    ax.set_title(title_new_text.get(title, 'TITLENOTFOUND'),)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(SAVE_PATH / \"average_breakdown_nd.pdf\", dpi=1200)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_run_nd = \"\"\n",
    "best_expr_nd = \"\"\n",
    "best_value_nd = np.inf\n",
    "best_value_idx_nd = -1\n",
    "\n",
    "best_run_d = \"\"\n",
    "best_expr_d = \"\"\n",
    "best_value_d = np.inf\n",
    "best_value_idx_d = -1\n",
    "\n",
    "best_run_nd_gin = \"\"\n",
    "best_expr_nd_gin = \"\"\n",
    "best_value_nd_gin = np.inf\n",
    "best_value_idx_nd_gin = -1\n",
    "\n",
    "for experiment in sorted(os.listdir(BASE_EXPERIMENT_DIRECTORY / MODELS_RESULTS)):\n",
    "    trial_cnt = 0\n",
    "    for trial in os.listdir(BASE_EXPERIMENT_DIRECTORY / MODELS_RESULTS / experiment):\n",
    "        if \"TorchTrainer\" not in trial:\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            df = pd.read_csv(\n",
    "                BASE_EXPERIMENT_DIRECTORY\n",
    "                / MODELS_RESULTS\n",
    "                / experiment\n",
    "                / trial\n",
    "                / \"progress.csv\"\n",
    "            )\n",
    "            min_value = df[\"loss\"].min()\n",
    "            min_value_id = df[\"loss\"].idxmin()\n",
    "\n",
    "            if (\n",
    "                experiment in experiment_map_2.keys()\n",
    "                and \"nd_\" in experiment_map_2[experiment]\n",
    "                and min_value < best_value_nd\n",
    "            ):\n",
    "                best_value_nd = min_value\n",
    "                best_value_idx_nd = min_value_id\n",
    "                best_run_nd = trial\n",
    "                best_expr_nd = experiment\n",
    "            elif min_value < best_value_d:\n",
    "                best_value_d = min_value\n",
    "                best_value_idx_d = min_value_id\n",
    "                best_run_d = trial\n",
    "                best_expr_d = experiment\n",
    "\n",
    "            if (\n",
    "                experiment in experiment_map_2.keys()\n",
    "                and \"nd_\" in experiment_map_2[experiment]\n",
    "                and \"gin\" in experiment_map_2[experiment]\n",
    "                and min_value < best_value_nd_gin\n",
    "            ):\n",
    "                best_value_nd_gin = min_value\n",
    "                best_value_idx_nd_gin = min_value_id\n",
    "                best_run_nd_gin = trial\n",
    "                best_expr_nd_gin = experiment\n",
    "\n",
    "        except FileNotFoundError:\n",
    "            # This is not an error! It only means that a trial was inturrepted and resume later\n",
    "            print(experiment_map_2[experiment])\n",
    "            print(experiment, trial, \" is incomplete\")\n",
    "\n",
    "best_run_pretty = best_run_nd.split(\"_\")[0] + \"_\" + best_run_nd.split(\"_\")[1]\n",
    "print(\"For ND\")\n",
    "print(\n",
    "    f\"Best Run {best_run_pretty}, with {best_value_nd} at epoch {best_value_idx_nd} in {best_expr_nd} experiment ({experiment_map_2[best_expr_nd]})!\"\n",
    ")\n",
    "\n",
    "best_run_pretty = best_run_nd_gin.split(\"_\")[0] + \"_\" + best_run_nd_gin.split(\"_\")[1]\n",
    "print(\"\\nFor ND - GIN\")\n",
    "print(\n",
    "    f\"Best Run {best_run_pretty}, with {best_value_nd_gin} at epoch {best_value_idx_nd_gin} in {best_expr_nd_gin} experiment ({experiment_map_2[best_expr_nd_gin]})!\"\n",
    ")\n",
    "\n",
    "best_run_pretty = best_run_d.split(\"_\")[0] + \"_\" + best_run_d.split(\"_\")[1]\n",
    "print(\"\\nFor D\")\n",
    "print(\n",
    "    f\"Best Run {best_run_pretty}, with {best_value_d} at epoch {best_value_idx_d} in {best_expr_d} experiment ({experiment_map_2[best_expr_d]})!\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(\n",
    "    nd_all_dfs[\n",
    "        (nd_all_dfs[\"Metric\"].str.contains(\"q\"))\n",
    "        & (\n",
    "            (nd_all_dfs[\"trial_id\"] == \"995a1ad7\")\n",
    "            | (nd_all_dfs[\"trial_id\"] == \"d9c8c887\")\n",
    "        )\n",
    "    ],\n",
    "    sharex=True,\n",
    "    sharey=False,\n",
    "    col=\"Metric\",\n",
    "    col_wrap=5,\n",
    "    height=6,\n",
    "    aspect=19 / 11,\n",
    ")\n",
    "ax = g.map_dataframe(\n",
    "    sns.lineplot,\n",
    "    x=\"training_iteration\",\n",
    "    y=\"Score\",\n",
    "    estimator=\"mean\",\n",
    "    hue=\"GNN\",\n",
    "    palette={\"gin\": \"#56B4E9\", \"sage\": \"#FB6467FF\"},\n",
    "    errorbar=None,\n",
    "    legend=\"full\",\n",
    "    marker=\".\",\n",
    "    markeredgecolor=(0,0,0,0.5),\n",
    ")\n",
    "g.add_legend()\n",
    "sns.move_legend(g, \"center right\")\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.set_xlabel(\"Epochs\")  # Set x label\n",
    "    ax.set_ylabel(\"Squaerd Error\")  # Set y label\n",
    "\n",
    "# plt.tight_layout()\n",
    "plt.savefig(SAVE_PATH / \"best_quartile_breakdown_nd.pdf\", dpi=1200)\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
