{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "import seaborn as sns \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.size'] = 9\n",
    "plt.rcParams['axes.labelsize'] = 10\n",
    "plt.rcParams['xtick.labelsize'] = 9\n",
    "plt.rcParams['ytick.labelsize'] = 9\n",
    "plt.rcParams['legend.fontsize'] = 10\n",
    "markersize = 3\n",
    "margin_title_size = 9\n",
    "\n",
    "plot_height = 0.7\n",
    "line_width = 5.1\n",
    "\n",
    "\n",
    "from fractions import Fraction\n",
    "import re\n",
    "\n",
    "\n",
    "def float_to_latex_fraction(x, limit_denominator=10):\n",
    "    x = float(re.search(r'\\{([^}]+)\\}', x.get_text()).group(1).replace('−', '-'))\n",
    "    frac = Fraction(x).limit_denominator(limit_denominator)\n",
    "    if frac.denominator == 1:\n",
    "        return f\"${frac.numerator}$\"\n",
    "    return f\"$\\\\frac{{{frac.numerator}}}{{{frac.denominator}}}$\"\n",
    "\n",
    "\n",
    "def set_yticklabels_as_fractions(ax):\n",
    "    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for root, dirs, files in os.walk(\"./results/synthetic/\"):\n",
    "    for file in files:\n",
    "        if file.endswith(\".csv\"):\n",
    "            experiment_name = os.path.basename(root)\n",
    "            experiment_params = experiment_name.split(\"-\")\n",
    "            experiment_dict = {param.split(\"=\")[0]: param.split(\"=\")[1] for param in experiment_params}\n",
    "\n",
    "            metrics_path = os.path.join(root, file)\n",
    "            metrics_dict = pd.read_csv(metrics_path).to_dict(orient=\"records\")[0]\n",
    "\n",
    "            result_dict = experiment_dict | metrics_dict\n",
    "            results.append(result_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame(results).astype(\n",
    "    {\n",
    "        \"num_train\": int, \n",
    "        \"model_name\": str, \n",
    "        \"num_layers\": int, \n",
    "        \"seed\": int, \n",
    "        \"nlpd\": float,\n",
    "        \"mse\": float,\n",
    "    }\n",
    ").sort_values(by=['num_train', 'num_layers'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data.rename(\n",
    "    columns={\n",
    "        \"num_train\": \"Number of Training Points\",\n",
    "        \"num_layers\": \"Number of Layers\",\n",
    "        \"nlpd\": \"NLPD\",\n",
    "        \"mse\": \"MSE\",\n",
    "        \"model_name\": \"Model\",\n",
    "    }\n",
    ").replace(\n",
    "    {\n",
    "        \"Model\": {\n",
    "            'residual+spherical_harmonic_features': \"Residual (IV)\", \n",
    "            'residual+inducing_points': \"Residual (PI)\", \n",
    "            'euclidean_with_geometric_input+inducing_points': \"Baseline\",\n",
    "            'residual+hodge+spherical_harmonic_features': 'Hodge (IV)',\n",
    "        },\n",
    "    }\n",
    ").query(\n",
    "    \"`Number of Training Points` == 800\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "g = sns.FacetGrid(\n",
    "    data, \n",
    "    col=\"Number of Training Points\", \n",
    "    hue=\"Model\", \n",
    "    margin_titles=True,\n",
    "    gridspec_kws={\"wspace\":0.20, \"hspace\": 0.0},\n",
    ")\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"NLPD\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    title=\"\",\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 1.75),\n",
    "    ncol=4,\n",
    "    fontsize=9,\n",
    ")\n",
    "\n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.set_titles(col_template=\"N = {col_name}\", size=9)\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width, plot_height)\n",
    "\n",
    "# set_yticklabels_as_fractions(g.axes[0, 0])\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='x', which='major', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "# plt.savefig(\"./plots/synthetic-nlpd_vs_num_layers_and_num_training-all_models-sd1.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "g = sns.FacetGrid(\n",
    "    data, \n",
    "    col=\"Number of Training Points\", \n",
    "    hue=\"Model\", \n",
    "    margin_titles=True,\n",
    "    gridspec_kws={\"wspace\":0.19, \"hspace\": 0.0},\n",
    ")\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"MSE\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    title=\"\",\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 1.75),\n",
    "    ncol=4,\n",
    "    fontsize=9,\n",
    ")\n",
    "\n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.set_titles(col_template=\"N = {col_name}\", size=9)\n",
    "g.set(ylim=(0, None))\n",
    "g.set(yticks=[0.0000, 0.0025, 0.0050])\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width, plot_height)\n",
    "\n",
    "# set_yticklabels_as_fractions(g.axes[0, 0])\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='x', which='major', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "# plt.savefig(\"./plots/synthetic-mse_vs_num_layers_and_num_training-all_models-sd1.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
