{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "import seaborn as sns \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np \n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.size'] = 8\n",
    "plt.rcParams['axes.labelsize'] = 8\n",
    "plt.rcParams['xtick.labelsize'] = 8\n",
    "plt.rcParams['ytick.labelsize'] = 8\n",
    "plt.rcParams['legend.fontsize'] = 8\n",
    "markersize = 3\n",
    "margin_title_size = 8\n",
    "\n",
    "plot_height = 0.7\n",
    "line_width = 5.1\n",
    "\n",
    "\n",
    "from fractions import Fraction\n",
    "import re\n",
    "\n",
    "\n",
    "def float_to_latex_fraction(x, limit_denominator=10):\n",
    "    x = float(re.search(r'\\{([^}]+)\\}', x.get_text()).group(1).replace('−', '-'))\n",
    "    frac = Fraction(x).limit_denominator(limit_denominator)\n",
    "    if frac.denominator == 1:\n",
    "        return f\"${frac.numerator}$\"\n",
    "    return f\"$\\\\frac{{{frac.numerator}}}{{{frac.denominator}}}$\"\n",
    "\n",
    "\n",
    "def set_yticklabels_as_fractions(ax):\n",
    "    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for root, dirs, files in os.walk(\"results/wind\"):\n",
    "    for file in files:\n",
    "        if file.endswith(\".csv\"):\n",
    "            experiment_name = os.path.basename(root)\n",
    "            experiment_name = experiment_name.replace(\"1e-05\", \"0.00001\")\n",
    "            experiment_name = experiment_name.replace(\"1e-06\", \"0.000001\")\n",
    "            experiment_params = experiment_name.split(\"-\")\n",
    "            experiment_dict = {param.split(\"=\")[0]: param.split(\"=\")[1] for param in experiment_params}\n",
    "\n",
    "            metrics_path = os.path.join(root, file)\n",
    "            metrics_dict = pd.read_csv(metrics_path).to_dict(orient=\"records\")[0]\n",
    "\n",
    "            result_dict = experiment_dict | metrics_dict\n",
    "            results.append(result_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame(results).astype(\n",
    "    {\n",
    "        'time': int, \n",
    "        'level': int, \n",
    "        'step_minutes': int, \n",
    "        'num_layers': int, \n",
    "        'seed': int, \n",
    "        'max_ell_variational': int, \n",
    "        'num_test_samples': int, \n",
    "        'max_ell_prior': int, \n",
    "        'num_samples': int, \n",
    "        'total_hidden_variance': float, \n",
    "        'num_iters': int, \n",
    "        'lr': float, \n",
    "        'num_hours': int, \n",
    "        'mse': float, \n",
    "        'pnll': float, \n",
    "    }\n",
    ").rename(columns={\n",
    "    'pnll': 'nlpd'\n",
    "}).query((\n",
    "    \"level in (0, 7, 15)\"\n",
    "    \"& num_layers in (1, 2, 3, 4, 5)\"\n",
    "    \"& total_hidden_variance == 0.0001\"\n",
    "    \"& step_minutes == 1\"\n",
    "    \"& num_hours == 24\"\n",
    "    \"& seed == 0\"\n",
    "    \"& max_ell_variational == 9\"\n",
    "    \"& max_ell_prior == 9\"\n",
    "    \"& num_samples == 3\"\n",
    "    \"& num_iters == 1000\"\n",
    "    \"& lr == 0.01\"\n",
    ")).sort_values(\n",
    "    by=[\"level\", \"time\", \"num_layers\"]\n",
    ").replace(\n",
    "    {\n",
    "        \"level\": {0: \"5.5 km\", 7: \"2.0 km\", 15: \"0.1 km\"},\n",
    "    }\n",
    ").assign(\n",
    "    time=lambda df: df.time + 1\n",
    ").rename(\n",
    "    columns={\n",
    "        \"level\": \"Altitude\",\n",
    "        \"num_layers\": \"Number of Layers\",\n",
    "        \"mse\": \"MSE\",\n",
    "        \"nlpd\": \"NLPD\", \n",
    "        \"time\": \"Month\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(\n",
    "    data,\n",
    "    col='Altitude',\n",
    "    hue=\"Number of Layers\", \n",
    ")\n",
    "g.map(sns.lineplot, \"Month\", \"NLPD\", marker=\"o\", markersize=markersize)\n",
    "g.add_legend(\n",
    "    loc=\"upper left\",\n",
    "    bbox_to_anchor=(1.0, 0.85),\n",
    ")\n",
    "g.legend.set_title(\"\\# Layers\", prop={'size': 8})\n",
    "g.set(xticks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])\n",
    "g.set_titles(col_template=\"Altitude $=$ {col_name}\")\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width / 1.2, plot_height / 1.2 * 2.0)\n",
    "\n",
    "# plt.savefig(\"plots/hodge-nlpd_vs_month_num_layers_and_altitude.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(\n",
    "    data,\n",
    "    col='Altitude',\n",
    "    hue=\"Number of Layers\", \n",
    ")\n",
    "g.map(sns.lineplot, \"Month\", \"MSE\", marker=\"o\", markersize=markersize)\n",
    "g.add_legend(\n",
    "    loc=\"upper left\",\n",
    "    bbox_to_anchor=(1.0, 0.85),\n",
    ")\n",
    "g.legend.set_title(\"\\# Layers\", prop={'size': 8})\n",
    "g.set(xticks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])\n",
    "g.set_titles(col_template=\"Altitude $=$ {col_name}\")\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width / 1.2, plot_height / 1.2 * 2.0)\n",
    "\n",
    "# plt.savefig(\"plots/hodge-mse_vs_month_num_layers_and_altitude.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "g = sns.FacetGrid(data, hue=\"Altitude\", margin_titles=True, sharey=False)\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"NLPD\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    loc=\"upper left\",\n",
    "    title=\"Altitude\",\n",
    "    bbox_to_anchor=(1.0, 1.15),\n",
    "    ncol=1,\n",
    ")\n",
    "g.legend.set_title(\"Altitude\", prop={'size': 8})\n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width / 5.5, plot_height)\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='x', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "\n",
    "# plt.savefig(\"plots/hodge-nlpd_vs_num_layers_and_altitude-sd1.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "g = sns.FacetGrid(data, hue=\"Altitude\", margin_titles=True, sharey=False)\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"NLPD\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    loc=\"upper left\",\n",
    "    title=\"\",\n",
    "    bbox_to_anchor=(0.05, 1.10),\n",
    "    ncol=3,\n",
    "    prop={'size': 8},\n",
    "    columnspacing=-1.0\n",
    ")\n",
    "for txt in g.legend.get_texts():\n",
    "    txt.set_ha(\"center\") # horizontal alignment of text item\n",
    "    txt.set_x(-26) # x-position\n",
    "    txt.set_y(8) # y-position\n",
    "    \n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=0.85, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width / 3, plot_height * 1.5)\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='x', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "\n",
    "# plt.savefig(\"plots/hodge-nlpd_vs_num_layers_and_altitude-sd1-larger.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "g = sns.FacetGrid(data, hue=\"Altitude\", margin_titles=True, sharey=False)\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"MSE\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    loc=\"upper left\",\n",
    "    title=\"\",\n",
    "    bbox_to_anchor=(0.05, 1.10),\n",
    "    ncol=3,\n",
    "    prop={'size': 8},\n",
    "    columnspacing=-1.0\n",
    ")\n",
    "for txt in g.legend.get_texts():\n",
    "    txt.set_ha(\"center\") # horizontal alignment of text item\n",
    "    txt.set_x(-26) # x-position\n",
    "    txt.set_y(8) # y-position\n",
    "    \n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=0.85, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width / 3, plot_height * 1.5)\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='x', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "\n",
    "# plt.savefig(\"plots/hodge-mse_vs_num_layers_and_altitude-sd1.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
