{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Synthetic Tree Results\n",
    "\n",
    "This notebook loads results from `run.py` and plots value estimation error for UCT, MENTS, RENTS, TENTS, DENTS, BTS, and VarDE.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceaf7fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c74c3812",
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "\n",
    "    # Embed fonts as TrueType (selectable/searchable)\n",
    "    \"pdf.fonttype\": 42,\n",
    "    \"ps.fonttype\": 42,\n",
    "\n",
    "    # sizes (tune if you want)\n",
    "    \"axes.titlesize\": 14,\n",
    "    \"axes.labelsize\": 14,\n",
    "    \"legend.fontsize\": 14,\n",
    "    \"legend.title_fontsize\": 14,\n",
    "    \"xtick.labelsize\": 14,\n",
    "    \"ytick.labelsize\": 14,\n",
    "\n",
    "    # Times-like fonts, matching paper\n",
    "    \"text.latex.preamble\": r\"\"\"\n",
    "        \\usepackage{newtxtext}\n",
    "        \\usepackage{newtxmath}\n",
    "    \"\"\",\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ef4208a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set this to a specific run directory if needed\n",
    "run_dir = None\n",
    "\n",
    "base_dir = Path(\"logs/runs\")\n",
    "if run_dir is None:\n",
    "    runs = sorted(base_dir.glob(\"run_*\"), key=lambda p: p.stat().st_mtime)\n",
    "    if not runs:\n",
    "        raise FileNotFoundError(\"No run directories found under logs/runs.\")\n",
    "    run_dir = runs[-1]\n",
    "else:\n",
    "    run_dir = Path(run_dir)\n",
    "\n",
    "print(f\"Using run directory: {run_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e915d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(run_dir / \"config.json\", \"r\", encoding=\"utf-8\") as handle:\n",
    "    config = json.load(handle)\n",
    "\n",
    "algorithms = config[\"algorithms\"]\n",
    "kds = [(item[\"k\"], item[\"d\"]) for item in config[\"kds\"]]\n",
    "n_exp = config[\"n_exp\"]\n",
    "n_trees = config[\"n_trees\"]\n",
    "n_simulations = config[\"n_simulations\"]\n",
    "\n",
    "print(\"Algorithms:\", algorithms)\n",
    "print(\"k,d pairs:\", kds)\n",
    "print(\"n_exp, n_trees, n_simulations:\", n_exp, n_trees, n_simulations)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "507adfa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "kds = [(16, 1), (14, 3), (16, 4), (200, 2)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92d4c505",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {\n",
    "    \"uct\": \"UCT\",\n",
    "    \"ments\": \"MENTS\",\n",
    "    \"rents\": \"RENTS\",\n",
    "    \"tents\": \"TENTS\",\n",
    "    \"dents\": \"DENTS\",\n",
    "    \"bts\": \"BTS\",\n",
    "    \"varde\": \"VarDE\",\n",
    "}\n",
    "\n",
    "colors = {\n",
    "    \"uct\": \"#1f77b4\",\n",
    "    \"ments\": \"#d62728\",\n",
    "    \"rents\": \"#9467bd\",\n",
    "    \"tents\": \"#8c564b\",\n",
    "    \"dents\": \"#ff27be\",\n",
    "    \"bts\": \"#ff7f0e\",\n",
    "    \"varde\": \"#2ca02c\",\n",
    "}\n",
    "\n",
    "fig, axes = plt.subplots(1, len(kds), figsize=(4 * len(kds), 3), squeeze=False)\n",
    "axes = axes[0]\n",
    "\n",
    "for ax, (k, d) in zip(axes, kds):\n",
    "    max_val = 0.0\n",
    "    for alg in algorithms:\n",
    "        path = run_dir / \"results\" / f\"k_{k}_d_{d}\" / f\"diff_uct_{alg}.npy\"\n",
    "        if not path.exists():\n",
    "            print(f\"Missing {path}\")\n",
    "            continue\n",
    "        diff_uct = np.load(path)\n",
    "        avg = diff_uct.mean(axis=0)\n",
    "        flat = diff_uct.reshape(n_exp * n_trees, n_simulations)\n",
    "        err = 2 * flat.std(axis=0) / np.sqrt(n_exp * n_trees)\n",
    "        linestyle = ':' if alg in {'ments', 'rents', 'tents'} else '-'\n",
    "        ax.plot(avg, color=colors.get(alg), linestyle=linestyle, linewidth=2.0, label=labels.get(alg, alg))\n",
    "        ax.fill_between(np.arange(n_simulations), avg - err, avg + err, color=colors.get(alg), alpha=0.15)\n",
    "        max_val = max(max_val, float(avg.max()))\n",
    "\n",
    "    ax.set_title(f\"k={k} d={d}\")\n",
    "    ax.set_xlabel(\"\\# Simulations\")\n",
    "    if ax is axes[0]:\n",
    "        ax.set_ylabel(\"Value Estimation Error\")\n",
    "    ax.grid(True, linestyle=\"--\", linewidth=0.5)\n",
    "    ax.set_xlim(0, n_simulations - 1)\n",
    "    ax.set_ylim(0, max_val * 1.1 if max_val > 0 else 1.0)\n",
    "\n",
    "handles, labels_list = axes[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels_list, loc=\"lower center\", ncol=len(labels_list), frameon=True, framealpha=0.8)\n",
    "fig.tight_layout()\n",
    "fig.subplots_adjust(bottom=0.33)\n",
    "plt.savefig(run_dir / \"stree.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dude",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
