{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.set_context(\"paper\", font_scale=1)\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "p = [\n",
    "    \"#000000\",\n",
    "    \"#E69F00\",\n",
    "    \"#56B4E9\",\n",
    "    \"#009E73\",\n",
    "    \"#FB6467FF\",\n",
    "    \"#808282\",\n",
    "    \"#F0E442\",\n",
    "    \"#440154FF\",\n",
    "    \"#0072B2\",\n",
    "    \"#D55E00\",\n",
    "    \"#CC79A7\",\n",
    "    \"#C2CD23\",\n",
    "    \"#918BC3\",\n",
    "    \"#FFFFFF\",\n",
    "]\n",
    "\n",
    "sns.color_palette(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_PATH = Path(\"some/base/path\")\n",
    "RESULTS_FROM_EVALUATE_MODELS = BASE_PATH / \"where\" / \"the\" / \"results\" / \"from\" / \"evaluate\" / \"are\" / \"stored\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_clean_text(txt, return_names=False):\n",
    "    p1 = txt.split(\" =\")[1].split(\" |\")[0]\n",
    "    p2 = txt.split(\" =\")[2]\n",
    "    if not return_names:\n",
    "        c = \"\"\n",
    "        if \"sreal\" in p1:\n",
    "            c += \"Small Real\"\n",
    "        else:\n",
    "            c += \"Medium-Large Real\"\n",
    "        \n",
    "        c += \" | \"\n",
    "        if \"gin\" in p2:\n",
    "            c += \"GIN\"\n",
    "        else:\n",
    "            c += \"SAGE\"\n",
    "        \n",
    "        c += \" \"\n",
    "        if \"nd\" in p2:\n",
    "            c += \"Non-Deterministic Data\"\n",
    "        else:\n",
    "            c += \"Deterministic Data\"\n",
    "        return c\n",
    "    \n",
    "    p1_norm = \"Small Real\" if \"sreal\" in p1.strip() else \"Medium-Large Real\"\n",
    "    return p1_norm, p1.strip(), p2.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_runs = {\n",
    "    \"nd_absolute_best_sage\": \"TALOS_20240216-185022-TorchTrainer_995a1ad7\",\n",
    "    \"nd_absolute_best_gin\": \"TALOS_20240201-180351-TorchTrainer_d9c8c887\",\n",
    "    \"d_absolute_best_gin\": \"TALOS_20240121-203233-TorchTrainer_e9bc79c6\",\n",
    "}  # \n",
    "\n",
    "runs_id = {}\n",
    "for k in id_runs.keys():\n",
    "    runs_id[id_runs[k]] = k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TYPES = [\"sreal\", \"mlreal\"]\n",
    "dfs = []\n",
    "\n",
    "for data_type in TYPES:\n",
    "    path = RESULTS_FROM_EVALUATE_MODELS / data_type\n",
    "    for model_type in os.listdir(path):\n",
    "        for experiment in sorted(os.listdir(os.path.join(path/model_type))):\n",
    "\n",
    "            if experiment not in runs_id.keys():\n",
    "                continue\n",
    "            \n",
    "            with open(path / model_type / experiment / \"params.json\") as json_data:\n",
    "                params = json.load(json_data)\n",
    "                json_data.close()\n",
    "                params = params[\"train_loop_config\"]\n",
    "\n",
    "            for file in os.listdir(os.path.join(path/model_type/experiment)):\n",
    "                if \".csv\" in file:\n",
    "                    df = pd.read_csv(os.path.join(path/model_type/experiment/file))\n",
    "                    df[\"ID\"] = runs_id[experiment]\n",
    "                    df[\"Type\"] = data_type\n",
    "                    dfs.append(df)\n",
    "\n",
    "df = pd.concat(dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_vars = [\"ID\", \"Type\"]\n",
    "\n",
    "melt_df = pd.melt(\n",
    "    df.drop(\n",
    "        columns=[\n",
    "            \"VAL/RealData Loss\",\n",
    "            \"rand_model_sd_RealLoss\",\n",
    "            \"rand_ind_RealLoss\",\n",
    "            \"rand_res_mean_RealLoss\",\n",
    "        ]\n",
    "    ),\n",
    "    id_vars=id_vars,\n",
    "    value_vars=[\"ValLoss\", \"TestLoss\", \"RealDataLoss\", \"rand_model_mean_RealLoss\"],\n",
    "    value_name=\"Value\",\n",
    "    var_name=\"Metric\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_context(\"paper\", font_scale=2.5)\n",
    "g = sns.FacetGrid(melt_df, col=\"ID\", row=\"Type\", aspect=19 / 19, height=11)\n",
    "g.map_dataframe(\n",
    "    sns.barplot,\n",
    "    y=\"Value\",\n",
    "    x=\"Metric\",\n",
    "    hue=\"Metric\",\n",
    "    orient=\"x\",\n",
    "    palette=p[1:5],\n",
    ")\n",
    "# Iterate over each subplot\n",
    "for ax in g.axes.flat:\n",
    "    ax.axhline(y=df[\"rand_ind_RealLoss\"][0], color='r', linestyle='--')\n",
    "    ax.axhline(y=df[\"rand_res_mean_RealLoss\"][0], color='b', linestyle='--')\n",
    "    # Access the bars in each subplot\n",
    "    for bar in ax.patches:\n",
    "        height = bar.get_height()\n",
    "        # Calculate the position of the text label\n",
    "        label_x = bar.get_x() + bar.get_width() / 2\n",
    "        label_y = height\n",
    "        # Add text annotation at the top of each bar\n",
    "        ax.text(label_x, label_y, f\"{height:.4f}\", ha=\"center\", va=\"bottom\")\n",
    "\n",
    "indices = np.arange(3, len(g.axes.flat) * 4, 4)\n",
    "for k, ax in enumerate(g.axes.flat):\n",
    "    for i, bar in enumerate(ax.patches):\n",
    "        if i in indices:\n",
    "            # if isinstance(a.title, plt.Text) and \"Rand\" in a.title.get_text():\n",
    "            err = df.loc[k, \"rand_model_sd_RealLoss\"]\n",
    "            ax.errorbar(\n",
    "                x=bar.get_x() + bar.get_width() / 2,\n",
    "                y=bar.get_height(),\n",
    "                yerr=err,\n",
    "                fmt=\"none\",\n",
    "                ecolor=p[3],\n",
    "                elinewidth=2,\n",
    "                capsize=5,\n",
    "                capthick=2,\n",
    "                zorder=1,\n",
    "            )\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    if isinstance(ax.title, plt.Text) and \"dropout\" in ax.title.get_text():\n",
    "        p1_norm, p1, p2 = get_clean_text(ax.title.get_text(), return_names=True)\n",
    "        m3, m1 = df[\"Dropout\"][(df[\"ID\"] == p2) & (df[\"Type\"] == p1)].to_numpy()[0].split(\" \")\n",
    "        ax.set_title(f\"{p1_norm} | M3: {m3}, M1: {m1}\", color=\"black\", alpha=1, fontweight=\"bold\")\n",
    "    else:\n",
    "        ax.set_title(get_clean_text(ax.title.get_text()), color=\"black\", fontweight=\"bold\")\n",
    "\n",
    "g.set_xlabels(\"\")\n",
    "g.set_xticklabels([\"Val\", \"Test\", \"Real\", \"Random Real\"])\n",
    "g.despine(left=True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(RESULTS_FROM_EVALUATE_MODELS / \"all_runs_metrics.pdf\", dpi=1200)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
