{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "\n",
    "import sae_bench.custom_saes.custom_sae_config as custom_sae_config\n",
    "import sae_bench.custom_saes.relu_sae as relu_sae\n",
    "import sae_bench.custom_saes.run_all_evals_custom_saes as run_all_evals_custom_saes\n",
    "import sae_bench.evals.core.main as core\n",
    "import sae_bench.evals.sparse_probing.main as sparse_probing\n",
    "import sae_bench.sae_bench_utils.general_utils as general_utils\n",
    "from sae_bench.sae_bench_utils.sae_selection_utils import get_saes_from_regex\n",
    "\n",
    "RANDOM_SEED = 42\n",
    "\n",
    "output_folders = {\n",
    "    \"absorption\": \"eval_results/absorption\",\n",
    "    \"autointerp\": \"eval_results/autointerp\",\n",
    "    \"core\": \"eval_results/core\",\n",
    "    \"scr\": \"eval_results/scr\",\n",
    "    \"tpp\": \"eval_results/tpp\",\n",
    "    \"sparse_probing\": \"eval_results/sparse_probing\",\n",
    "    \"unlearning\": \"eval_results/unlearning\",\n",
    "}\n",
    "\n",
    "# Note: Unlearning is not recommended for models with < 2B parameters and we recommend an instruct tuned model\n",
    "# Unlearning will also require requesting permission for the WMDP dataset (see unlearning/README.md)\n",
    "# Absorption not recommended for models < 2B parameters\n",
    "# asyncio doesn't like notebooks, so autointerp must be ran using a python script\n",
    "\n",
    "# Select your eval types here.\n",
    "eval_types = [\n",
    "    \"absorption\",\n",
    "    # \"autointerp\",\n",
    "    \"core\",\n",
    "    \"scr\",\n",
    "    \"tpp\",\n",
    "    \"sparse_probing\",\n",
    "    # \"unlearning\",\n",
    "]\n",
    "\n",
    "if \"autointerp\" in eval_types:\n",
    "    raise ValueError(\"autointerp must be ran using a python script\")\n",
    "\n",
    "device = general_utils.setup_environment()\n",
    "\n",
    "model_name = \"pythia-70m-deduped\"\n",
    "llm_batch_size = 512\n",
    "torch_dtype = torch.float32\n",
    "\n",
    "# Currently all evals take str_dtype instead of torch_dtype. We did this for serialization purposes, but it was probably a mistake.\n",
    "# For now we will just use the str_dtype. TODO: Fix this\n",
    "str_dtype = torch_dtype.__str__().split(\".\")[-1]\n",
    "\n",
    "\n",
    "# If evaluating multiple SAEs on the same layer, set save_activations to True\n",
    "# This will require at least 100GB of disk space\n",
    "save_activations = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell loads your custom SAEs. If you just want to use existing SAE Lens SAEs, comment it out.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "repo_id = \"canrager/lm_sae\"\n",
    "baseline_filename = (\n",
    "    \"pythia70m_sweep_standard_ctx128_0712/resid_post_layer_4/trainer_8/ae.pt\"\n",
    ")\n",
    "hook_layer = 4\n",
    "hook_name = f\"blocks.{hook_layer}.hook_resid_post\"\n",
    "\n",
    "sae = relu_sae.load_dictionary_learning_relu_sae(\n",
    "    repo_id, baseline_filename, model_name, device, torch_dtype, layer=hook_layer\n",
    ")\n",
    "\n",
    "print(f\"sae dtype: {sae.dtype}, device: {sae.device}\")\n",
    "\n",
    "d_sae, d_in = sae.W_dec.data.shape\n",
    "\n",
    "assert d_sae >= d_in\n",
    "\n",
    "print(f\"d_in: {d_in}, d_sae: {d_sae}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our sae object we need to have a CustomSAEConfig. This contains some information which is used by the evals (hook_name, hook_layer, model_name, d_sae, etc). In addition, it contains information that is used by our plotting functions, like number of training tokens and architecture. For example, we should have the sae.cfg.architecture defined if we want to plot multiple SAE architectures.\n",
    "\n",
    "Note: Everything in this cell, except `architecture` and `training_tokens`, is done in the `BaseSAE` class that the `ReluSAE` inherits from. Because of this, we recommend that you modify an existing SAE class.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.cfg = custom_sae_config.CustomSAEConfig(\n",
    "    model_name, d_in=d_in, d_sae=d_sae, hook_name=hook_name, hook_layer=hook_layer\n",
    ")\n",
    "\n",
    "# Core evals require us to specify cfg.dtype, which must be a string for easy serialization. TODO: Refactor to use torch.dtype\n",
    "sae.cfg.dtype = str_dtype\n",
    "\n",
    "\n",
    "# The following contains our current defined SAE types and the shapes to plot for each. Add your custom SAE as new_sae_key\n",
    "new_sae_key = \"vanilla\"\n",
    "trainer_markers = {\n",
    "    \"standard\": \"o\",\n",
    "    \"jumprelu\": \"X\",\n",
    "    \"topk\": \"^\",\n",
    "    \"p_anneal\": \"*\",\n",
    "    \"gated\": \"d\",\n",
    "    new_sae_key: \"s\",  # New SAE\n",
    "}\n",
    "\n",
    "trainer_colors = {\n",
    "    \"standard\": \"blue\",\n",
    "    \"jumprelu\": \"orange\",\n",
    "    \"topk\": \"green\",\n",
    "    \"p_anneal\": \"red\",\n",
    "    \"gated\": \"purple\",\n",
    "    new_sae_key: \"black\",  # New SAE\n",
    "}\n",
    "\n",
    "sae.cfg.architecture = new_sae_key\n",
    "sae.cfg.training_tokens = 200_000_000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`selected_saes` is a list of tuples of (unique_sae_id, sae object) OR (sae lens release, sae lens id). If it is a list of custom sae objects, then memory size will increase with the length of the list. This is especially important if the SAEs are large. If memory is a concern, I recommend calling the `run_eval()` function multiple times with lists of length 1, each list containing a new sae object.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: the custom_sae_id should be unique, as it is used for the intermediate results and final results file names\n",
    "\n",
    "unique_custom_sae_id = baseline_filename.replace(\"/\", \"_\").replace(\".\", \"_\")\n",
    "print(f\"sae_id: {unique_custom_sae_id}\")\n",
    "\n",
    "# list of tuple of (sae_id, sae object)\n",
    "custom_saes = [(unique_custom_sae_id, sae)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Select your baseline SAEs here. Refer to `sae_regex_selection.ipynb` for more regex patterns. We are going to get a topk SAE from the same layer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_regex_pattern = r\"(sae_bench_pythia70m_sweep_topk_ctx128_0730).*\"\n",
    "sae_block_pattern = r\".*blocks\\.([4])\\.hook_resid_post__trainer_(8)$\"\n",
    "\n",
    "baseline_saes = get_saes_from_regex(sae_regex_pattern, sae_block_pattern)\n",
    "print(f\"baseline_saes: {baseline_saes}\")\n",
    "baseline_sae_id = f\"{baseline_saes[0][0]}_{baseline_saes[0][1]}\".replace(\".\", \"_\")\n",
    "print(f\"baseline_sae_id: {baseline_sae_id}\")\n",
    "\n",
    "selected_saes = custom_saes + baseline_saes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run time for the next 2 functions is approximately 2 minutes on an RTX 3090.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: We typically run with n_eval_sparsity_variance_batches=2000, but I have reduced it here for a faster run\n",
    "\n",
    "_ = core.multiple_evals(\n",
    "    selected_saes=selected_saes,\n",
    "    n_eval_reconstruction_batches=200,\n",
    "    n_eval_sparsity_variance_batches=200,\n",
    "    eval_batch_size_prompts=32,\n",
    "    compute_featurewise_density_statistics=True,\n",
    "    compute_featurewise_weight_based_metrics=True,\n",
    "    exclude_special_tokens_from_reconstruction=True,\n",
    "    dataset=\"Skylion007/openwebtext\",\n",
    "    context_size=128,\n",
    "    output_folder=\"eval_results/core\",\n",
    "    verbose=True,\n",
    "    dtype=str_dtype,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We do a subset of the sparse probing datasets here for shorter runtime\n",
    "dataset_names = [\"LabHC/bias_in_bios_class_set1\"]\n",
    "\n",
    "# TODO: Add a verbose flag\n",
    "_ = sparse_probing.run_eval(\n",
    "    sparse_probing.SparseProbingEvalConfig(\n",
    "        model_name=model_name,\n",
    "        random_seed=RANDOM_SEED,\n",
    "        llm_batch_size=llm_batch_size,\n",
    "        llm_dtype=str_dtype,\n",
    "        dataset_names=dataset_names,\n",
    "    ),\n",
    "    selected_saes,\n",
    "    device,\n",
    "    \"eval_results/sparse_probing\",\n",
    "    force_rerun=False,\n",
    "    clean_up_activations=True,\n",
    "    save_activations=save_activations,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "import torch\n",
    "\n",
    "import sae_bench.sae_bench_utils.graphing_utils as graphing_utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = \"./images\"\n",
    "\n",
    "if not os.path.exists(image_path):\n",
    "    os.makedirs(image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folders = [\"./eval_results\"]\n",
    "\n",
    "eval_type = \"sparse_probing\"\n",
    "\n",
    "eval_folders = []\n",
    "core_folders = []\n",
    "for results_folder in results_folders:\n",
    "    eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "    core_folders.append(f\"{results_folder}/core\")\n",
    "\n",
    "eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "core_filenames = graphing_utils.find_eval_results_files(core_folders)\n",
    "\n",
    "print(f\"eval_filenames: {eval_filenames}\")\n",
    "print(f\"core_filenames: {core_filenames}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we can view the raw results, and we see that both SAEs significantly outperform the residual stream baseline.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_results_dict = graphing_utils.get_eval_results(eval_filenames)\n",
    "core_results_dict = graphing_utils.get_eval_results(core_filenames)\n",
    "\n",
    "for sae in eval_results_dict:\n",
    "    eval_results_dict[sae].update(core_results_dict[sae])\n",
    "\n",
    "\n",
    "print(eval_results_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_filepath = eval_filenames[0]\n",
    "\n",
    "with open(baseline_filepath) as f:\n",
    "    baseline_sae_eval_results = json.load(f)\n",
    "\n",
    "custom_filepath = eval_filenames[1]\n",
    "\n",
    "with open(custom_filepath) as f:\n",
    "    custom_sae_eval_results = json.load(f)\n",
    "\n",
    "k = 1\n",
    "\n",
    "print(baseline_sae_eval_results.keys())\n",
    "\n",
    "print(\n",
    "    f\"Baseline SAE top {k} accuracy was:\",\n",
    "    baseline_sae_eval_results[\"eval_result_metrics\"][\"sae\"][\n",
    "        f\"sae_top_{k}_test_accuracy\"\n",
    "    ],\n",
    ")\n",
    "print(\n",
    "    f\"Custom SAE top {k} accuracy was:\",\n",
    "    custom_sae_eval_results[\"eval_result_metrics\"][\"sae\"][f\"sae_top_{k}_test_accuracy\"],\n",
    ")\n",
    "print(\n",
    "    f\"LLM residual stream top {k} accuracy was:\",\n",
    "    baseline_sae_eval_results[\"eval_result_metrics\"][\"llm\"][\n",
    "        f\"llm_top_{k}_test_accuracy\"\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also plot the metrics, plotting L0 vs Custom Metric or L0 vs Loss Recovered vs Custom metric. We can have different shapes for the SAE type or dictionary size.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_base_name = os.path.join(image_path, \"sparse_probing\")\n",
    "\n",
    "graphing_utils.plot_results(\n",
    "    eval_filenames,\n",
    "    core_filenames,\n",
    "    eval_type,\n",
    "    image_base_name,\n",
    "    k,\n",
    "    trainer_markers=trainer_markers,\n",
    "    trainer_colors=trainer_colors,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will run all of the evals, and create more plots. Be warned - this takes around an hour. Note that a significant amount of the costs are one time per layer - for example, with absorption, we have to train a bunch of probes on a given layer. So, if we have multiple SAEs per layer the cost should be much less than 30 minutes per SAE. In addition, to save disk space usage we currently are not saving activations for reuse by multiple SAEs.\n",
    "\n",
    "Additionally, we can make this faster by evaluating on a subset of the datasets. Sparse probing and Spurious Correlation Removal both have approximately 8 datasets each. We could lower the time by using fewer datasets at the cost of not having as strong of a signal.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = run_all_evals_custom_saes.run_evals(\n",
    "    model_name,\n",
    "    selected_saes,\n",
    "    llm_batch_size,\n",
    "    str_dtype,\n",
    "    device,\n",
    "    eval_types,\n",
    "    api_key=None,\n",
    "    force_rerun=False,\n",
    "    save_activations=save_activations,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for eval_type in eval_types:\n",
    "    eval_folders = []\n",
    "\n",
    "    for results_folder in results_folders:\n",
    "        eval_folders.append(f\"{results_folder}/{eval_type}\")\n",
    "\n",
    "    eval_filenames = graphing_utils.find_eval_results_files(eval_folders)\n",
    "\n",
    "    graphing_utils.plot_results(\n",
    "        eval_filenames,\n",
    "        core_filenames,\n",
    "        eval_type,\n",
    "        image_base_name,\n",
    "        k=10,\n",
    "        trainer_markers=trainer_markers,\n",
    "        trainer_colors=trainer_colors,\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
