{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "local_dir = \"./graphing_eval_results/\"\n",
    "image_path = \"./temp_images\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_type = \"pca_sae\"\n",
    "include_baseline = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns\n",
    "\n",
    "selections = {\n",
    "    # \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "    #     [\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*layer_{layer}.*(16k|65k|1m).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "    #     [\n",
    "    #         r\"(gemma-scope-9b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*layer_{layer}.*(16k|131k|1m).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B TopK Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B Vanilla Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 4K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 16K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*layer_{layer}.*(16k).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    \"SAE Bench Gemma-2-2B 65K Width Series\": [\n",
    "        [\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "            r\"(gemma-scope-2b-pt-res)\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*layer_{layer}.*(65k).*\",\n",
    "        ],\n",
    "    ],\n",
    "    # \"SAE Bench Pythia-70M SAE Type Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_pythia70m_sweep.*_ctx128_.*\",\n",
    "    #     ],\n",
    "    #     [r\".*blocks\\.({layer})\\.hook_resid_post__trainer_.*\"],\n",
    "    # ],\n",
    "}\n",
    "\n",
    "eval_path = f\"{local_dir}/unlearning\"\n",
    "eval_path = f\"{local_dir}/scr\"\n",
    "eval_path = f\"{local_dir}/sparse_probing\"\n",
    "# eval_path = f\"{local_dir}/tpp\"\n",
    "# eval_path = f\"{local_dir}/absorption\"\n",
    "\n",
    "eval_paths = [\n",
    "    # f\"{local_dir}/sparse_probing\",\n",
    "    # f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    f\"{local_dir}/core\",\n",
    "    # f\"{local_dir}/absorption\",\n",
    "    f\"{local_dir}/autointerp\",\n",
    "    f\"{local_dir}/unlearning\",\n",
    "]\n",
    "\n",
    "core_results_path = f\"{local_dir}/core\"\n",
    "\n",
    "combinations = list(itertools.product(eval_paths, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [50],\n",
    "    \"tpp\": [50],\n",
    "    \"sparse_probing\": [1],\n",
    "}\n",
    "\n",
    "figures = []\n",
    "\n",
    "for eval_path, selection in combinations:\n",
    "    eval_type = eval_path.split(\"/\")[-1]\n",
    "\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    if \"Gemma-2-2B\" in selection:\n",
    "        layers = [12]\n",
    "        model_name = \"gemma-2-2b\"\n",
    "    # elif \"Pythia-70M\" in selection:\n",
    "    #     layers = [3, 4]\n",
    "    # elif \"Gemma-2-9B\" in selection:\n",
    "    #     layers = [9, 20, 31]\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = selections[selection][0]\n",
    "        sae_block_patterns = list(selections[selection][1])\n",
    "\n",
    "        for i, pattern in enumerate(sae_block_patterns):\n",
    "            sae_block_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        suptitle = f\"{selection} Layer {layer}\\n\"\n",
    "        prefix = \"\"\n",
    "\n",
    "        image_base_folder = os.path.join(image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        selected_saes = select_saes_multiple_patterns(\n",
    "            sae_regex_patterns, sae_block_patterns\n",
    "        )\n",
    "\n",
    "        baseline_sae = None\n",
    "        baseline_label = None\n",
    "\n",
    "        if include_baseline:\n",
    "            if model_name != \"gemma-2-9b\":\n",
    "                baseline_sae = (f\"{model_name}_layer_{layer}_pca_sae\", \"custom_sae\")\n",
    "                baseline_label = \"PCA Baseline\"\n",
    "\n",
    "        fig = graphing_utils.get_single_figure(\n",
    "            selected_saes,\n",
    "            eval_path,\n",
    "            core_results_path,\n",
    "            image_base_name,\n",
    "            k=k,\n",
    "            title=prefix,\n",
    "            baseline_sae=baseline_sae,\n",
    "            baseline_label=baseline_label,\n",
    "        )\n",
    "        figures.append(fig)\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(12, 8))\n",
    "\n",
    "# Add a main title\n",
    "fig.suptitle(suptitle, fontsize=20, y=0.92)\n",
    "\n",
    "# Adjust the spacing between subplots\n",
    "plt.subplots_adjust(\n",
    "    wspace=0.1,  # Width spacing between subplots\n",
    "    hspace=0.0,  # Height spacing between subplots\n",
    "    left=0.05,  # Left margin\n",
    "    right=0.9,  # Right margin\n",
    "    top=0.90,  # Top margin - adjusted to make room for suptitle\n",
    "    bottom=0.05,  # Bottom margin\n",
    ")\n",
    "\n",
    "\n",
    "for ax, sub_fig in zip(axes.flatten(), figures):\n",
    "    # Draw each figure's content onto the subplot axes\n",
    "    sub_fig.canvas.draw()\n",
    "    ax.imshow(sub_fig.canvas.buffer_rgba())\n",
    "    ax.axis(\"off\")  # Optional: Turn off axis for clarity\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"2x2_plot.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns\n",
    "\n",
    "selections = {\n",
    "    \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "        [\n",
    "            r\"(gemma-scope-2b-pt-res)\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*layer_{layer}.*(16k|65k|1m).*\",\n",
    "        ],\n",
    "    ],\n",
    "    # \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "    #     [\n",
    "    #         r\"(gemma-scope-9b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*layer_{layer}.*(16k|131k|1m).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B TopK Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B Vanilla Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 4K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 16K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*layer_{layer}.*(16k).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 65K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*layer_{layer}.*(65k).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Pythia-70M SAE Type Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_pythia70m_sweep.*_ctx128_.*\",\n",
    "    #     ],\n",
    "    #     [r\".*blocks\\.({layer})\\.hook_resid_post__trainer_.*\"],\n",
    "    # ],\n",
    "}\n",
    "\n",
    "eval_path = f\"{local_dir}/unlearning\"\n",
    "eval_path = f\"{local_dir}/scr\"\n",
    "eval_path = f\"{local_dir}/sparse_probing\"\n",
    "# eval_path = f\"{local_dir}/tpp\"\n",
    "# eval_path = f\"{local_dir}/absorption\"\n",
    "\n",
    "eval_paths = [\n",
    "    f\"{local_dir}/sparse_probing\",\n",
    "    # f\"{local_dir}/tpp\",\n",
    "    # f\"{local_dir}/scr\",\n",
    "    # f\"{local_dir}/core\",\n",
    "    # f\"{local_dir}/absorption\",\n",
    "    # f\"{local_dir}/autointerp\",\n",
    "    # f\"{local_dir}/unlearning\",\n",
    "]\n",
    "\n",
    "core_results_path = f\"{local_dir}/core\"\n",
    "\n",
    "combinations = list(itertools.product(eval_paths, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [50],\n",
    "    \"tpp\": [50],\n",
    "    \"sparse_probing\": [1],\n",
    "}\n",
    "\n",
    "figures = []\n",
    "\n",
    "for eval_path, selection in combinations:\n",
    "    eval_type = eval_path.split(\"/\")[-1]\n",
    "\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    if \"Gemma-2-2B\" in selection:\n",
    "        layers = [5, 19]\n",
    "        model_name = \"gemma-2-2b\"\n",
    "    # elif \"Pythia-70M\" in selection:\n",
    "    #     layers = [3, 4]\n",
    "    # elif \"Gemma-2-9B\" in selection:\n",
    "    #     layers = [9, 20, 31]\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = selections[selection][0]\n",
    "        sae_block_patterns = list(selections[selection][1])\n",
    "\n",
    "        for i, pattern in enumerate(sae_block_patterns):\n",
    "            sae_block_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        suptitle = f\"{selection}\\nSparse Probing Top {k} Accuracy\"\n",
    "\n",
    "        image_base_folder = os.path.join(image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        title = f\"Layer {layer}\"\n",
    "\n",
    "        selected_saes = select_saes_multiple_patterns(\n",
    "            sae_regex_patterns, sae_block_patterns\n",
    "        )\n",
    "\n",
    "        baseline_sae = None\n",
    "        baseline_label = None\n",
    "\n",
    "        if include_baseline:\n",
    "            if model_name != \"gemma-2-9b\":\n",
    "                baseline_sae = (f\"{model_name}_layer_{layer}_pca_sae\", \"custom_sae\")\n",
    "                baseline_label = \"PCA Baseline\"\n",
    "\n",
    "        fig = graphing_utils.get_single_figure(\n",
    "            selected_saes,\n",
    "            eval_path,\n",
    "            core_results_path,\n",
    "            image_base_name,\n",
    "            k=k,\n",
    "            title=title,\n",
    "            plot_type=False,\n",
    "            baseline_sae=baseline_sae,\n",
    "            baseline_label=baseline_label,\n",
    "        )\n",
    "        figures.append(fig)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 8))\n",
    "\n",
    "# Add a main title\n",
    "fig.suptitle(suptitle, fontsize=20, y=0.75)\n",
    "\n",
    "# Adjust the spacing between subplots\n",
    "plt.subplots_adjust(\n",
    "    wspace=0.1,  # Width spacing between subplots\n",
    "    hspace=0.0,  # Height spacing between subplots\n",
    "    left=0.05,  # Left margin\n",
    "    right=0.9,  # Right margin\n",
    "    top=0.90,  # Top margin - adjusted to make room for suptitle\n",
    "    bottom=0.05,  # Bottom margin\n",
    ")\n",
    "\n",
    "\n",
    "for ax, sub_fig in zip(axes.flatten(), figures):\n",
    "    # Draw each figure's content onto the subplot axes\n",
    "    sub_fig.canvas.draw()\n",
    "    ax.imshow(sub_fig.canvas.buffer_rgba())\n",
    "    ax.axis(\"off\")  # Optional: Turn off axis for clarity\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"1x2_sparse_probing_plot.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns\n",
    "\n",
    "selections = {\n",
    "    # \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "    #     [\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*layer_{layer}.*(16k|65k|1m).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "    #     [\n",
    "    #         r\"(gemma-scope-9b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*layer_{layer}.*(16k|131k|1m).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B TopK Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B Vanilla Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Gemma-2-2B 4K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    \"SAE Bench Gemma-2-2B 16K Width Series\": [\n",
    "        [\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "            r\"(gemma-scope-2b-pt-res)\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*layer_{layer}.*(16k).*\",\n",
    "        ],\n",
    "    ],\n",
    "    # \"SAE Bench Gemma-2-2B 65K Width Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "    #         r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "    #         r\"(gemma-scope-2b-pt-res)\",\n",
    "    #     ],\n",
    "    #     [\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "    #         r\".*layer_{layer}.*(65k).*\",\n",
    "    #     ],\n",
    "    # ],\n",
    "    # \"SAE Bench Pythia-70M SAE Type Series\": [\n",
    "    #     [\n",
    "    #         r\"sae_bench_pythia70m_sweep.*_ctx128_.*\",\n",
    "    #     ],\n",
    "    #     [r\".*blocks\\.({layer})\\.hook_resid_post__trainer_.*\"],\n",
    "    # ],\n",
    "}\n",
    "\n",
    "eval_path = f\"{local_dir}/unlearning\"\n",
    "eval_path = f\"{local_dir}/scr\"\n",
    "eval_path = f\"{local_dir}/sparse_probing\"\n",
    "# eval_path = f\"{local_dir}/tpp\"\n",
    "# eval_path = f\"{local_dir}/absorption\"\n",
    "\n",
    "eval_paths = [\n",
    "    # f\"{local_dir}/sparse_probing\",\n",
    "    f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    # f\"{local_dir}/core\",\n",
    "    # f\"{local_dir}/absorption\",\n",
    "    # f\"{local_dir}/autointerp\",\n",
    "    # f\"{local_dir}/unlearning\",\n",
    "]\n",
    "\n",
    "core_results_path = f\"{local_dir}/core\"\n",
    "\n",
    "combinations = list(itertools.product(eval_paths, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [50],\n",
    "    \"tpp\": [50],\n",
    "    \"sparse_probing\": [1],\n",
    "}\n",
    "\n",
    "figures = []\n",
    "\n",
    "for eval_path, selection in combinations:\n",
    "    eval_type = eval_path.split(\"/\")[-1]\n",
    "\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    if \"Gemma-2-2B\" in selection:\n",
    "        layers = [12]\n",
    "    # elif \"Pythia-70M\" in selection:\n",
    "    #     layers = [3, 4]\n",
    "    # elif \"Gemma-2-9B\" in selection:\n",
    "    #     layers = [9, 20, 31]\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = selections[selection][0]\n",
    "        sae_block_patterns = list(selections[selection][1])\n",
    "\n",
    "        for i, pattern in enumerate(sae_block_patterns):\n",
    "            sae_block_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        suptitle = f\"{selection} Layer {layer}\"\n",
    "\n",
    "        image_base_folder = os.path.join(image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        selected_saes = select_saes_multiple_patterns(\n",
    "            sae_regex_patterns, sae_block_patterns\n",
    "        )\n",
    "\n",
    "        baseline_sae = None\n",
    "        baseline_label = None\n",
    "\n",
    "        if include_baseline:\n",
    "            if model_name != \"gemma-2-9b\":\n",
    "                baseline_sae = (f\"{model_name}_layer_{layer}_pca_sae\", \"custom_sae\")\n",
    "                baseline_label = \"PCA Baseline\"\n",
    "\n",
    "        fig = graphing_utils.get_single_figure(\n",
    "            selected_saes,\n",
    "            eval_path,\n",
    "            core_results_path,\n",
    "            image_base_name,\n",
    "            k=k,\n",
    "            baseline_sae=baseline_sae,\n",
    "            baseline_label=baseline_label,\n",
    "        )\n",
    "        figures.append(fig)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 8))\n",
    "\n",
    "# Add a main title\n",
    "fig.suptitle(suptitle, fontsize=20, y=0.75)\n",
    "\n",
    "# Adjust the spacing between subplots\n",
    "plt.subplots_adjust(\n",
    "    wspace=0.1,  # Width spacing between subplots\n",
    "    hspace=0.0,  # Height spacing between subplots\n",
    "    left=0.05,  # Left margin\n",
    "    right=0.9,  # Right margin\n",
    "    top=0.90,  # Top margin - adjusted to make room for suptitle\n",
    "    bottom=0.05,  # Bottom margin\n",
    ")\n",
    "\n",
    "\n",
    "for ax, sub_fig in zip(axes.flatten(), figures):\n",
    "    # Draw each figure's content onto the subplot axes\n",
    "    sub_fig.canvas.draw()\n",
    "    ax.imshow(sub_fig.canvas.buffer_rgba())\n",
    "    ax.axis(\"off\")  # Optional: Turn off axis for clarity\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"1x2_scr_tpp.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
