{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Standard imports\n",
        "# Imports for displaying vis in Colab / notebook\n",
        "from typing import Optional\n",
        "\n",
        "import einops\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from beartype import beartype\n",
        "from jaxtyping import Float, Int, jaxtyped\n",
        "from tqdm import tqdm\n",
        "\n",
        "PORT = 8000\n",
        "\n",
        "import sae_bench.sae_bench_utils.dataset_info as dataset_info\n",
        "import sae_bench.sae_bench_utils.dataset_utils as dataset_utils\n",
        "\n",
        "torch.set_grad_enabled(False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "if torch.backends.mps.is_available():\n",
        "    device = \"mps\"\n",
        "else:\n",
        "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "print(f\"Device: {device}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from sae_lens import SAE\n",
        "from transformer_lens import HookedTransformer\n",
        "\n",
        "model_dtype = torch.bfloat16\n",
        "\n",
        "model = HookedTransformer.from_pretrained(\n",
        "    \"pythia-70m-deduped\", device=device, dtype=model_dtype\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
        "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
        "# We also return the feature sparsities which are stored in HF for convenience.\n",
        "sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
        "    release=\"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
        "    sae_id=\"blocks.4.hook_resid_post__trainer_10\",\n",
        "    device=device,\n",
        ")\n",
        "8\n",
        "\n",
        "print(sae.cfg)\n",
        "\n",
        "context_length = sae.cfg.context_size\n",
        "layer = sae.cfg.hook_layer\n",
        "batch_size = 128"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "train_set_size = 4000\n",
        "test_set_size = 1000\n",
        "\n",
        "random_seed = 42\n",
        "\n",
        "dataset_name = \"bias_in_bios\"\n",
        "train_df, test_df = dataset_utils.load_huggingface_dataset(dataset_name)\n",
        "train_data, test_data = dataset_utils.get_multi_label_train_test_data(\n",
        "    train_df, test_df, dataset_name, train_set_size, test_set_size, random_seed\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "chosen_classes = [\"0\", \"1\", \"2\"]\n",
        "\n",
        "train_data = utils.filter_dataset(train_data, chosen_classes)\n",
        "test_data = utils.filter_dataset(test_data, chosen_classes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "train_data = utils.tokenize_data(train_data, model.tokenizer, context_length, device)\n",
        "test_data = utils.tokenize_data(test_data, model.tokenizer, context_length, device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "first_key = next(iter(train_data.keys()))\n",
        "print(first_key)\n",
        "print(train_data[first_key].keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(len(train_data[first_key][\"input_ids\"]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "hook_name = f\"blocks.{layer}.hook_resid_post\"\n",
        "\n",
        "\n",
        "@jaxtyped(typechecker=beartype)\n",
        "@torch.no_grad\n",
        "def get_all_llm_activations(\n",
        "    tokenized_inputs_dict: dict[\n",
        "        str, dict[str, Int[torch.Tensor, \"dataset_size seq_len\"]]\n",
        "    ],\n",
        "    model: HookedTransformer,\n",
        "    batch_size: int,\n",
        "    hook_name: str,\n",
        ") -> dict[str, Float[torch.Tensor, \"batch_size seq_len d_model\"]]:\n",
        "    \"\"\"VERY IMPORTANT NOTE: We zero out masked token activations in this function. Later, we ignore zeroed activations.\"\"\"\n",
        "    all_classes_acts_BLD = {}\n",
        "\n",
        "    for class_name in tokenized_inputs_dict:\n",
        "        all_acts_BLD = []\n",
        "        tokenized_inputs = tokenized_inputs_dict[class_name]\n",
        "\n",
        "        for i in tqdm(range(len(tokenized_inputs[\"input_ids\"]) // batch_size)):\n",
        "            tokens_BL = tokenized_inputs[\"input_ids\"][\n",
        "                i * batch_size : (i + 1) * batch_size\n",
        "            ]\n",
        "            attention_mask_BL = tokenized_inputs[\"attention_mask\"][\n",
        "                i * batch_size : (i + 1) * batch_size\n",
        "            ]\n",
        "\n",
        "            acts_BLD = None\n",
        "\n",
        "            def activation_hook(resid_BLD: torch.Tensor, hook):\n",
        "                nonlocal acts_BLD\n",
        "                acts_BLD = resid_BLD\n",
        "\n",
        "            model.run_with_hooks(\n",
        "                tokens_BL, return_type=None, fwd_hooks=[(hook_name, activation_hook)]\n",
        "            )\n",
        "\n",
        "            acts_BLD = acts_BLD * attention_mask_BL[:, :, None]\n",
        "            all_acts_BLD.append(acts_BLD)\n",
        "\n",
        "        all_acts_BLD = torch.cat(all_acts_BLD, dim=0)\n",
        "\n",
        "        all_classes_acts_BLD[class_name] = all_acts_BLD\n",
        "\n",
        "    return all_classes_acts_BLD\n",
        "\n",
        "\n",
        "all_train_acts_BLD = get_all_llm_activations(train_data, model, batch_size, hook_name)\n",
        "all_test_acts_BLD = get_all_llm_activations(test_data, model, batch_size, hook_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "@jaxtyped(typechecker=beartype)\n",
        "def create_meaned_model_activations(\n",
        "    all_llm_activations_BLD: dict[\n",
        "        str, Float[torch.Tensor, \"batch_size seq_len d_model\"]\n",
        "    ],\n",
        "    dtype: torch.dtype,\n",
        ") -> dict[str, Float[torch.Tensor, \"batch_size d_model\"]]:\n",
        "    \"\"\"VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens.\"\"\"\n",
        "    all_llm_activations_BD = {}\n",
        "    for class_name in all_llm_activations_BLD:\n",
        "        acts_BLD = all_llm_activations_BLD[class_name]\n",
        "        activations_BL = einops.reduce(acts_BLD, \"B L D -> B L\", \"sum\")\n",
        "        nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)\n",
        "        nonzero_acts_B = einops.reduce(nonzero_acts_BL, \"B L -> B\", \"sum\")\n",
        "\n",
        "        meaned_acts_BD = (\n",
        "            einops.reduce(acts_BLD, \"B L D -> B D\", \"sum\") / nonzero_acts_B[:, None]\n",
        "        )\n",
        "        all_llm_activations_BD[class_name] = meaned_acts_BD\n",
        "\n",
        "    return all_llm_activations_BD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "@jaxtyped(typechecker=beartype)\n",
        "@torch.no_grad\n",
        "def get_sae_meaned_activations(\n",
        "    all_llm_activations_BLD: dict[\n",
        "        str, Float[torch.Tensor, \"batch_size seq_len d_model\"]\n",
        "    ],\n",
        "    sae: SAE,\n",
        "    sae_batch_size: int,\n",
        "    dtype: torch.dtype,\n",
        ") -> dict[str, Float[torch.Tensor, \"batch_size d_sae\"]]:\n",
        "    \"\"\"VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens.\"\"\"\n",
        "    all_sae_activations_BF = {}\n",
        "    for class_name in tqdm(all_llm_activations_BLD):\n",
        "        all_acts_BLD = all_llm_activations_BLD[class_name]\n",
        "\n",
        "        all_acts_BF = []\n",
        "\n",
        "        for i in range(len(all_acts_BLD) // sae_batch_size):\n",
        "            acts_BLD = all_acts_BLD[i * sae_batch_size : (i + 1) * sae_batch_size]\n",
        "            acts_BLF = sae.encode(acts_BLD)\n",
        "\n",
        "            activations_BL = einops.reduce(acts_BLD, \"B L D -> B L\", \"sum\")\n",
        "            nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)\n",
        "            nonzero_acts_B = einops.reduce(nonzero_acts_BL, \"B L -> B\", \"sum\")\n",
        "\n",
        "            acts_BLF = acts_BLF * nonzero_acts_BL[:, :, None]\n",
        "            acts_BF = (\n",
        "                einops.reduce(acts_BLF, \"B L F -> B F\", \"sum\") / nonzero_acts_B[:, None]\n",
        "            )\n",
        "            acts_BF = acts_BF.to(dtype=dtype)\n",
        "\n",
        "            all_acts_BF.append(acts_BF)\n",
        "\n",
        "        all_acts_BF = torch.cat(all_acts_BF, dim=0)\n",
        "        all_sae_activations_BF[class_name] = all_acts_BF\n",
        "\n",
        "    return all_sae_activations_BF"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "@jaxtyped(typechecker=beartype)\n",
        "def prepare_probe_data(\n",
        "    all_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
        "    class_idx: str,\n",
        "    batch_size: int,\n",
        "    select_top_k: int | None = None,  # experimental feature\n",
        ") -> tuple[\n",
        "    list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
        "    list[Int[torch.Tensor, \"batch_size\"]],\n",
        "]:\n",
        "    positive_acts_BD = all_activations[class_idx]\n",
        "    device = positive_acts_BD.device\n",
        "\n",
        "    num_positive = len(positive_acts_BD)\n",
        "\n",
        "    # Collect all negative class activations and labels\n",
        "    negative_acts = []\n",
        "    for idx, acts in all_activations.items():\n",
        "        if idx != class_idx:\n",
        "            negative_acts.append(acts)\n",
        "\n",
        "    negative_acts = torch.cat(negative_acts)\n",
        "\n",
        "    # Randomly select num_positive samples from negative class\n",
        "    indices = torch.randperm(len(negative_acts))[:num_positive]\n",
        "    selected_negative_acts_BD = negative_acts[indices]\n",
        "\n",
        "    assert selected_negative_acts_BD.shape == positive_acts_BD.shape\n",
        "\n",
        "    # Experimental feature: find the top k features that differ the most between in distribution and out of distribution\n",
        "    # zero out the rest. Useful for k-sparse probing experiments.\n",
        "    if select_top_k is not None:\n",
        "        positive_distribution_D = positive_acts_BD.mean(dim=(0))\n",
        "        negative_distribution_D = negative_acts.mean(dim=(0))\n",
        "        distribution_diff_D = (positive_distribution_D - negative_distribution_D).abs()\n",
        "        top_k_indices_D = torch.argsort(distribution_diff_D, descending=True)[\n",
        "            :select_top_k\n",
        "        ]\n",
        "\n",
        "        mask_D = torch.ones(\n",
        "            distribution_diff_D.shape[0],\n",
        "            dtype=torch.bool,\n",
        "            device=positive_acts_BD.device,\n",
        "        )\n",
        "        mask_D[top_k_indices_D] = False\n",
        "\n",
        "        masked_positive_acts_BD = positive_acts_BD.clone()\n",
        "        masked_negative_acts_BD = selected_negative_acts_BD.clone()\n",
        "\n",
        "        masked_positive_acts_BD[:, mask_D] = 0.0\n",
        "        masked_negative_acts_BD[:, mask_D] = 0.0\n",
        "    else:\n",
        "        masked_positive_acts_BD = positive_acts_BD\n",
        "        masked_negative_acts_BD = selected_negative_acts_BD\n",
        "\n",
        "    # Combine positive and negative samples\n",
        "    combined_acts = torch.cat([masked_positive_acts_BD, masked_negative_acts_BD])\n",
        "\n",
        "    combined_labels = torch.empty(len(combined_acts), dtype=torch.int, device=device)\n",
        "    combined_labels[:num_positive] = dataset_info.POSITIVE_CLASS_LABEL\n",
        "    combined_labels[num_positive:] = dataset_info.NEGATIVE_CLASS_LABEL\n",
        "\n",
        "    # Shuffle the combined data\n",
        "    shuffle_indices = torch.randperm(len(combined_acts))\n",
        "    shuffled_acts = combined_acts[shuffle_indices]\n",
        "    shuffled_labels = combined_labels[shuffle_indices]\n",
        "\n",
        "    # Reshape into lists of tensors with specified batch_size\n",
        "    num_samples = len(shuffled_acts)\n",
        "    num_batches = num_samples // batch_size\n",
        "\n",
        "    batched_acts = [\n",
        "        shuffled_acts[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)\n",
        "    ]\n",
        "    batched_labels = [\n",
        "        shuffled_labels[i * batch_size : (i + 1) * batch_size]\n",
        "        for i in range(num_batches)\n",
        "    ]\n",
        "\n",
        "    return batched_acts, batched_labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Probe model and training\n",
        "class Probe(nn.Module):\n",
        "    def __init__(self, activation_dim: int, dtype: torch.dtype):\n",
        "        super().__init__()\n",
        "        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x).squeeze(-1)\n",
        "\n",
        "\n",
        "def test_probe(\n",
        "    input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
        "    label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
        "    probe: Probe,\n",
        ") -> float:\n",
        "    criterion = nn.BCEWithLogitsLoss()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        corrects_0 = []\n",
        "        corrects_1 = []\n",
        "        all_corrects = []\n",
        "        losses = []\n",
        "\n",
        "        for acts_BD, labels_B in zip(input_batches, label_batches):\n",
        "            logits_B = probe(acts_BD)\n",
        "            preds_B = (logits_B > 0.0).long()\n",
        "            correct_B = (preds_B == labels_B).float()\n",
        "\n",
        "            all_corrects.append(correct_B)\n",
        "            corrects_0.append(correct_B[labels_B == 0])\n",
        "            corrects_1.append(correct_B[labels_B == 1])\n",
        "\n",
        "            loss = criterion(logits_B, labels_B.to(dtype=probe.net.weight.dtype))\n",
        "            losses.append(loss)\n",
        "\n",
        "        accuracy_all = torch.cat(all_corrects).mean().item()\n",
        "        accuracy_0 = torch.cat(corrects_0).mean().item() if corrects_0 else 0.0\n",
        "        accuracy_1 = torch.cat(corrects_1).mean().item() if corrects_1 else 0.0\n",
        "        loss = torch.stack(losses).mean().item()\n",
        "\n",
        "    return accuracy_all\n",
        "\n",
        "\n",
        "def train_probe(\n",
        "    train_input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
        "    train_label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
        "    test_input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
        "    test_label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
        "    dim: int,\n",
        "    epochs: int,\n",
        "    device: str,\n",
        "    model_dtype: torch.dtype,\n",
        "    lr: float,\n",
        "    verbose: bool = False,\n",
        ") -> tuple[Probe, float]:\n",
        "    probe = Probe(dim, model_dtype).to(device)\n",
        "    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)\n",
        "    criterion = nn.BCEWithLogitsLoss()\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "        for acts_BD, labels_B in zip(train_input_batches, train_label_batches):\n",
        "            logits_B = probe(acts_BD)\n",
        "            loss = criterion(\n",
        "                logits_B, labels_B.clone().detach().to(device=device, dtype=model_dtype)\n",
        "            )\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "        train_accuracy = test_probe(train_input_batches, train_label_batches, probe)\n",
        "\n",
        "        test_accuracy = test_probe(test_input_batches, test_label_batches, probe)\n",
        "\n",
        "        if epoch == epochs - 1 and verbose:\n",
        "            print(\n",
        "                f\"\\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\\n\"\n",
        "            )\n",
        "\n",
        "    return probe, test_accuracy\n",
        "\n",
        "\n",
        "def train_probe_on_activations(\n",
        "    train_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
        "    test_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
        "    probe_batch_size: int,\n",
        "    epochs: int,\n",
        "    lr: float,\n",
        "    model_dtype: torch.dtype,\n",
        "    device: str,\n",
        "    select_top_k: int | None = None,\n",
        ") -> tuple[dict[str, Probe], dict[str, float]]:\n",
        "    torch.set_grad_enabled(True)\n",
        "\n",
        "    probes, test_accuracies = {}, {}\n",
        "\n",
        "    for profession in train_activations.keys():\n",
        "        train_acts, train_labels = prepare_probe_data(\n",
        "            train_activations, profession, probe_batch_size, select_top_k\n",
        "        )\n",
        "\n",
        "        test_acts, test_labels = prepare_probe_data(\n",
        "            test_activations, profession, probe_batch_size, select_top_k\n",
        "        )\n",
        "\n",
        "        activation_dim = train_acts[0].shape[1]\n",
        "\n",
        "        print(f\"activation dim: {activation_dim}\")\n",
        "\n",
        "        probe, test_accuracy = train_probe(\n",
        "            train_acts,\n",
        "            train_labels,\n",
        "            test_acts,\n",
        "            test_labels,\n",
        "            epochs=epochs,\n",
        "            dim=activation_dim,\n",
        "            device=device,\n",
        "            model_dtype=model_dtype,\n",
        "            lr=lr,\n",
        "            verbose=False,\n",
        "        )\n",
        "\n",
        "        print(f\"Test accuracy for {profession}: {test_accuracy}\")\n",
        "\n",
        "        probes[profession] = probe\n",
        "        test_accuracies[profession] = test_accuracy\n",
        "\n",
        "    return probes, test_accuracies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "first_key = next(iter(all_train_acts_BLD.keys()))\n",
        "print(all_train_acts_BLD[first_key].shape)\n",
        "all_train_acts_BD = create_meaned_model_activations(all_train_acts_BLD, model_dtype)\n",
        "all_test_acts_BD = create_meaned_model_activations(all_test_acts_BLD, model_dtype)\n",
        "print(all_train_acts_BD[first_key].shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "probe_batch_size = 128\n",
        "epochs = 10\n",
        "lr = 1e-3\n",
        "\n",
        "probes, test_accuracies = train_probe_on_activations(\n",
        "    all_train_acts_BD,\n",
        "    all_test_acts_BD,\n",
        "    probe_batch_size,\n",
        "    epochs,\n",
        "    lr,\n",
        "    model_dtype,\n",
        "    device,\n",
        "    select_top_k=None,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "sae_batch_size = 32\n",
        "all_sae_train_acts_BF = get_sae_meaned_activations(\n",
        "    all_train_acts_BLD, sae, sae_batch_size, model_dtype\n",
        ")\n",
        "all_sae_test_acts_BF = get_sae_meaned_activations(\n",
        "    all_test_acts_BLD, sae, sae_batch_size, model_dtype\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "sae_probes, sae_test_accuracies = train_probe_on_activations(\n",
        "    all_sae_train_acts_BF,\n",
        "    all_sae_test_acts_BF,\n",
        "    probe_batch_size,\n",
        "    epochs,\n",
        "    lr,\n",
        "    model_dtype,\n",
        "    device,\n",
        "    select_top_k=None,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "for k in [1, 5, 10, 20, 50]:\n",
        "    train_probe_on_activations(\n",
        "        all_sae_train_acts_BF,\n",
        "        all_sae_test_acts_BF,\n",
        "        probe_batch_size,\n",
        "        epochs,\n",
        "        lr,\n",
        "        model_dtype,\n",
        "        device,\n",
        "        select_top_k=k,\n",
        "    )"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
