{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8626c376",
   "metadata": {},
   "source": [
    "# Training and evaluating an SAE on Gemma-2-2B layer 12\n",
    "\n",
    "You probably don't want to do this sort of training in a notebook, but the processes is demonstrated here regardless. We'll train a single SAE with 32k latents and L0=100 on Gemma-2-2b layer 12 using 500M tokens from the Pile. In the paper, we train a suite of SAEs with different L0s. We won't attempt to do that here, as really you should run that sort of sweep on a cluster, but the process is the same (except for changing L0 for each run)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b3a5e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparse_but_wrong.enchanced_batch_topk_sae import (\n",
    "    EnhancedBatchTopKTrainingSAEConfig,\n",
    ")\n",
    "\n",
    "from sae_lens import LanguageModelSAETrainingRunner, LanguageModelSAERunnerConfig\n",
    "\n",
    "# These parameters assume a H100 80GB GPU. You can change them to fit your GPU.\n",
    "\n",
    "runner_cfg = LanguageModelSAERunnerConfig(\n",
    "    sae=EnhancedBatchTopKTrainingSAEConfig(\n",
    "        k=100,\n",
    "        d_in=2304,\n",
    "        d_sae=32 * 1024,\n",
    "        normalize_acts_by_decoder_norm=True,\n",
    "    ),\n",
    "    model_name=\"google/gemma-2-2b\",\n",
    "    model_class_name=\"AutoModelForCausalLM\",\n",
    "    dataset_path=\"monology/pile-uncopyrighted\",\n",
    "    hook_name=f\"model.layers.12\",\n",
    "    context_size=1024,\n",
    "    training_tokens=500_000_000,\n",
    "    device=\"cuda\",\n",
    "    lr=3e-4,\n",
    "    n_batches_in_buffer=64,\n",
    "    train_batch_size_tokens=4096,\n",
    "    store_batch_size_prompts=12,\n",
    "    eval_batch_size_prompts=6,\n",
    "    autocast_lm=True,\n",
    "    autocast=True,\n",
    ")\n",
    "\n",
    "runner = LanguageModelSAETrainingRunner(runner_cfg)\n",
    "\n",
    "sae = runner.run()\n",
    "sae.save_inference_model(\"l0_100_sae\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32753cff",
   "metadata": {},
   "source": [
    "Now that we've trained the SAE, let's calculate its nth decoder projection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7306bb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from sae_lens import ActivationsStore, SAE\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "from sparse_but_wrong.nth_decoder_projection import nth_decoder_projection\n",
    "\n",
    "# let's start by loading the saved inference SAE\n",
    "\n",
    "loaded_sae = SAE.load_from_disk(\"l0_100_sae\", device=\"cuda\")\n",
    "\n",
    "# First, we load some training activations from the Pile dataset to test nth decoder projection against\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_activation_batches(\n",
    "    n_batches: int = 100,\n",
    "    store_batch_size_prompts: int = 8,\n",
    "    n_batches_in_buffer: int = 8,\n",
    "    train_batch_size_tokens: int = 4096,\n",
    "):\n",
    "    model = HookedTransformer.from_pretrained_no_processing(\"gemma-2-2b\", device=\"cuda\")\n",
    "    store = ActivationsStore(\n",
    "        model=model,\n",
    "        dataset=\"monology/pile-uncopyrighted\",\n",
    "        d_in=2304,\n",
    "        hook_name=\"blocks.12.hook_resid_post\",\n",
    "        hook_head_index=None,\n",
    "        context_size=1024,\n",
    "        prepend_bos=True,\n",
    "        streaming=True,\n",
    "        store_batch_size_prompts=store_batch_size_prompts,\n",
    "        train_batch_size_tokens=train_batch_size_tokens,\n",
    "        n_batches_in_buffer=n_batches_in_buffer,\n",
    "        total_training_tokens=10**9,\n",
    "        normalize_activations=\"none\",\n",
    "        dataset_trust_remote_code=True,\n",
    "        dtype=\"float32\",\n",
    "        device=torch.device(\"cuda\"),\n",
    "        seqpos_slice=(None,),\n",
    "    )\n",
    "    return [store.next_batch().squeeze(1) for _ in tqdm(range(n_batches), desc=\"Generating inputs\")]\n",
    "\n",
    "train_activation_batches = get_activation_batches()\n",
    "\n",
    "# we'll use N = 12_000, (a bit less than half of the SAE's width).\n",
    "# this is an arbitrary but reasonable choice. In practice, anything less than ~ d_sae / 2 seems to work well.\n",
    "N = 12_000\n",
    "\n",
    "with torch.no_grad():\n",
    "    projections = []\n",
    "    for batch in train_activation_batches:\n",
    "        projections.append(nth_decoder_projection(batch, loaded_sae, N))\n",
    "\n",
    "    s_n = torch.stack(projections).mean().item()\n",
    "\n",
    "print(f\"nth decoder projection (N={N}): {s_n}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87f1eadb",
   "metadata": {},
   "source": [
    "Finally, let's evaluate the SAE using the [sae-Probes benchmark](https://github.com/sae-probes/sae-probes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a8668e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_probes import run_sae_evals\n",
    "\n",
    "run_sae_evals(\n",
    "    sae=loaded_sae,\n",
    "    model_name=\"gemma-2-2b\",\n",
    "    hook_name=\"blocks.12.hook_resid_post\",\n",
    "    reg_type=\"l1\",\n",
    "    setting=\"normal\",\n",
    "    ks=[1, 16],\n",
    "    results_path=\"l0_100_sae_probes_results\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89296861",
   "metadata": {},
   "source": [
    "This saves the sparse probing results for each dataset to the `l0_100_sae_probes_results` directory. Let's load these and calculate a mean F1 score for the SAE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a55de855",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "sae_probes_result_paths = Path(\"l0_100_sae_probes_results\").glob(\"*/normal_setting/*.json\")\n",
    "\n",
    "rows = []\n",
    "for sae_probes_path in sae_probes_result_paths:\n",
    "    with open(sae_probes_path) as f:\n",
    "        sae_probes_data = json.load(f)\n",
    "        for ds_result in sae_probes_data:\n",
    "            rows.append(ds_result)\n",
    "sae_probes_df = pd.DataFrame(rows)\n",
    "sae_probes_mean_df = sae_probes_df.groupby(['k'])[['test_auc', 'test_acc', 'test_f1']].mean().reset_index()\n",
    "\n",
    "sae_probes_mean_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
