{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "import seaborn as sns \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.size'] = 9\n",
    "plt.rcParams['axes.labelsize'] = 10\n",
    "plt.rcParams['xtick.labelsize'] = 9\n",
    "plt.rcParams['ytick.labelsize'] = 9\n",
    "plt.rcParams['legend.fontsize'] = 10\n",
    "markersize = 3\n",
    "margin_title_size = 9\n",
    "\n",
    "plot_height = 0.7\n",
    "line_width = 5.1\n",
    "\n",
    "\n",
    "from fractions import Fraction\n",
    "import re\n",
    "\n",
    "\n",
    "def float_to_latex_fraction(x, limit_denominator=10):\n",
    "    x = float(re.search(r'\\{([^}]+)\\}', x.get_text()).group(1).replace('−', '-'))\n",
    "    frac = Fraction(x).limit_denominator(limit_denominator)\n",
    "    if frac.denominator == 1:\n",
    "        return f\"${frac.numerator}$\"\n",
    "    return f\"$\\\\frac{{{frac.numerator}}}{{{frac.denominator}}}$\"\n",
    "\n",
    "\n",
    "def set_yticklabels_as_fractions(ax):\n",
    "    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for root, dirs, files in os.walk(\"results/uci/\"):\n",
    "    for file in files:\n",
    "        if file.endswith(\".csv\"):\n",
    "            experiment_name = os.path.basename(root)\n",
    "            experiment_name = experiment_name.replace(\"=-1\", \"=None\")\n",
    "            print(experiment_name)\n",
    "            experiment_params = experiment_name.split(\"-\")\n",
    "            experiment_dict = {param.split(\"=\")[0]: param.split(\"=\")[1] for param in experiment_params}\n",
    "\n",
    "            metrics_path = os.path.join(root, file)\n",
    "            metrics_dict = pd.read_csv(metrics_path).to_dict(orient=\"records\")[0]\n",
    "\n",
    "            result_dict = experiment_dict | metrics_dict\n",
    "            results.append(result_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame(results).astype(\n",
    "    {\n",
    "        'dataset_name': str, \n",
    "        \"model_name\": str, \n",
    "        \"num_layers\": int, \n",
    "        \"seed\": int, \n",
    "        # \"num_iters\": int, \n",
    "        \"mse\": float, \n",
    "        \"nlpd\": float,\n",
    "    }\n",
    ").dropna(subset=[\"num_iters\"]).astype(\n",
    "    {\n",
    "        \"num_iters\": int\n",
    "    }\n",
    ").sort_values(\n",
    "    by=['model_name', 'num_layers', 'seed']\n",
    ").replace(\n",
    "    \"None\", -1\n",
    ").drop(\n",
    "    columns=[\"Unnamed: 0\"]\n",
    ").query(\n",
    "    \"num_iters == 5000 & seed in [0, 1, 2, 3, 4] & kernel_max_ell == -1 & (dataset_name not in ['kin8mn', 'power'] | batch_size == '1000')\"\n",
    ").drop_duplicates().query((\n",
    "    \"model_name in ['euclidean+inducing_points', 'residual+spherical_harmonic_features']\"\n",
    "    \"& dataset_name in ['yacht', 'energy', 'concrete', 'kin8mn', 'power']\"\n",
    "))\n",
    "\n",
    "dataset_to_batch_size = {\n",
    "    \"yacht\": 277, \n",
    "    \"energy\": 691, \n",
    "    \"concrete\": 927, \n",
    "    \"kin8mn\": 1000, \n",
    "    \"power\": 1000, \n",
    "}\n",
    "\n",
    "dataset_to_dimension = {\n",
    "    \"yacht\": 6, \n",
    "    \"energy\": 8, \n",
    "    \"concrete\": 8, \n",
    "    \"kin8mn\": 8, \n",
    "    \"power\": 4, \n",
    "}\n",
    "\n",
    "data = data.assign(\n",
    "    batch_size = lambda x: x[\"dataset_name\"].map(dataset_to_batch_size),\n",
    "    dimension = lambda x: x[\"dataset_name\"].map(dataset_to_dimension),\n",
    ").sort_values(\n",
    "    by=['batch_size', 'dataset_name']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "plot_data = data.rename(columns={\n",
    "    \"dataset_name\": \"Dataset\",\n",
    "    \"model_name\": \"Model\",\n",
    "    \"num_layers\": \"Number of Layers\",\n",
    "    \"nlpd\": \"NLPD\",\n",
    "    \"mse\": \"MSE\"\n",
    "}).replace(\n",
    "    {\n",
    "        \"Model\": {\n",
    "            \"residual+spherical_harmonic_features\": \"Residual (IV)\",\n",
    "            \"euclidean+inducing_points\": \"Euclidean (PI)\",\n",
    "        },\n",
    "    }\n",
    ").assign(\n",
    "    Dataset=lambda x: x[\"Dataset\"].str.capitalize()\n",
    ").assign(\n",
    "    Dataset=lambda x: x[\"Dataset\"] + \"\\nB=\" + x[\"batch_size\"].astype(str) + \", D=\" + x[\"dimension\"].astype(str),\n",
    ")\n",
    "\n",
    "g = sns.FacetGrid(\n",
    "    plot_data, \n",
    "    col=\"Dataset\", \n",
    "    margin_titles=False, \n",
    "    sharey=False, \n",
    "    hue=\"Model\",\n",
    "    gridspec_kws={\"wspace\":0.25, \"hspace\": 0.0}\n",
    ")\n",
    "\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"NLPD\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    title=\"\",\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 1.95),\n",
    "    ncol=2,\n",
    ")\n",
    "\n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.set_titles(col_template=\"{col_name}\", size=9)\n",
    "\n",
    "# g.axes[0, 4].set_yticks([-0.1, 0.0])\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    set_yticklabels_as_fractions(ax)\n",
    "    ax.tick_params(axis='both', which='major', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width, plot_height)\n",
    "\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.savefig(\"./plots/uci-nlpd_vs_num_layers-euclidean_and_residual-sd1-size_optimised.pdf\", bbox_inches='tight', pad_inches=0.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
