{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6e3fe139-9140-487a-82d6-1b2efab1b269",
   "metadata": {},
   "source": [
    "# Sparsity-faithfulness SAE and transcoder evaluations\n",
    "\n",
    "This notebook demonstrates how to perform the sparsity-faithfulness SAE and transcoder evaluations, as seen in Section 3.2.2 of our paper. We will be evaluating our transcoders and SAEs on Pythia-410M."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e096e8e5-3bba-456b-ab84-4571aea3690f",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9da60e4b-f27d-41eb-bd42-736e6231092c",
   "metadata": {},
   "source": [
    "Import the standard `transcoder_circuits` code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2bd1544d-2ea4-472a-8446-864fb872b993",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transcoder_circuits.circuit_analysis import *\n",
    "from transcoder_circuits.feature_dashboards import *\n",
    "from transcoder_circuits.replacement_ctx import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cf48a63-5aad-4ba3-ab08-b8d36397998c",
   "metadata": {},
   "source": [
    "Import the SAE/transcoder code, along with the model that we'll be analyzing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "587bb6de-8af2-4d23-bffe-095b76389a3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model pythia-410m into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "from sae_training.sparse_autoencoder import SparseAutoencoder\n",
    "from transformer_lens import HookedTransformer, utils\n",
    "import os\n",
    "import torch\n",
    "\n",
    "model = HookedTransformer.from_pretrained('pythia-410m')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1970b171-f1bf-44ff-a850-28ab3f5ad395",
   "metadata": {
    "id": "N3D_0qDmBY5K"
   },
   "source": [
    "Now, load in a corpus of text that we'll use for our analysis. We'll be drawing from OpenWebText."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb24f806-aca2-441f-9e11-8de389bbeb90",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# This function was stolen from one of Neel Nanda's exploratory notebooks\n",
    "# Thanks, Neel!\n",
    "import einops\n",
    "def tokenize_and_concatenate(\n",
    "    dataset,\n",
    "    tokenizer,\n",
    "    streaming = False,\n",
    "    max_length = 1024,\n",
    "    column_name = \"text\",\n",
    "    add_bos_token = True,\n",
    "):\n",
    "    \"\"\"Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.\n",
    "\n",
    "    This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)\n",
    "\n",
    "    Args:\n",
    "        dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.\n",
    "        tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.\n",
    "        streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.\n",
    "        max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.\n",
    "        column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.\n",
    "        add_bos_token (bool, optional): . Defaults to True.\n",
    "\n",
    "    Returns:\n",
    "        Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called \"tokens\"\n",
    "\n",
    "    Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why\n",
    "    \"\"\"\n",
    "    for key in dataset.features:\n",
    "        if key != column_name:\n",
    "            dataset = dataset.remove_columns(key)\n",
    "\n",
    "    if tokenizer.pad_token is None:\n",
    "        # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.\n",
    "        tokenizer.add_special_tokens({\"pad_token\": \"<PAD>\"})\n",
    "    # Define the length to chop things up into - leaving space for a bos_token if required\n",
    "    if add_bos_token:\n",
    "        seq_len = max_length - 1\n",
    "    else:\n",
    "        seq_len = max_length\n",
    "\n",
    "    def tokenize_function(examples):\n",
    "        text = examples[column_name]\n",
    "        # Concatenate it all into an enormous string, separated by eos_tokens\n",
    "        full_text = tokenizer.eos_token.join(text)\n",
    "        # Divide into 20 chunks of ~ equal length\n",
    "        num_chunks = 20\n",
    "        chunk_length = (len(full_text) - 1) // num_chunks + 1\n",
    "        chunks = [\n",
    "            full_text[i * chunk_length : (i + 1) * chunk_length]\n",
    "            for i in range(num_chunks)\n",
    "        ]\n",
    "        # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned\n",
    "        tokens = tokenizer(chunks, return_tensors=\"np\", padding=True)[\n",
    "            \"input_ids\"\n",
    "        ].flatten()\n",
    "        # Drop padding tokens\n",
    "        tokens = tokens[tokens != tokenizer.pad_token_id]\n",
    "        num_tokens = len(tokens)\n",
    "        num_batches = num_tokens // (seq_len)\n",
    "        # Drop the final tokens if not enough to make a full sequence\n",
    "        tokens = tokens[: seq_len * num_batches]\n",
    "        tokens = einops.rearrange(\n",
    "            tokens, \"(batch seq) -> batch seq\", batch=num_batches, seq=seq_len\n",
    "        )\n",
    "        if add_bos_token:\n",
    "            prefix = np.full((num_batches, 1), tokenizer.bos_token_id)\n",
    "            tokens = np.concatenate([prefix, tokens], axis=1)\n",
    "        return {\"tokens\": tokens}\n",
    "\n",
    "    tokenized_dataset = dataset.map(\n",
    "        tokenize_function,\n",
    "        batched=True,\n",
    "        remove_columns=[column_name],\n",
    "    )\n",
    "    #tokenized_dataset.set_format(type=\"torch\", columns=[\"tokens\"])\n",
    "    return tokenized_dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d5bf917d-7bed-4a8a-99d1-284a6a5bda78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from huggingface_hub import HfApi\n",
    "import numpy as np\n",
    "\n",
    "dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)\n",
    "dataset = dataset.shuffle(seed=42, buffer_size=10_000)\n",
    "tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)\n",
    "tokenized_owt = tokenized_owt.shuffle(42)\n",
    "tokenized_owt = tokenized_owt.take(12800*2)\n",
    "owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ba288f88-eab1-4eac-b49d-1da315133f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "owt_tokens_torch = torch.from_numpy(owt_tokens).cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "318aebb9-7da8-4641-8705-9330d8736721",
   "metadata": {},
   "source": [
    "# SAE sweep evaluation\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2736d6f2-d365-4fd4-bcbe-20e73ab38592",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_sae(model, owt_tokens_torch, sae, num_batches=100, batch_size=128):\n",
    "    layer = sae.cfg.hook_point_layer\n",
    "\n",
    "    # evaluate l0s\n",
    "    l0s = []\n",
    "    losses = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm.tqdm(range(0, num_batches)):\n",
    "            cur_tokens = owt_tokens_torch[batch*batch_size:(batch+1)*batch_size]\n",
    "            \n",
    "            sae_acts = []\n",
    "            def replacement_hook(acts, hook):\n",
    "                sae_out = sae(acts)\n",
    "                activations = sae_out[0].to(acts.dtype)\n",
    "                sae_acts.append(sae_out[1])\n",
    "                return activations\n",
    "            \n",
    "            loss = model.run_with_hooks(cur_tokens, return_type=\"loss\", fwd_hooks=[(sae.cfg.hook_point, replacement_hook)])\n",
    "            binarized_acts = 1.0*(sae_acts[0] > 0)\n",
    "            l0s.append(\n",
    "                (binarized_acts.reshape(-1, binarized_acts.shape[-1])).sum(dim=1).mean().item()\n",
    "            )\n",
    "            losses.append(utils.to_numpy(loss))\n",
    "    \n",
    "    return {\n",
    "        'l0': np.mean(l0s),\n",
    "        'sae_loss': np.mean(losses)\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e82b99dd-7fe5-4a8d-a4f2-c7eb4896bb65",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:49<00:00,  1.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0': 391.38703887939454, 'sae_loss': 3.3405383}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sae_template = \"./pythia-saes/l1_3e-05/aols0e3v/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "sae = SparseAutoencoder.load_from_pretrained(f\"{sae_template}.pt\").eval()\n",
    "print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1633225e-8fc7-4a08-899f-9c33687906c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:49<00:00,  1.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0': 214.0961669921875, 'sae_loss': 3.3458128}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sae_template = \"./pythia-saes/l1_4e-05/dnses7qu/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "sae = SparseAutoencoder.load_from_pretrained(f\"{sae_template}.pt\").eval()\n",
    "print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8abc21c5-ed80-46e0-b8ee-238a3e2be78b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:49<00:00,  1.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0': 214.0961669921875, 'sae_loss': 3.3458128}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sae_template = \"./pythia-saes/l1_4e-05/dnses7qu/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "sae = SparseAutoencoder.load_from_pretrained(f\"{sae_template}.pt\").eval()\n",
    "print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "09f2af7e-f1b1-48a5-83af-fda0ec2f907b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:49<00:00,  1.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0': 97.85787994384765, 'sae_loss': 3.3527865}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sae_template = \"./pythia-saes/l1_5.5e-05/oy0zuy0k/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "sae = SparseAutoencoder.load_from_pretrained(f\"{sae_template}.pt\").eval()\n",
    "print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1f02d292-a799-4673-b2b3-7116a0208780",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:48<00:00,  1.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0': 47.84608245849609, 'sae_loss': 3.3621392}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sae_template = \"./pythia-saes/l1_7e-05/olr1w3lx/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "sae = SparseAutoencoder.load_from_pretrained(f\"{sae_template}.pt\").eval()\n",
    "print(eval_sae(model, owt_tokens_torch, sae, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6fc8a51-4a3c-48ab-9eb6-e1f58aeb5dcd",
   "metadata": {},
   "source": [
    "# New sweep evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cb3a71ef-0666-415f-a4e0-91f14dbf0190",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_transcoder_l0_ce(model, all_tokens, transcoder, num_batches=100, batch_size=128):\n",
    "    l0s = []\n",
    "    transcoder_losses = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm.tqdm(range(0, num_batches)):\n",
    "            torch.cuda.empty_cache()\n",
    "            cur_batch_tokens = all_tokens[batch*batch_size:(batch+1)*batch_size]\n",
    "            with TranscoderReplacementContext(model, [transcoder]):\n",
    "                cur_losses, cache = model.run_with_cache(cur_batch_tokens, return_type=\"loss\", names_filter=[transcoder.cfg.hook_point])\n",
    "                # measure losses\n",
    "                transcoder_losses.append(utils.to_numpy(cur_losses))\n",
    "                # measure l0s\n",
    "                acts = cache[transcoder.cfg.hook_point]\n",
    "                binarized_transcoder_acts = 1.0*(transcoder(acts)a[1] > 0)\n",
    "                l0s.append(\n",
    "                    (binarized_transcoder_acts.reshape(-1, binarized_transcoder_acts.shape[-1])).sum(dim=1).mean().item()\n",
    "                )\n",
    "\n",
    "    return {\n",
    "        'l0s': np.mean(l0s),\n",
    "        'ce_loss': np.mean(transcoder_losses)\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6020d14-2883-4ced-9fb1-924111355e33",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:26<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0s': 203.77172332763672, 'ce_loss': 3.341213}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transcoder_template = \"./pythia-transcoders/lr_0.0002_l1_2.5e-05/pk60eijx/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "transcoder = SparseAutoencoder.load_from_pretrained(f\"{transcoder_template}.pt\").eval()\n",
    "print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17691515-39f7-46cc-a53a-9c210a804ea9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:25<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0s': 148.1538818359375, 'ce_loss': 3.3440711}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transcoder_template = \"pythia-transcoders/lr_0.0002_l1_3e-05/67jdp0mv/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "transcoder = SparseAutoencoder.load_from_pretrained(f\"{transcoder_template}.pt\").eval()\n",
    "print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "86d93f40-4a6f-45db-86ac-73a5324e7fe6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:26<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0s': 82.748544921875, 'ce_loss': 3.3491273}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transcoder_template = \"pythia-transcoders/lr_0.0002_l1_4e-05/pze62n3h/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "transcoder = SparseAutoencoder.load_from_pretrained(f\"{transcoder_template}.pt\").eval()\n",
    "print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "63b025f1-397e-4639-ac1a-da79d59f3e75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:26<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0s': 44.042958984375, 'ce_loss': 3.3549356}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transcoder_template = \"pythia-transcoders/lr_0.0002_l1_5.5e-05/szsvunrm/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "transcoder = SparseAutoencoder.load_from_pretrained(f\"{transcoder_template}.pt\").eval()\n",
    "print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e0f69707-763a-4e25-9d52-b454832fcc06",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [03:26<00:00,  1.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'l0s': 27.454230651855468, 'ce_loss': 3.3682058}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "transcoder_template = \"pythia-transcoders/lr_0.0002_l1_7e-05/v4gqmaoc/final_sparse_autoencoder_pythia-410m_blocks.15.ln2.hook_normalized_32768\"\n",
    "transcoder = SparseAutoencoder.load_from_pretrained(f\"{transcoder_template}.pt\").eval()\n",
    "print(eval_transcoder_l0_ce(model, owt_tokens_torch, transcoder, num_batches=200, batch_size=128))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
