{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard imports\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "import plotly.express as px\n",
    "from typing import Callable, Optional\n",
    "from tqdm import tqdm\n",
    "from jaxtyping import Int, Float, jaxtyped, BFloat16\n",
    "from beartype import beartype\n",
    "import einops\n",
    "\n",
    "# Imports for displaying vis in Colab / notebook\n",
    "import webbrowser\n",
    "import http.server\n",
    "import socketserver\n",
    "import threading\n",
    "\n",
    "PORT = 8000\n",
    "\n",
    "import sae_bench_utils.dataset_utils as dataset_utils\n",
    "import sae_bench_utils.dataset_info as dataset_info\n",
    "import sae_bench_utils\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "print(f\"Device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "from sae_lens import SAE\n",
    "\n",
    "model_dtype = torch.bfloat16\n",
    "\n",
    "model = HookedTransformer.from_pretrained(\n",
    "    \"pythia-70m-deduped\", device=device, dtype=model_dtype\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
    "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
    "# We also return the feature sparsities which are stored in HF for convenience.\n",
    "sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release=\"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
    "    sae_id=\"blocks.4.hook_resid_post__trainer_10\",\n",
    "    device=device,\n",
    ")\n",
    "8\n",
    "\n",
    "print(sae.cfg)\n",
    "\n",
    "context_length = sae.cfg.context_size\n",
    "layer = sae.cfg.hook_layer\n",
    "batch_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set_size = 4000\n",
    "test_set_size = 1000\n",
    "\n",
    "random_seed = 42\n",
    "\n",
    "dataset_name = \"bias_in_bios\"\n",
    "train_df, test_df = dataset_utils.load_huggingface_dataset(dataset_name)\n",
    "train_data, test_data = dataset_utils.get_multi_label_train_test_data(\n",
    "    train_df, test_df, dataset_name, train_set_size, test_set_size, random_seed\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chosen_classes = [\"0\", \"1\", \"2\"]\n",
    "\n",
    "train_data = utils.filter_dataset(train_data, chosen_classes)\n",
    "test_data = utils.filter_dataset(test_data, chosen_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = utils.tokenize_data(train_data, model.tokenizer, context_length, device)\n",
    "test_data = utils.tokenize_data(test_data, model.tokenizer, context_length, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_key = next(iter(train_data.keys()))\n",
    "print(first_key)\n",
    "print(train_data[first_key].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_data[first_key][\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook_name = f\"blocks.{layer}.hook_resid_post\"\n",
    "\n",
    "\n",
    "@jaxtyped(typechecker=beartype)\n",
    "@torch.no_grad\n",
    "def get_all_llm_activations(\n",
    "    tokenized_inputs_dict: dict[str, dict[str, Int[torch.Tensor, \"dataset_size seq_len\"]]],\n",
    "    model: HookedTransformer,\n",
    "    batch_size: int,\n",
    "    hook_name: str,\n",
    ") -> dict[str, Float[torch.Tensor, \"batch_size seq_len d_model\"]]:\n",
    "    \"\"\"VERY IMPORTANT NOTE: We zero out masked token activations in this function. Later, we ignore zeroed activations.\"\"\"\n",
    "    all_classes_acts_BLD = {}\n",
    "\n",
    "    for class_name in tokenized_inputs_dict:\n",
    "        all_acts_BLD = []\n",
    "        tokenized_inputs = tokenized_inputs_dict[class_name]\n",
    "\n",
    "        for i in tqdm(range(len(tokenized_inputs[\"input_ids\"]) // batch_size)):\n",
    "            tokens_BL = tokenized_inputs[\"input_ids\"][i * batch_size : (i + 1) * batch_size]\n",
    "            attention_mask_BL = tokenized_inputs[\"attention_mask\"][\n",
    "                i * batch_size : (i + 1) * batch_size\n",
    "            ]\n",
    "\n",
    "            acts_BLD = None\n",
    "\n",
    "            def activation_hook(resid_BLD: torch.Tensor, hook):\n",
    "                nonlocal acts_BLD\n",
    "                acts_BLD = resid_BLD\n",
    "\n",
    "            model.run_with_hooks(\n",
    "                tokens_BL, return_type=None, fwd_hooks=[(hook_name, activation_hook)]\n",
    "            )\n",
    "\n",
    "            acts_BLD = acts_BLD * attention_mask_BL[:, :, None]\n",
    "            all_acts_BLD.append(acts_BLD)\n",
    "\n",
    "        all_acts_BLD = torch.cat(all_acts_BLD, dim=0)\n",
    "\n",
    "        all_classes_acts_BLD[class_name] = all_acts_BLD\n",
    "\n",
    "    return all_classes_acts_BLD\n",
    "\n",
    "\n",
    "all_train_acts_BLD = get_all_llm_activations(train_data, model, batch_size, hook_name)\n",
    "all_test_acts_BLD = get_all_llm_activations(test_data, model, batch_size, hook_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jaxtyped(typechecker=beartype)\n",
    "def create_meaned_model_activations(\n",
    "    all_llm_activations_BLD: dict[str, Float[torch.Tensor, \"batch_size seq_len d_model\"]],\n",
    "    dtype: torch.dtype,\n",
    ") -> dict[str, Float[torch.Tensor, \"batch_size d_model\"]]:\n",
    "    \"\"\"VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens.\"\"\"\n",
    "    all_llm_activations_BD = {}\n",
    "    for class_name in all_llm_activations_BLD:\n",
    "        acts_BLD = all_llm_activations_BLD[class_name]\n",
    "        activations_BL = einops.reduce(acts_BLD, \"B L D -> B L\", \"sum\")\n",
    "        nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)\n",
    "        nonzero_acts_B = einops.reduce(nonzero_acts_BL, \"B L -> B\", \"sum\")\n",
    "\n",
    "        meaned_acts_BD = (\n",
    "            einops.reduce(acts_BLD, \"B L D -> B D\", \"sum\") / nonzero_acts_B[:, None]\n",
    "        )\n",
    "        all_llm_activations_BD[class_name] = meaned_acts_BD\n",
    "\n",
    "    return all_llm_activations_BD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jaxtyped(typechecker=beartype)\n",
    "@torch.no_grad\n",
    "def get_sae_meaned_activations(\n",
    "    all_llm_activations_BLD: dict[str, Float[torch.Tensor, \"batch_size seq_len d_model\"]],\n",
    "    sae: SAE,\n",
    "    sae_batch_size: int,\n",
    "    dtype: torch.dtype,\n",
    ") -> dict[str, Float[torch.Tensor, \"batch_size d_sae\"]]:\n",
    "    \"\"\"VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens.\"\"\"\n",
    "    all_sae_activations_BF = {}\n",
    "    for class_name in tqdm(all_llm_activations_BLD):\n",
    "        all_acts_BLD = all_llm_activations_BLD[class_name]\n",
    "\n",
    "        all_acts_BF = []\n",
    "\n",
    "        for i in range(len(all_acts_BLD) // sae_batch_size):\n",
    "            acts_BLD = all_acts_BLD[i * sae_batch_size : (i + 1) * sae_batch_size]\n",
    "            acts_BLF = sae.encode(acts_BLD)\n",
    "\n",
    "            activations_BL = einops.reduce(acts_BLD, \"B L D -> B L\", \"sum\")\n",
    "            nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)\n",
    "            nonzero_acts_B = einops.reduce(nonzero_acts_BL, \"B L -> B\", \"sum\")\n",
    "\n",
    "            acts_BLF = acts_BLF * nonzero_acts_BL[:, :, None]\n",
    "            acts_BF = einops.reduce(acts_BLF, \"B L F -> B F\", \"sum\") / nonzero_acts_B[:, None]\n",
    "            acts_BF = acts_BF.to(dtype=dtype)\n",
    "\n",
    "            all_acts_BF.append(acts_BF)\n",
    "\n",
    "        all_acts_BF = torch.cat(all_acts_BF, dim=0)\n",
    "        all_sae_activations_BF[class_name] = all_acts_BF\n",
    "\n",
    "    return all_sae_activations_BF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jaxtyped(typechecker=beartype)\n",
    "def prepare_probe_data(\n",
    "    all_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
    "    class_idx: str,\n",
    "    batch_size: int,\n",
    "    select_top_k: Optional[int] = None,  # experimental feature\n",
    ") -> tuple[\n",
    "    list[Float[torch.Tensor, \"batch_size d_model\"]], list[Int[torch.Tensor, \"batch_size\"]]\n",
    "]:\n",
    "    positive_acts_BD = all_activations[class_idx]\n",
    "    device = positive_acts_BD.device\n",
    "\n",
    "    num_positive = len(positive_acts_BD)\n",
    "\n",
    "    # Collect all negative class activations and labels\n",
    "    negative_acts = []\n",
    "    for idx, acts in all_activations.items():\n",
    "        if idx != class_idx:\n",
    "            negative_acts.append(acts)\n",
    "\n",
    "    negative_acts = torch.cat(negative_acts)\n",
    "\n",
    "    # Randomly select num_positive samples from negative class\n",
    "    indices = torch.randperm(len(negative_acts))[:num_positive]\n",
    "    selected_negative_acts_BD = negative_acts[indices]\n",
    "\n",
    "    assert selected_negative_acts_BD.shape == positive_acts_BD.shape\n",
    "\n",
    "    # Experimental feature: find the top k features that differ the most between in distribution and out of distribution\n",
    "    # zero out the rest. Useful for k-sparse probing experiments.\n",
    "    if select_top_k is not None:\n",
    "        positive_distribution_D = positive_acts_BD.mean(dim=(0))\n",
    "        negative_distribution_D = negative_acts.mean(dim=(0))\n",
    "        distribution_diff_D = (positive_distribution_D - negative_distribution_D).abs()\n",
    "        top_k_indices_D = torch.argsort(distribution_diff_D, descending=True)[:select_top_k]\n",
    "\n",
    "        mask_D = torch.ones(\n",
    "            distribution_diff_D.shape[0], dtype=torch.bool, device=positive_acts_BD.device\n",
    "        )\n",
    "        mask_D[top_k_indices_D] = False\n",
    "\n",
    "        masked_positive_acts_BD = positive_acts_BD.clone()\n",
    "        masked_negative_acts_BD = selected_negative_acts_BD.clone()\n",
    "\n",
    "        masked_positive_acts_BD[:, mask_D] = 0.0\n",
    "        masked_negative_acts_BD[:, mask_D] = 0.0\n",
    "    else:\n",
    "        masked_positive_acts_BD = positive_acts_BD\n",
    "        masked_negative_acts_BD = selected_negative_acts_BD\n",
    "\n",
    "    # Combine positive and negative samples\n",
    "    combined_acts = torch.cat([masked_positive_acts_BD, masked_negative_acts_BD])\n",
    "\n",
    "    combined_labels = torch.empty(len(combined_acts), dtype=torch.int, device=device)\n",
    "    combined_labels[:num_positive] = dataset_info.POSITIVE_CLASS_LABEL\n",
    "    combined_labels[num_positive:] = dataset_info.NEGATIVE_CLASS_LABEL\n",
    "\n",
    "    # Shuffle the combined data\n",
    "    shuffle_indices = torch.randperm(len(combined_acts))\n",
    "    shuffled_acts = combined_acts[shuffle_indices]\n",
    "    shuffled_labels = combined_labels[shuffle_indices]\n",
    "\n",
    "    # Reshape into lists of tensors with specified batch_size\n",
    "    num_samples = len(shuffled_acts)\n",
    "    num_batches = num_samples // batch_size\n",
    "\n",
    "    batched_acts = [\n",
    "        shuffled_acts[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)\n",
    "    ]\n",
    "    batched_labels = [\n",
    "        shuffled_labels[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)\n",
    "    ]\n",
    "\n",
    "    return batched_acts, batched_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Probe model and training\n",
    "class Probe(nn.Module):\n",
    "    def __init__(self, activation_dim: int, dtype: torch.dtype):\n",
    "        super().__init__()\n",
    "        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x).squeeze(-1)\n",
    "\n",
    "\n",
    "def test_probe(\n",
    "    input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
    "    label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
    "    probe: Probe,\n",
    ") -> float:\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        corrects_0 = []\n",
    "        corrects_1 = []\n",
    "        all_corrects = []\n",
    "        losses = []\n",
    "\n",
    "        for acts_BD, labels_B in zip(input_batches, label_batches):\n",
    "            logits_B = probe(acts_BD)\n",
    "            preds_B = (logits_B > 0.0).long()\n",
    "            correct_B = (preds_B == labels_B).float()\n",
    "\n",
    "            all_corrects.append(correct_B)\n",
    "            corrects_0.append(correct_B[labels_B == 0])\n",
    "            corrects_1.append(correct_B[labels_B == 1])\n",
    "\n",
    "            loss = criterion(logits_B, labels_B.to(dtype=probe.net.weight.dtype))\n",
    "            losses.append(loss)\n",
    "\n",
    "        accuracy_all = torch.cat(all_corrects).mean().item()\n",
    "        accuracy_0 = torch.cat(corrects_0).mean().item() if corrects_0 else 0.0\n",
    "        accuracy_1 = torch.cat(corrects_1).mean().item() if corrects_1 else 0.0\n",
    "        loss = torch.stack(losses).mean().item()\n",
    "\n",
    "    return accuracy_all\n",
    "\n",
    "\n",
    "def train_probe(\n",
    "    train_input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
    "    train_label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
    "    test_input_batches: list[Float[torch.Tensor, \"batch_size d_model\"]],\n",
    "    test_label_batches: list[Int[torch.Tensor, \"batch_size\"]],\n",
    "    dim: int,\n",
    "    epochs: int,\n",
    "    device: str,\n",
    "    model_dtype: torch.dtype,\n",
    "    lr: float,\n",
    "    verbose: bool = False,\n",
    ") -> tuple[Probe, float]:\n",
    "    probe = Probe(dim, model_dtype).to(device)\n",
    "    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        for acts_BD, labels_B in zip(train_input_batches, train_label_batches):\n",
    "            logits_B = probe(acts_BD)\n",
    "            loss = criterion(\n",
    "                logits_B, labels_B.clone().detach().to(device=device, dtype=model_dtype)\n",
    "            )\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        train_accuracy = test_probe(train_input_batches, train_label_batches, probe)\n",
    "\n",
    "        test_accuracy = test_probe(test_input_batches, test_label_batches, probe)\n",
    "\n",
    "        if epoch == epochs - 1 and verbose:\n",
    "            print(\n",
    "                f\"\\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\\n\"\n",
    "            )\n",
    "\n",
    "    return probe, test_accuracy\n",
    "\n",
    "\n",
    "def train_probe_on_activations(\n",
    "    train_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
    "    test_activations: dict[str, Float[torch.Tensor, \"num_datapoints d_model\"]],\n",
    "    probe_batch_size: int,\n",
    "    epochs: int,\n",
    "    lr: float,\n",
    "    model_dtype: torch.dtype,\n",
    "    device: str,\n",
    "    select_top_k: Optional[int] = None,\n",
    ") -> tuple[dict[str, Probe], dict[str, float]]:\n",
    "    torch.set_grad_enabled(True)\n",
    "\n",
    "    probes, test_accuracies = {}, {}\n",
    "\n",
    "    for profession in train_activations.keys():\n",
    "        train_acts, train_labels = prepare_probe_data(\n",
    "            train_activations, profession, probe_batch_size, select_top_k\n",
    "        )\n",
    "\n",
    "        test_acts, test_labels = prepare_probe_data(\n",
    "            test_activations, profession, probe_batch_size, select_top_k\n",
    "        )\n",
    "\n",
    "        activation_dim = train_acts[0].shape[1]\n",
    "\n",
    "        print(f\"activation dim: {activation_dim}\")\n",
    "\n",
    "        probe, test_accuracy = train_probe(\n",
    "            train_acts,\n",
    "            train_labels,\n",
    "            test_acts,\n",
    "            test_labels,\n",
    "            epochs=epochs,\n",
    "            dim=activation_dim,\n",
    "            device=device,\n",
    "            model_dtype=model_dtype,\n",
    "            lr=lr,\n",
    "            verbose=False,\n",
    "        )\n",
    "\n",
    "        print(f\"Test accuracy for {profession}: {test_accuracy}\")\n",
    "\n",
    "        probes[profession] = probe\n",
    "        test_accuracies[profession] = test_accuracy\n",
    "\n",
    "    return probes, test_accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_key = next(iter(all_train_acts_BLD.keys()))\n",
    "print(all_train_acts_BLD[first_key].shape)\n",
    "all_train_acts_BD = create_meaned_model_activations(all_train_acts_BLD, model_dtype)\n",
    "all_test_acts_BD = create_meaned_model_activations(all_test_acts_BLD, model_dtype)\n",
    "print(all_train_acts_BD[first_key].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probe_batch_size = 128\n",
    "epochs = 10\n",
    "lr = 1e-3\n",
    "\n",
    "probes, test_accuracies = train_probe_on_activations(\n",
    "    all_train_acts_BD,\n",
    "    all_test_acts_BD,\n",
    "    probe_batch_size,\n",
    "    epochs,\n",
    "    lr,\n",
    "    model_dtype,\n",
    "    device,\n",
    "    select_top_k=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_batch_size = 32\n",
    "all_sae_train_acts_BF = get_sae_meaned_activations(\n",
    "    all_train_acts_BLD, sae, sae_batch_size, model_dtype\n",
    ")\n",
    "all_sae_test_acts_BF = get_sae_meaned_activations(\n",
    "    all_test_acts_BLD, sae, sae_batch_size, model_dtype\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_probes, sae_test_accuracies = train_probe_on_activations(\n",
    "    all_sae_train_acts_BF,\n",
    "    all_sae_test_acts_BF,\n",
    "    probe_batch_size,\n",
    "    epochs,\n",
    "    lr,\n",
    "    model_dtype,\n",
    "    device,\n",
    "    select_top_k=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in [1, 5, 10, 20, 50]:\n",
    "    train_probe_on_activations(\n",
    "        all_sae_train_acts_BF,\n",
    "        all_sae_test_acts_BF,\n",
    "        probe_batch_size,\n",
    "        epochs,\n",
    "        lr,\n",
    "        model_dtype,\n",
    "        device,\n",
    "        select_top_k=k,\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
