{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "from typing import Optional\n",
    "import os\n",
    "\n",
    "def get_nested_folders(path: str) -> list[str]:\n",
    "    \"\"\"\n",
    "    Recursively get a list of folders that contain a config.json file, starting the search from the given path\n",
    "    \"\"\"\n",
    "    folder_names = []\n",
    "\n",
    "    for root, dirs, files in os.walk(path):\n",
    "        if \"config.json\" in files:\n",
    "            folder_names.append(root)\n",
    "\n",
    "    return folder_names\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAINER_LABELS = {\n",
    "    \"StandardTrainer\": \"Standard\",\n",
    "    \"JumpReluTrainer\": \"JumpReLU\",\n",
    "    \"TopKTrainer\": \"Top K\",\n",
    "    \"BatchTopKTrainer\": \"Batch Top K\",\n",
    "    \"GatedSAETrainer\": \"Gated\",\n",
    "    \"PAnnealTrainer\": \"P-Anneal\",\n",
    "}\n",
    "\n",
    "TRAINER_MARKERS = {\n",
    "    \"StandardTrainer\": \"o\",\n",
    "    \"JumpReluTrainer\": \"X\",\n",
    "    \"TopKTrainer\": \"s\",\n",
    "    \"BatchTopKTrainer\": \"d\",\n",
    "    \"GatedSAETrainer\": \"d\",\n",
    "    \"PAnnealTrainer\": \"s\",\n",
    "}\n",
    "\n",
    "TRAINER_COLORS = {\n",
    "    \"StandardTrainer\": \"blue\",\n",
    "    \"JumpReluTrainer\": \"orange\",\n",
    "    \"TopKTrainer\": \"green\",\n",
    "    \"BatchTopKTrainer\": \"black\",\n",
    "    \"GatedSAETrainer\": \"red\",\n",
    "    \"PAnnealTrainer\": \"purple\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dirs = [\"./top_k\", \"./batch_top_k\", \"./jumprelu\"]\n",
    "# save_dirs = [\"./run2\"]\n",
    "ae_paths = []\n",
    "\n",
    "for save_dir in save_dirs:\n",
    "    ae_paths.extend(get_nested_folders(save_dir))\n",
    "\n",
    "print(ae_paths)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_results = {}\n",
    "\n",
    "for ae_path in ae_paths:\n",
    "    with open(ae_path + \"/config.json\") as f:\n",
    "        config = json.load(f)\n",
    "\n",
    "    with open(ae_path + \"/eval_results.json\") as f:\n",
    "        eval_results = json.load(f)\n",
    "\n",
    "    ae_results = {}\n",
    "\n",
    "    ae_results[\"l0\"] = eval_results[\"l0\"]\n",
    "    ae_results[\"frac_recovered\"] = eval_results[\"frac_recovered\"]\n",
    "    ae_results[\"trainer_class\"] = config[\"trainer\"][\"trainer_class\"]\n",
    "    ae_results[\"dict_size\"] = config[\"trainer\"][\"dict_size\"]\n",
    "\n",
    "    ae_results['frac_alive'] = eval_results['frac_alive']\n",
    "\n",
    "    plotting_results[ae_path] = ae_results\n",
    "print(plotting_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_2var_graph(\n",
    "    results: dict[str, dict[str, float]],\n",
    "    custom_metric: str,\n",
    "    title: str = \"L0 vs Custom Metric\",\n",
    "    y_label: str = \"Custom Metric\",\n",
    "    xlims: Optional[tuple[float, float]] = None,\n",
    "    ylims: Optional[tuple[float, float]] = None,\n",
    "    output_filename: Optional[str] = None,\n",
    "    legend_location: str = \"lower right\",\n",
    "    x_axis_key: str = \"l0\",\n",
    "    return_fig: bool = False,\n",
    "):\n",
    "    # Extract data from results\n",
    "    l0_values = [data[x_axis_key] for data in results.values()]\n",
    "    custom_metric_values = [data[custom_metric] for data in results.values()]\n",
    "\n",
    "    # Create the scatter plot\n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "    handles, labels = [], []\n",
    "\n",
    "    for trainer, marker in TRAINER_MARKERS.items():\n",
    "        # Filter data for this trainer\n",
    "        trainer_data = {k: v for k, v in results.items() if v[\"trainer_class\"] == trainer}\n",
    "\n",
    "        if not trainer_data:\n",
    "            continue  # Skip this trainer if no data points\n",
    "\n",
    "        l0_values = [data[x_axis_key] for data in trainer_data.values()]\n",
    "        custom_metric_values = [data[custom_metric] for data in trainer_data.values()]\n",
    "\n",
    "        # Plot data points\n",
    "        scatter = ax.scatter(\n",
    "            l0_values,\n",
    "            custom_metric_values,\n",
    "            marker=marker,\n",
    "            s=100,\n",
    "            label=trainer,\n",
    "            color=TRAINER_COLORS[trainer],\n",
    "            edgecolor=\"black\",\n",
    "        )\n",
    "\n",
    "        # Create custom legend handle with both marker and color\n",
    "        legend_handle = plt.scatter(\n",
    "            [], [], marker=marker, s=100, color=TRAINER_COLORS[trainer], edgecolor=\"black\"\n",
    "        )\n",
    "        handles.append(legend_handle)\n",
    "\n",
    "        if trainer in TRAINER_LABELS:\n",
    "            trainer_label = TRAINER_LABELS[trainer]\n",
    "        else:\n",
    "            trainer_label = trainer.capitalize()\n",
    "        labels.append(trainer_label)\n",
    "\n",
    "    # Set labels and title\n",
    "    ax.set_xlabel(\"L0 (Sparsity)\")\n",
    "    ax.set_ylabel(y_label)\n",
    "    ax.set_title(title)\n",
    "\n",
    "    ax.legend(handles, labels, loc=legend_location)\n",
    "\n",
    "    # Set axis limits\n",
    "    if xlims:\n",
    "        ax.set_xlim(*xlims)\n",
    "    if ylims:\n",
    "        ax.set_ylim(*ylims)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save and show the plot\n",
    "    if output_filename:\n",
    "        plt.savefig(output_filename, bbox_inches=\"tight\")\n",
    "\n",
    "    if return_fig:\n",
    "        return fig\n",
    "\n",
    "    plt.show()\n",
    "    \n",
    "plt.rcParams.update({\"font.size\": 20})\n",
    "plot_2var_graph(plotting_results, \"frac_recovered\", title=\"Fraction Recovered vs L0\", y_label=\"Fraction Recovered\", output_filename=\"frac_recovered_vs_l0.png\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
