{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "import seaborn as sns \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np \n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.size'] = 8\n",
    "plt.rcParams['axes.labelsize'] = 8\n",
    "plt.rcParams['xtick.labelsize'] = 8\n",
    "plt.rcParams['ytick.labelsize'] = 8\n",
    "plt.rcParams['legend.fontsize'] = 8\n",
    "markersize = 3\n",
    "margin_title_size = 8\n",
    "\n",
    "plot_height = 0.7\n",
    "line_width = 5.1\n",
    "\n",
    "\n",
    "from fractions import Fraction\n",
    "import re\n",
    "\n",
    "\n",
    "def float_to_latex_fraction(x, limit_denominator=10):\n",
    "    x = float(re.search(r'\\{([^}]+)\\}', x.get_text()).group(1).replace('−', '-'))\n",
    "    frac = Fraction(x).limit_denominator(limit_denominator)\n",
    "    if frac.denominator == 1:\n",
    "        return f\"${frac.numerator}$\"\n",
    "    return f\"$\\\\frac{{{frac.numerator}}}{{{frac.denominator}}}$\"\n",
    "\n",
    "\n",
    "def set_yticklabels_as_fractions(ax):\n",
    "    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for root, dirs, files in os.walk(\"results/time_uci\"):\n",
    "    for file in files:\n",
    "        if file.endswith(\".csv\"):\n",
    "            experiment_name = os.path.basename(root)\n",
    "            experiment_name = experiment_name.replace(\"=-1\", \"=None\")\n",
    "            print(experiment_name)\n",
    "            metrics_path = os.path.join(root, file)\n",
    "            result_df = pd.read_csv(metrics_path)\n",
    "            results.append(result_df)\n",
    "\n",
    "data = pd.concat(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat(results).astype(\n",
    "    {\n",
    "        \"dataset\": str, \n",
    "        \"model\": str, \n",
    "        'num_layers': int, \n",
    "        \"dataset_dim\": int, \n",
    "        \"batch_size\": int, \n",
    "        \"num_iters\": int, \n",
    "        \"num_inducing\": int, \n",
    "        \"seed\": int,\n",
    "        \"time\": float, \n",
    "    }\n",
    ").sort_values(\n",
    "    by=['model', 'num_layers', 'seed', 'batch_size', 'dataset']\n",
    ").query('num_iters == 100')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "errorbar = ('sd', 1)\n",
    "\n",
    "plot_data = data.rename(columns={\n",
    "    \"dataset\": \"Dataset\",\n",
    "    \"model\": \"Model\",\n",
    "    \"num_layers\": \"Number of Layers\",\n",
    "    \"time\": \"Time (s)\",\n",
    "}).replace(\n",
    "    {\n",
    "        \"Model\": {\n",
    "            \"residual+spherical_harmonic_features\": \"Residual (IV)\",\n",
    "            \"euclidean+inducing_points\": \"Euclidean (IL)\",\n",
    "        },\n",
    "    }\n",
    ").assign(\n",
    "    Dataset=lambda x: x[\"Dataset\"].str.capitalize(),\n",
    ").assign(\n",
    "    log_time=lambda x: np.log(x[\"Time (s)\"])\n",
    ").rename(\n",
    "    columns={\n",
    "        \"log_time\": \"Log Time (s)\",\n",
    "    }\n",
    ").assign(\n",
    "    Dataset=lambda x: x[\"Dataset\"] + \"\\nB=\" + x[\"batch_size\"].astype(str) + \", d=\" + x[\"dataset_dim\"].astype(str),\n",
    ")\n",
    "\n",
    "\n",
    "g = sns.FacetGrid(\n",
    "    plot_data, \n",
    "    col=\"Dataset\", \n",
    "    margin_titles=True, \n",
    "    sharey=False, \n",
    "    hue=\"Model\",\n",
    "    gridspec_kws={\"wspace\":0.25, \"hspace\": 0.0},\n",
    ")\n",
    "g.map(sns.lineplot, \"Number of Layers\", \"Time (s)\", marker=\"o\", errorbar=errorbar, markersize=markersize)\n",
    "g.add_legend(\n",
    "    title=\"\",\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 1.85),\n",
    "    ncol=2,\n",
    ")\n",
    "\n",
    "g.set(xticks=[1, 2, 3, 4, 5])\n",
    "g.set_axis_labels(\"\\# Layers\")\n",
    "g.set_titles(col_template=\"{col_name}\", size=margin_title_size)\n",
    "\n",
    "for ax in g.axes.flat:\n",
    "    # set_yticklabels_as_fractions(ax)\n",
    "    ax.tick_params(axis='both', which='major', pad=-3)\n",
    "    ax.tick_params(axis='y', pad=-4)\n",
    "\n",
    "g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)\n",
    "g.figure.set_size_inches(line_width, plot_height)\n",
    "\n",
    "\n",
    "os.makedirs('./plots', exist_ok=True)\n",
    "plt.savefig(\"./plots/uci-timing.pdf\", bbox_inches='tight', pad_inches=0.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
