{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67075609-6138-4bf0-b225-8488f04bdaec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import json\n",
    "import pandas as pd\n",
    "\n",
    "# Configuration: path to the directory containing your eval results\n",
    "PARENT_DIR = \"/\"\n",
    "\n",
    "# Compile regex to extract type, m, n, and l0\n",
    "pattern = re.compile(\n",
    "    r'^(?P<sae_type>topk|kronsae)(?:_(?P<m>\\d+)_(?P<n>\\d+))?_l0_(?P<l0>\\d+)_custom_sae_eval_results'\n",
    ")\n",
    "\n",
    "records = []\n",
    "\n",
    "for name in os.listdir(PARENT_DIR):\n",
    "    path = os.path.join(PARENT_DIR, name)\n",
    "    json_path = os.path.join(path, \"eval_results.json\") if os.path.isdir(path) else path\n",
    "\n",
    "    if not os.path.exists(json_path) or not name.endswith((\"_custom_sae_eval_results\", \".json\")):\n",
    "        continue\n",
    "\n",
    "    m = pattern.match(os.path.basename(name))\n",
    "    if not m:\n",
    "        continue\n",
    "\n",
    "    gd = m.groupdict()\n",
    "    sae_type = gd[\"sae_type\"]\n",
    "    l0       = int(gd[\"l0\"])\n",
    "    m_val    = gd.get(\"m\")\n",
    "    n_val    = gd.get(\"n\")\n",
    "\n",
    "    # Build model_name\n",
    "    if m_val and n_val:\n",
    "        # model_name = f\"{sae_type}_{m_val}_{n_val}\"\n",
    "        model_name = f\"KronSAE $m={m_val}$ $n={n_val}$\"\n",
    "    else:\n",
    "        # model_name = sae_type\n",
    "        model_name = \"TopK\"\n",
    "\n",
    "    with open(json_path, 'r') as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    mean = data.get(\"eval_result_metrics\", {}).get(\"mean\", {})\n",
    "\n",
    "    records.append({\n",
    "        \"name\": name,\n",
    "        \"SAE\": model_name,\n",
    "        \"sae_type\": sae_type,\n",
    "        \"l0\": l0,\n",
    "        \"mean_absorption_fraction_score\": mean.get(\"mean_absorption_fraction_score\"),\n",
    "        \"mean_full_absorption_score\":     mean.get(\"mean_full_absorption_score\"),\n",
    "        \"mean_num_split_features\":        mean.get(\"mean_num_split_features\"),\n",
    "        \"std_dev_absorption_fraction_score\": mean.get(\"std_dev_absorption_fraction_score\"),\n",
    "        \"std_dev_full_absorption_score\":     mean.get(\"std_dev_full_absorption_score\"),\n",
    "        \"std_dev_num_split_features\":        mean.get(\"std_dev_num_split_features\"),\n",
    "    })\n",
    "\n",
    "# Create DataFrame\n",
    "df = pd.DataFrame(records)\n",
    "print(df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c39662e1-8fdf-4ff7-8029-ba0cedc24982",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D \n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15+4, 5), dpi=250)\n",
    "\n",
    "from matplotlib.pyplot import get_cmap\n",
    "tbl_cb10 = get_cmap(\"tab10\").colors  \n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "sns.set_palette(\"tab10\")\n",
    "g = sns.barplot(\n",
    "    data=df,\n",
    "    x=\"l0\", y=\"mean_absorption_fraction_score\", hue=\"SAE\",\n",
    "    errorbar=\"sd\", \n",
    "    legend = False, \n",
    "    alpha=.99, ax=axs[0]\n",
    ")\n",
    "\n",
    "\n",
    "g.set_xlabel('$\\\\ell_0$', fontsize=18)\n",
    "g.set_ylabel('Mean Absorption Fraction Score', fontsize=18)\n",
    "\n",
    "g.tick_params(axis='both', which='major', labelsize=14)\n",
    "g.tick_params(axis='both', which='minor', labelsize=10)\n",
    "\n",
    "\n",
    "g1 = sns.barplot(\n",
    "    data=df, \n",
    "    x=\"l0\", y=\"mean_full_absorption_score\", hue=\"SAE\",\n",
    "    errorbar=\"sd\", \n",
    "    legend = False, \n",
    "    alpha=.99,  ax=axs[1]\n",
    ")\n",
    "\n",
    "g1.tick_params(axis='both', which='major', labelsize=14)\n",
    "g1.tick_params(axis='both', which='minor', labelsize=10)\n",
    "g1.set_xlabel('$\\\\ell_0$', fontsize=18)\n",
    "g1.set_ylabel('Mean Full Absorption Score', fontsize=18)\n",
    "g2 = sns.barplot(\n",
    "    data=df, \n",
    "    x=\"l0\", y=\"mean_num_split_features\", hue=\"SAE\",\n",
    "    errorbar=\"sd\", \n",
    "    alpha=.99,  ax=axs[2]\n",
    ")\n",
    "\n",
    "g2.set_xlabel('$\\\\ell_0$', fontsize=18)\n",
    "g2.set_ylabel('Mean Number of Feature Splits', fontsize=18)\n",
    "g2.tick_params(axis='both', which='major', labelsize=14)\n",
    "g2.tick_params(axis='both', which='minor', labelsize=10)\n",
    "\n",
    "handles, labels = g2.get_legend_handles_labels()\n",
    "\n",
    "\n",
    "empty_handle = Line2D(\n",
    "    [0], [0], color=\"white\", marker='', linestyle=\"\", markersize=4, label=''\n",
    ")\n",
    "\n",
    "custom_handles =  [empty_handle] + handles[:-1] + [empty_handle]  + [handles[-1]]\n",
    "custom_labels = ['KronSAE: ', '$m=2$ $n=8$', '$m=4$ $n=4$', '$m=2$ $n=16$' , '$m=2$ $n=4$', '$m=4$ $n=8$', 'SAE:', 'TopK']\n",
    "\n",
    "g.legend(\n",
    "    handles=custom_handles,\n",
    "    labels=custom_labels,\n",
    "    loc='upper right',\n",
    "    fontsize=14,\n",
    ")\n",
    "g2.get_legend().remove()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'absorption_score.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e62da582-1b8f-4bd2-b778-85aea5212191",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
