{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting Custom Metric Results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.general_utils as general_utils\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NOTE\n",
    "\n",
    "This graphing notebook is fairly complex and primarily for producing a variety of SAE Bench specific plots. You can run this notebook to download our SAE Bench data and replicate all of our graphs. This notebook may also serve as inspiration for creating your own custom graphs. If you need to graph your own results, by default we recommend referring to `sae_bench_demo.ipynb` for a simpler graphing notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # SAE Bench USE ONLY\n",
    "\n",
    "# # This cell will download the SAE Bench results for graphing.\n",
    "\n",
    "# from huggingface_hub import snapshot_download\n",
    "\n",
    "# hf_repo_id = \"adamkarvonen/new_sae_bench_results\"\n",
    "local_dir = \"./graphing_eval_results_0119\"\n",
    "# os.makedirs(local_dir, exist_ok=True)\n",
    "\n",
    "# snapshot_download(\n",
    "#     repo_id=hf_repo_id,\n",
    "#     local_dir=local_dir,\n",
    "#     repo_type=\"dataset\",\n",
    "#     ignore_patterns=[\n",
    "#         \"*autointerp_with_generations*\",\n",
    "#         \"*core_with_feature_statistics*\",\n",
    "#     ],  # These use significant disk space / download time and are not needed for graphing\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # SAE Bench USE ONLY\n",
    "\n",
    "# # The purpose of this is that we currently organize results like this:\n",
    "\n",
    "# # {eval_type}/{sae_release}/{sae_release}_{sae_id}_eval_results.json\n",
    "# # because we have results for over 600 SAEs\n",
    "\n",
    "# # However, the current scripts output results like this:\n",
    "# # {eval_type}/{sae_release}_{sae_id}_eval_results.json\n",
    "\n",
    "# # So, we just flatten the sae bench results to match the expected format\n",
    "\n",
    "# import os\n",
    "# import shutil\n",
    "# import glob\n",
    "# from pathlib import Path\n",
    "\n",
    "# # Get all immediate subdirectories in eval_results\n",
    "# main_dirs = [d for d in os.listdir(local_dir) if os.path.isdir(os.path.join(local_dir, d))]\n",
    "\n",
    "# for main_dir in main_dirs:\n",
    "#     main_dir_path = os.path.join(local_dir, main_dir)\n",
    "#     print(f\"\\nProcessing {main_dir}...\")\n",
    "\n",
    "#     # Get all subdirectories in the current directory\n",
    "#     subdirs = [d for d in os.listdir(main_dir_path) if os.path.isdir(os.path.join(main_dir_path, d))]\n",
    "\n",
    "#     for subdir in subdirs:\n",
    "#         if not subdir.startswith('.'): # Skip hidden directories\n",
    "#             subdir_path = os.path.join(main_dir_path, subdir)\n",
    "#             print(f\"Moving files from {subdir}\")\n",
    "\n",
    "#             # Get all files in the subdirectory\n",
    "#             files = glob.glob(os.path.join(subdir_path, '*'))\n",
    "\n",
    "#             for file_path in files:\n",
    "#                 if os.path.isfile(file_path):  # Make sure it's a file, not a directory\n",
    "#                     file_name = os.path.basename(file_path)\n",
    "#                     destination = os.path.join(main_dir_path, file_name)\n",
    "\n",
    "#                     # Handle file name conflicts\n",
    "#                     if os.path.exists(destination):\n",
    "#                         base, extension = os.path.splitext(file_name)\n",
    "#                         counter = 1\n",
    "#                         while os.path.exists(destination):\n",
    "#                             new_name = f\"{base}_{counter}{extension}\"\n",
    "#                             destination = os.path.join(main_dir_path, new_name)\n",
    "#                             counter += 1\n",
    "\n",
    "#                     # Move the file\n",
    "#                     try:\n",
    "#                         shutil.move(file_path, destination)\n",
    "#                         print(f\"  Moved: {file_name} -> {os.path.basename(os.path.dirname(destination))}/\")\n",
    "#                     except Exception as e:\n",
    "#                         print(f\"  Error moving {file_name}: {str(e)}\")\n",
    "\n",
    "# print(\"\\nFile moving complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data\n",
    "\n",
    "Add all folders with results to `results_folders`. Also, select your `eval_type`.\n",
    "\n",
    "We need a custom eval (SCR, Sparse Probing, etc.) and the core eval results, as we plot the SAEs by L0 / Loss Recovered in most plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folders = [\n",
    "    \"./matroyshka_0117\",\n",
    "    local_dir,\n",
    "]\n",
    "\n",
    "# eval_type = \"sparse_probing\"\n",
    "\n",
    "# If we want to graph L0 vs Loss Recovered, we can set the eval type to core\n",
    "eval_type = \"core\"\n",
    "eval_type = \"absorption\"\n",
    "\n",
    "eval_folders = []\n",
    "core_folders = []\n",
    "\n",
    "for results_folder in results_folders:\n",
    "    eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "    core_folders.append(f\"{results_folder}/core\")\n",
    "\n",
    "image_path = \"./images\"\n",
    "\n",
    "if not os.path.exists(image_path):\n",
    "    os.makedirs(image_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell stores both the custom eval (e.g. SCR or sparse probing) and the core evals (L0 / Loss Recovered) for every SAE identified by the regex pattern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
    "\n",
    "eval_results = graphing_utils.get_eval_results(eval_filenames)\n",
    "core_results = graphing_utils.get_core_results(core_filenames)\n",
    "\n",
    "for sae in eval_results:\n",
    "    eval_results[sae].update(core_results[sae])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_eval_results = {}\n",
    "\n",
    "for sae in eval_results:\n",
    "    old_sae = sae\n",
    "    if \"matroyshka\" in sae:\n",
    "        eval_results[sae][\"sae_class\"] = \"matryoshka_batch_topk\"\n",
    "        sae = sae.replace(\"matroyshka\", \"matryoshka\")\n",
    "    if \"matryoshka\" in sae and \"65k\" in sae:\n",
    "        new_eval_results[sae] = eval_results[old_sae]\n",
    "    elif \"matryoshka\" in sae and \"notemp\" in sae:\n",
    "        new_eval_results[sae] = eval_results[old_sae]\n",
    "    elif \"matryoshka\" in sae:\n",
    "        continue\n",
    "    elif \"checkpoint\" in sae:\n",
    "        continue\n",
    "    else:\n",
    "        new_eval_results[sae] = eval_results[old_sae]\n",
    "\n",
    "eval_results = new_eval_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_names = list(eval_results.keys())\n",
    "\n",
    "print(eval_results.keys())\n",
    "print(\"\\nAvailable SAEs:\\n\", eval_results.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes we may want to graph and differentiate SAEs of the same type. For example, we may have trained multiple SAEs with different hyperparameters. In that case, we can use the next cell to customize the `sae_class` field for graphing purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def customize_class(sae_dict: dict) -> dict:\n",
    "    for sae_name in sae_dict:\n",
    "        if \"temp_1_\" in sae_name:\n",
    "            sae_dict[sae_name][\"sae_class\"] = \"temp 1\"\n",
    "        elif \"temp_2_\" in sae_name:\n",
    "            sae_dict[sae_name][\"sae_class\"] = \"temp 2\"\n",
    "        elif \"temp_100_\" in sae_name:\n",
    "            sae_dict[sae_name][\"sae_class\"] = \"temp 100\"\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown class for {sae_name}\")\n",
    "\n",
    "    return sae_dict\n",
    "\n",
    "\n",
    "# eval_results = customize_class(eval_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For plotting purposes we also want dictionary size, sae type, and number of training steps. The following cell populates this information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot custom metric above unsupervised metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\nAvailable custom metrics:\\n\", eval_results[sae_names[0]].keys())\n",
    "print(eval_results[sae_names[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = None\n",
    "k = 2\n",
    "\n",
    "eval_path = eval_folders[0]\n",
    "\n",
    "custom_metric, custom_metric_name = graphing_utils.get_custom_metric_key_and_name(\n",
    "    eval_path, k\n",
    ")\n",
    "\n",
    "if \"tpp\" in eval_path:\n",
    "    custom_metric = f\"tpp_threshold_{k}_total_metric\"\n",
    "    custom_metric_name = f\"TPP Top {k} Metric\"\n",
    "elif \"scr\" in eval_path:\n",
    "    custom_metric = f\"scr_metric_threshold_{k}\"\n",
    "    custom_metric_name = f\"SCR Top {k} Metric\"\n",
    "elif \"sparse_probing\" in eval_path:\n",
    "    custom_metric = f\"sae_top_{k}_test_accuracy\"\n",
    "    custom_metric_name = f\"Sparse Probing Top {k} Test Accuracy\"\n",
    "elif \"absorption\" in eval_path:\n",
    "    custom_metric = \"mean_absorption_fraction_score\"\n",
    "    custom_metric_name = \"Mean Absorption Score\"\n",
    "elif \"autointerp\" in eval_path:\n",
    "    custom_metric = \"autointerp_score\"\n",
    "    custom_metric_name = \"Autointerp Score\"\n",
    "elif \"core\" in eval_path:\n",
    "    custom_metric = \"ce_loss_score\"\n",
    "    custom_metric_name = \"Loss Recovered\"\n",
    "else:\n",
    "    raise ValueError(\"Please add the correct key for the custom metric\")\n",
    "\n",
    "image_base_name = os.path.join(image_path, custom_metric)\n",
    "print(custom_metric, custom_metric_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "title_3var = f\"L0 vs Loss Recovered vs {custom_metric_name}\"\n",
    "title_2var = f\"L0 vs {custom_metric_name}\"\n",
    "\n",
    "# graphing_utils.plot_3var_graph(\n",
    "#     eval_results,\n",
    "#     title_3var,\n",
    "#     custom_metric,\n",
    "#     colorbar_label=\"Custom Metric\",\n",
    "#     output_filename=f\"{image_base_name}_3var.png\",\n",
    "# )\n",
    "graphing_utils.plot_2var_graph(\n",
    "    eval_results,\n",
    "    custom_metric,\n",
    "    y_label=custom_metric_name,\n",
    "    title=title_2var,\n",
    "    output_filename=f\"{image_base_name}_2var.png\",\n",
    ")\n",
    "# plot_interactive_3var_graph(plotting_results, custom_metric)\n",
    "\n",
    "# At this point, if there's any additional .json files located alongside the ae.pt and eval_results.json\n",
    "# You can easily adapt them to be included in the plotting_results dictionary by using something similar to add_ae_config_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ...with interactive hovering\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# graphing_utils.plot_interactive_3var_graph(\n",
    "#     eval_results,\n",
    "#     custom_metric,\n",
    "#     title=title_3var,\n",
    "#     output_filename=f\"{image_base_name}_3var_interactive.html\",\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise ValueError(\"Stop here\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graphing_utils.plot_2var_graph_dict_size(\n",
    "    eval_results,\n",
    "    custom_metric,\n",
    "    y_label=custom_metric_name,\n",
    "    title=title_2var,\n",
    "    output_filename=f\"{image_base_name}_2var.png\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot metric over training checkpoints\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: We have SAE checkpoints at initialization (step 0), which does not fit on\n",
    "a log scale (log(0) = -inf). We visualize this with a cut in the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# graphing_utils.plot_training_steps(\n",
    "#     eval_results,\n",
    "#     custom_metric,\n",
    "#     title=f\"Tokens vs {custom_metric_name} Gemma Layer {layer}\",\n",
    "#     output_filename=f\"{image_base_name}_tokens_vs_diff.png\",\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell combines all of the above steps into a single function so we can plot results from multiple runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folders = [\"./matroyshka_0117\", local_dir]\n",
    "\n",
    "eval_type = \"absorption\"\n",
    "\n",
    "eval_folders = []\n",
    "core_folders = []\n",
    "\n",
    "for results_folder in results_folders:\n",
    "    eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "    core_folders.append(f\"{results_folder}/core\")\n",
    "\n",
    "eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
    "\n",
    "\n",
    "sae_regex_patterns = [\n",
    "    r\"saebench_gemma-2-2b_width-2pow16_date-0108_TopK_gemma-2-2b__0108_resid_post_layer_12_.*\",\n",
    "    r\"saebench_gemma-2-2b_width-2pow16_date-0108_Standard_gemma-2-2b__0108_resid_post_layer_12_.*\",\n",
    "\n",
    "]\n",
    "\n",
    "filtered_eval_filenames = general_utils.filter_with_regex(\n",
    "    eval_filenames, sae_regex_patterns\n",
    ")\n",
    "filtered_core_filenames = general_utils.filter_with_regex(\n",
    "    core_filenames, sae_regex_patterns\n",
    ")\n",
    "\n",
    "print(filtered_eval_filenames)\n",
    "# raise ValueError\n",
    "\n",
    "eval_results = graphing_utils.get_eval_results(filtered_eval_filenames)\n",
    "core_results = graphing_utils.get_core_results(filtered_core_filenames)\n",
    "\n",
    "for sae in eval_results:\n",
    "    eval_results[sae].update(core_results[sae])\n",
    "\n",
    "graphing_utils.plot_results(\n",
    "    filtered_eval_filenames,\n",
    "    filtered_core_filenames,\n",
    "    eval_type,\n",
    "    image_base_name,\n",
    "    k=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SAE Bench Plot Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.general_utils as general_utils\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "\n",
    "selections = {\n",
    "    # \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "    #     r\"gemma-scope-2b-pt-res_layer_{layer}_width_(16k|65k|1m)\",\n",
    "    # ],\n",
    "    # \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "    #     r\"gemma-scope-9b-pt-res_layer_{layer}_width_(16k|131k|1m)\",\n",
    "    # ],\n",
    "    \"SAE Bench Gemma-2-2B All Width Series\": [\n",
    "        r\"saebench_gemma-2-2b_width-2pow14_date-0108(?!.*step)(?!.*Standard).*\",\n",
    "        r\"saebench_gemma-2-2b_width-2pow16_date-0108(?!.*step)(?!.*Standard).*\",\n",
    "    ],\n",
    "    \"SAE Bench Gemma-2-2B Matryoshka Width Series\": [\n",
    "        r\"matroyshka_gemma-2-2b-16k-v2_MatroyshkaBatchTopKTrainer_notemp.*\",\n",
    "        r\"matroyshka_gemma-2-2b-16k-v2_MatryoshkaBatchTopKTrainer_65k_temp1000.*\",\n",
    "    ],\n",
    "    # \"SAE Bench Gemma-2-2B 4K Width Series\": [\n",
    "    #     r\"saebench_gemma-2-2b_width-2pow12_date-0108(?!.*step).*\",\n",
    "    # ],\n",
    "    \"SAE Bench Gemma-2-2B 16K Width Series\": [\n",
    "        r\"saebench_gemma-2-2b_width-2pow14_date-0108(?!.*step).*\",\n",
    "        r\"matroyshka_gemma-2-2b-16k-v2_MatroyshkaBatchTopKTrainer_notemp.*\",\n",
    "    ],\n",
    "    \"SAE Bench Gemma-2-2B 65K Width Series\": [\n",
    "        r\"saebench_gemma-2-2b_width-2pow16_date-0108(?!.*step).*\",\n",
    "        r\"matroyshka_gemma-2-2b-16k-v2_MatryoshkaBatchTopKTrainer_65k_temp1000.*\",\n",
    "    ],\n",
    "    # \"SAE Bench Pythia-70M SAE Type Series\": [\n",
    "    #     r\"sae_bench_pythia70m_sweep.*_ctx128_.*blocks\\.({layer})\\.hook_resid_post__trainer_.*\",\n",
    "    # ],\n",
    "}\n",
    "\n",
    "results_folders = [\"./graphing_eval_results_0119\", \"./matroyshka_0117\"]\n",
    "baseline_folder = results_folders[0]\n",
    "\n",
    "eval_type = \"absorption\"\n",
    "\n",
    "eval_types = [\n",
    "    \"scr\",\n",
    "    \"tpp\",\n",
    "    \"sparse_probing\",\n",
    "    \"absorption\",\n",
    "    \"core\",\n",
    "    \"autointerp\",\n",
    "    \"unlearning\",\n",
    "]\n",
    "\n",
    "combinations = list(itertools.product(eval_types, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [5, 10, 20, 50, 500],\n",
    "    \"scr\": [5, 10, 20, 50, 500],\n",
    "    \"sparse_probing\": [1, 2, 5],\n",
    "}\n",
    "\n",
    "baseline_type = \"pca_sae\"\n",
    "include_baseline = False\n",
    "image_path = \"./images\"\n",
    "\n",
    "\n",
    "for eval_type, selection in combinations:\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    if \"Gemma-2-2B\" in selection:\n",
    "        layers = [12]\n",
    "        model_name = \"gemma-2-2b\"\n",
    "    elif \"Pythia-70M\" in selection:\n",
    "        layers = [3, 4]\n",
    "        model_name = \"pythia-70m-deduped\"\n",
    "    elif \"Gemma-2-9B\" in selection:\n",
    "        layers = [9, 20, 31]\n",
    "        model_name = \"gemma-2-9b\"\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = selections[selection]\n",
    "\n",
    "        for i, pattern in enumerate(sae_regex_patterns):\n",
    "            sae_regex_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        prefix = f\"{selection} Layer {layer}\\n\"\n",
    "\n",
    "        image_base_folder = os.path.join(image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        baseline_sae = None\n",
    "        baseline_label = None\n",
    "\n",
    "        if include_baseline:\n",
    "            if model_name != \"gemma-2-9b\":\n",
    "                baseline_sae = (\n",
    "                    f\"{model_name}_layer_{layer}_pca_sae_custom_sae_eval_results.json\"\n",
    "                )\n",
    "                baseline_sae = os.path.join(baseline_folder, eval_type, baseline_sae)\n",
    "                baseline_label = \"PCA Baseline\"\n",
    "\n",
    "        eval_folders = []\n",
    "        core_folders = []\n",
    "\n",
    "        for results_folder in results_folders:\n",
    "            eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "            core_folders.append(f\"{results_folder}/core\")\n",
    "\n",
    "        eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "        core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
    "        \n",
    "        filtered_eval_filenames = general_utils.filter_with_regex(\n",
    "            eval_filenames, sae_regex_patterns\n",
    "        )\n",
    "        filtered_core_filenames = general_utils.filter_with_regex(\n",
    "            core_filenames, sae_regex_patterns\n",
    "        )\n",
    "\n",
    "        # We will have failures on e.g. Pythia and unlearning, because the results don't exist\n",
    "        try:\n",
    "            graphing_utils.plot_results(\n",
    "                filtered_eval_filenames,\n",
    "                filtered_core_filenames,\n",
    "                eval_type,\n",
    "                image_base_name,\n",
    "                k=k,\n",
    "                title_prefix=prefix,\n",
    "                baseline_sae_path=baseline_sae,\n",
    "                baseline_label=baseline_label,\n",
    "            )\n",
    "        except Exception as e:\n",
    "            print(f\"Error plotting {selection} Layer {layer}: {str(e)}\")\n",
    "            continue\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.general_utils as general_utils\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "\n",
    "selections = {\n",
    "    # \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "    #     r\"gemma-scope-2b-pt-res_layer_{layer}_width_(16k|65k|1m)\",\n",
    "    # ],\n",
    "    # \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "    #     r\"gemma-scope-9b-pt-res_layer_{layer}_width_(16k|131k|1m)\",\n",
    "    # ],\n",
    "    \"SAE Bench Gemma-2-2B Checkpoint Series\": [\n",
    "        r\"saebench_gemma-2-2b_width-2pow14_date-0108.*(_Standard|_TopK).*\",\n",
    "    ],\n",
    "}\n",
    "\n",
    "results_folders = [\"./graphing_eval_results_0119\"]\n",
    "\n",
    "eval_type = \"absorption\"\n",
    "\n",
    "eval_types = [\n",
    "    \"scr\",\n",
    "    \"tpp\",\n",
    "    \"sparse_probing\",\n",
    "    \"absorption\",\n",
    "    \"core\",\n",
    "    \"autointerp\",\n",
    "    \"unlearning\",\n",
    "]\n",
    "\n",
    "combinations = list(itertools.product(eval_types, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [5, 10, 20, 50, 500],\n",
    "    \"tpp\": [5, 10, 20, 50, 500],\n",
    "    \"sparse_probing\": [1, 2, 5],\n",
    "}\n",
    "\n",
    "image_path = \"./images_checkpoints\"\n",
    "\n",
    "\n",
    "for eval_type, selection in combinations:\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    layers = [12]\n",
    "    model_name = \"gemma-2-2b\"\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = selections[selection]\n",
    "\n",
    "        for i, pattern in enumerate(sae_regex_patterns):\n",
    "            sae_regex_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        prefix = f\"{selection} Layer {layer}\\n\"\n",
    "\n",
    "        image_base_folder = os.path.join(image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        eval_folders = []\n",
    "        core_folders = []\n",
    "\n",
    "        for results_folder in results_folders:\n",
    "            eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "            core_folders.append(f\"{results_folder}/core\")\n",
    "\n",
    "        eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "        core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
    "        \n",
    "        filtered_eval_filenames = general_utils.filter_with_regex(\n",
    "            eval_filenames, sae_regex_patterns\n",
    "        )\n",
    "        filtered_core_filenames = general_utils.filter_with_regex(\n",
    "            core_filenames, sae_regex_patterns\n",
    "        )\n",
    "\n",
    "        eval_results = graphing_utils.get_eval_results(filtered_eval_filenames)\n",
    "        core_results = graphing_utils.get_core_results(filtered_core_filenames)\n",
    "\n",
    "        for sae in eval_results:\n",
    "            eval_results[sae].update(core_results[sae])\n",
    "\n",
    "            if \"batch\" in sae:\n",
    "                print(sae)\n",
    "                raise ValueError(\"Stop here\")\n",
    "            if \"batch\" in eval_results[sae][\"sae_class\"]:\n",
    "                print(sae)\n",
    "                raise ValueError(\"Stop here\")\n",
    "        \n",
    "\n",
    "        custom_metric, custom_metric_name = graphing_utils.get_custom_metric_key_and_name(eval_type, k)\n",
    "\n",
    "        graphing_utils.plot_training_steps(\n",
    "            eval_results,\n",
    "            custom_metric,\n",
    "            title=f\"Tokens vs {custom_metric_name} Gemma Layer {layer}\",\n",
    "            output_filename=f\"{image_base_name}_tokens_vs_diff.png\",\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EVERYTHING BELOW THIS HAS NOT BEEN TESTED AND IS PROBABLY BROKEN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise ValueError(\"Stop here\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "\n",
    "selections = {\n",
    "    \"Gemma-Scope Gemma-2-2B Width Series\": [\n",
    "        [\n",
    "            r\"(gemma-scope-2b-pt-res)\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*layer_{layer}.*(16k|65k|1m).*\",\n",
    "        ],\n",
    "    ],\n",
    "    \"Gemma-Scope Gemma-2-9B Width Series\": [\n",
    "        [\n",
    "            r\"(gemma-scope-9b-pt-res)\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*layer_{layer}.*(16k|131k|1m).*\",\n",
    "        ],\n",
    "    ],\n",
    "    \"SAE Bench Gemma-2-2B TopK Width Series\": [\n",
    "        [\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "        ],\n",
    "    ],\n",
    "    \"SAE Bench Gemma-2-2B Vanilla Width Series\": [\n",
    "        [\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\",\n",
    "        ],\n",
    "        [\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "            r\".*blocks\\.{layer}(?!.*step).*\",\n",
    "        ],\n",
    "    ],\n",
    "}\n",
    "\n",
    "eval_paths = [\n",
    "    # f\"{local_dir}/sparse_probing\",\n",
    "    # f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/scr\",\n",
    "]\n",
    "\n",
    "core_results_path = f\"{local_dir}/core\"\n",
    "\n",
    "combinations = list(itertools.product(eval_paths, selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [10, 50, 100, 500],\n",
    "    \"tpp\": [10, 50, 100, 500],\n",
    "    \"sparse_probing\": [1, 2, 5, 10, 50],\n",
    "}\n",
    "\n",
    "best_of_image_path = \"./images/best_of\"\n",
    "\n",
    "for eval_path, selection in combinations:\n",
    "    eval_type = eval_path.split(\"/\")[-1]\n",
    "\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        raise ValueError(\"This cell is only meant for SCR, TPP, and Sparse Probing\")\n",
    "\n",
    "    if \"Gemma-2-2B\" in selection:\n",
    "        layers = [5, 12, 19]\n",
    "        model_name = \"gemma-2-2b\"\n",
    "    elif \"Pythia-70M\" in selection:\n",
    "        layers = [3, 4]\n",
    "        model_name = \"pythia-70m-deduped\"\n",
    "    elif \"Gemma-2-9B\" in selection:\n",
    "        layers = [9, 20, 31]\n",
    "        model_name = \"gemma-2-9b\"\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    for layer in layers:\n",
    "        sae_regex_patterns = selections[selection][0]\n",
    "        sae_block_patterns = list(selections[selection][1])\n",
    "\n",
    "        for i, pattern in enumerate(sae_block_patterns):\n",
    "            sae_block_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        prefix = f\"{selection} Layer {layer}\\n\"\n",
    "\n",
    "        image_base_folder = os.path.join(best_of_image_path, eval_type)\n",
    "\n",
    "        if not os.path.exists(image_base_folder):\n",
    "            os.makedirs(image_base_folder)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            image_base_folder, f\"{selection.replace(' ', '_').lower()}_layer_{layer}\"\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{str(ks)}\"\n",
    "\n",
    "        selected_saes = select_saes_multiple_patterns(\n",
    "            sae_regex_patterns, sae_block_patterns\n",
    "        )\n",
    "\n",
    "        baseline_sae = None\n",
    "        baseline_label = None\n",
    "\n",
    "        if include_baseline:\n",
    "            if model_name != \"gemma-2-9b\":\n",
    "                baseline_sae = (f\"{model_name}_layer_{layer}_pca_sae\", \"custom_sae\")\n",
    "                baseline_label = \"PCA Baseline\"\n",
    "\n",
    "        # We will have failures on e.g. Pythia and unlearning, because the results don't exist\n",
    "        try:\n",
    "            graphing_utils.plot_best_of_ks_results(\n",
    "                selected_saes,\n",
    "                eval_path,\n",
    "                core_results_path,\n",
    "                image_base_name,\n",
    "                ks=ks,\n",
    "                title_prefix=prefix,\n",
    "                baseline_sae=baseline_sae,\n",
    "                baseline_label=baseline_label,\n",
    "            )\n",
    "        except Exception as e:\n",
    "            print(f\"Error plotting {selection} Layer {layer}: {str(e)}\")\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns\n",
    "\n",
    "checkpoint_selections = {\n",
    "    # r\"sae_bench_gemma-2-2b_topk_width-2pow12_date-1109\": r\".*blocks\\.{layer}.*\",\n",
    "    \"gemma-2-2b 16k\": [\n",
    "        [\n",
    "            r\"sae_bench_gemma-2-2b_topk_width-2pow14_date-1109\",\n",
    "            r\"sae_bench_gemma-2-2b_vanilla_width-2pow14_date-1109\",\n",
    "        ],\n",
    "        [r\".*blocks\\.{layer}.*\", r\".*blocks\\.{layer}.*\"],\n",
    "    ],\n",
    "    # r\"sae_bench_gemma-2-2b_topk_width-2pow16_date-1109\": r\".*blocks\\.{layer}.*\",\n",
    "    # r\"sae_bench_gemma-2-2b_vanilla_width-2pow12_date-1109\": r\".*blocks\\.{layer}.*\",\n",
    "    # r\"sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109\": r\".*blocks\\.{layer}.*\",\n",
    "}\n",
    "\n",
    "eval_paths = [\n",
    "    f\"{local_dir}/sparse_probing\",\n",
    "    f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    f\"{local_dir}/absorption\",\n",
    "    f\"{local_dir}/unlearning\",\n",
    "    f\"{local_dir}/autointerp\",\n",
    "]\n",
    "\n",
    "core_results_path = f\"{local_dir}/core\"\n",
    "\n",
    "combinations = list(itertools.product(eval_paths, checkpoint_selections.keys()))\n",
    "\n",
    "ks_lookup = {\n",
    "    \"scr\": [10, 50],\n",
    "    \"tpp\": [10, 50],\n",
    "    \"sparse_probing\": [1, 2, 5],\n",
    "}\n",
    "\n",
    "checkpoint_image_path = \"./checkpoint_images\"\n",
    "\n",
    "for eval_path, selection in combinations:\n",
    "    eval_type = eval_path.split(\"/\")[-1]\n",
    "\n",
    "    if eval_type in ks_lookup:\n",
    "        ks = ks_lookup[eval_type]\n",
    "    else:\n",
    "        ks = [-1]\n",
    "\n",
    "    if \"gemma-2-2b\" in selection:\n",
    "        layers = [5, 12, 19]\n",
    "    else:\n",
    "        raise ValueError(\"Please add the correct layers for the selection\")\n",
    "\n",
    "    layer_ks_combinations = list(itertools.product(layers, ks))\n",
    "\n",
    "    for layer, k in layer_ks_combinations:\n",
    "        sae_regex_patterns = checkpoint_selections[selection][0]\n",
    "        sae_block_patterns = list(checkpoint_selections[selection][1])\n",
    "\n",
    "        for i, pattern in enumerate(sae_block_patterns):\n",
    "            sae_block_patterns[i] = pattern.format(layer=layer)\n",
    "\n",
    "        prefix = f\"{selection} Layer {layer}\\n\"\n",
    "\n",
    "        if not os.path.exists(checkpoint_image_path):\n",
    "            os.makedirs(checkpoint_image_path)\n",
    "\n",
    "        image_base_name = os.path.join(\n",
    "            checkpoint_image_path,\n",
    "            f\"{eval_type}_{selection.replace(' ', '_').lower()}_layer_{layer}\",\n",
    "        )\n",
    "\n",
    "        if eval_type in ks_lookup:\n",
    "            image_base_name = f\"{image_base_name}_topk_{k}\"\n",
    "\n",
    "        selected_saes = select_saes_multiple_patterns(\n",
    "            sae_regex_patterns, sae_block_patterns\n",
    "        )\n",
    "\n",
    "        eval_results = graphing_utils.get_eval_results(selected_saes, eval_path)\n",
    "        core_results = graphing_utils.get_core_results(selected_saes, core_results_path)\n",
    "\n",
    "        # if len(eval_results) != len(selected_saes):\n",
    "        #     print(f\"Skipping {selection} Layer {layer} due to missing results\")\n",
    "        #     continue\n",
    "\n",
    "        for sae in eval_results:\n",
    "            eval_results[sae].update(core_results[sae])\n",
    "\n",
    "        custom_metric, custom_metric_name = (\n",
    "            graphing_utils.get_custom_metric_key_and_name(eval_path, k)\n",
    "        )\n",
    "\n",
    "        graphing_utils.plot_training_steps(\n",
    "            eval_results,\n",
    "            custom_metric,\n",
    "            title=f\"Tokens vs {custom_metric_name} Gemma Layer {layer}\",\n",
    "            output_filename=f\"{image_base_name}_tokens_vs_diff.png\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot metric correlations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# k=100\n",
    "# custom_metric = f'sae_top_{k}_test_accuracy'\n",
    "\n",
    "metric_keys = [\n",
    "    \"l0\",\n",
    "    \"frac_recovered\",\n",
    "    custom_metric,\n",
    "]\n",
    "\n",
    "graphing_utils.plot_correlation_heatmap(\n",
    "    eval_results, metric_names=metric_keys, ae_names=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simple example usage:\n",
    "# plot_metric_scatter(plotting_results, metric_x=\"l0\", metric_y=\"frac_recovered\", title=\"L0 vs Fraction Recovered\")\n",
    "\n",
    "threshold_x = 50\n",
    "threshold_y = 100\n",
    "\n",
    "metric_x = f\"sae_top_{threshold_x}_test_accuracy\"\n",
    "metric_y = f\"sae_top_{threshold_y}_test_accuracy\"\n",
    "\n",
    "title = \"\"\n",
    "x_label = \"k=1 Sparse Probe Accuracy\"\n",
    "y_label = \"k=100 Sparse Probe Accuracy\"\n",
    "output_filename = os.path.join(\n",
    "    image_path,\n",
    "    f\"sparse_probing_result_correlation_for_thresholds_{threshold_y}_{threshold_y}.png\",\n",
    ")\n",
    "\n",
    "graphing_utils.plot_correlation_scatter(\n",
    "    eval_results,\n",
    "    metric_x=metric_x,\n",
    "    metric_y=metric_y,\n",
    "    title=title,\n",
    "    x_label=x_label,\n",
    "    y_label=y_label,\n",
    "    output_filename=output_filename,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
