{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "daf9aced",
   "metadata": {},
   "source": [
    "# Open source SAE L0 analysis\n",
    "\n",
    "Looking through the open-source SAEs with known L0 from SAELens / Neuronpedia.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0251063",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory\n",
    "\n",
    "saes = get_pretrained_saes_directory()\n",
    "\n",
    "# Fix the model_l0s data structure to correctly group by model\n",
    "model_l0s = defaultdict(list)\n",
    "for sae_info in saes.values():\n",
    "    if sae_info.expected_l0 is None:\n",
    "        continue\n",
    "    for name, l0 in sae_info.expected_l0.items():\n",
    "        if sae_info.neuronpedia_id.get(name) is None:\n",
    "            continue\n",
    "        if l0 > 0:\n",
    "            model_l0s[sae_info.model].append(l0)\n",
    "\n",
    "print(\"Models found:\", list(model_l0s.keys()))\n",
    "print(\"Number of L0 values per model:\")\n",
    "for model, l0_values in model_l0s.items():\n",
    "    print(f\"  {model}: {len(l0_values)} SAEs\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c5242d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "sns.set_theme()\n",
    "\n",
    "# Create histograms for each model\n",
    "models = list(model_l0s.keys())\n",
    "n_models = len(models)\n",
    "\n",
    "# Calculate grid dimensions\n",
    "n_cols = min(3, n_models)  # Max 3 columns\n",
    "n_rows = (n_models + n_cols - 1) // n_cols  # Ceiling division\n",
    "\n",
    "fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))\n",
    "\n",
    "# Handle case where we only have one subplot\n",
    "if n_models == 1:\n",
    "    axes = [axes]\n",
    "elif n_rows == 1:\n",
    "    axes = [axes] if n_models == 1 else axes\n",
    "else:\n",
    "    axes = axes.flatten()\n",
    "\n",
    "for i, model in enumerate(models):\n",
    "    l0_values = model_l0s[model]\n",
    "    \n",
    "    # Create histogram\n",
    "    axes[i].hist(l0_values, bins=30, alpha=0.7, edgecolor='black')\n",
    "    axes[i].set_title(f'{model}\\n({len(l0_values)} SAEs)')\n",
    "    axes[i].set_xlabel('L0 (Average number of active features)')\n",
    "    axes[i].set_ylabel('Number of SAEs')\n",
    "    axes[i].grid(True, alpha=0.3)\n",
    "    \n",
    "    # Add statistics as text\n",
    "    mean_l0 = np.mean(l0_values)\n",
    "    median_l0 = np.median(l0_values)\n",
    "    axes[i].axvline(mean_l0, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_l0:.1f}')\n",
    "    axes[i].axvline(median_l0, color='orange', linestyle='--', alpha=0.7, label=f'Median: {median_l0:.1f}')\n",
    "    axes[i].legend(fontsize=8)\n",
    "\n",
    "# Hide any unused subplots\n",
    "for i in range(n_models, len(axes)):\n",
    "    axes[i].set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "Path(\"plots\").mkdir(parents=True, exist_ok=True)\n",
    "plt.savefig(\"plots/open_source_saes_analysis.pdf\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
