{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "parent_dir = os.path.abspath('..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "from datasets import load_dataset\n",
    "import random\n",
    "from nnsight import LanguageModel\n",
    "import torch as t\n",
    "from torch import nn\n",
    "from attribution import patching_effect\n",
    "from dictionary_learning import AutoEncoder, ActivationBuffer\n",
    "from dictionary_learning.dictionary import IdentityDict\n",
    "from dictionary_learning.interp import examine_dimension\n",
    "from dictionary_learning.utils import hf_dataset_to_generator\n",
    "from tqdm import tqdm\n",
    "import gc\n",
    "\n",
    "DEBUGGING = False\n",
    "\n",
    "if DEBUGGING:\n",
    "    tracer_kwargs = dict(scan=True, validate=True)\n",
    "else:\n",
    "    tracer_kwargs = dict(scan=False, validate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset hyperparameters\n",
    "dataset = load_dataset(\"LabHC/bias_in_bios\")\n",
    "profession_dict = {'professor' : 21, 'nurse' : 13}\n",
    "male_prof = 'professor'\n",
    "female_prof = 'nurse'\n",
    "\n",
    "# model hyperparameters\n",
    "DEVICE = 'cuda:0'\n",
    "model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)\n",
    "activation_dim = 512\n",
    "layer = 4\n",
    "\n",
    "# dictionary hyperparameters\n",
    "dict_id = 10\n",
    "expansion_factor = 64\n",
    "dictionary_size = expansion_factor * activation_dim\n",
    "\n",
    "# data preparation hyperparameters\n",
    "batch_size = 1024\n",
    "SEED = 42\n",
    "\n",
    "def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):\n",
    "    if train:\n",
    "        data = dataset['train']\n",
    "    else:\n",
    "        data = dataset['test']\n",
    "    if ambiguous:\n",
    "        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        n = min([len(neg), len(pos)])\n",
    "        neg, pos = neg[:n], pos[:n]\n",
    "        data = neg + pos\n",
    "        labels = [0]*n + [1]*n\n",
    "        idxs = list(range(2*n))\n",
    "        random.Random(seed).shuffle(idxs)\n",
    "        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]\n",
    "        true_labels = spurious_labels = labels\n",
    "    else:\n",
    "        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]\n",
    "        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]\n",
    "        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])\n",
    "        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]\n",
    "        data = neg_neg + neg_pos + pos_neg + pos_pos\n",
    "        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n\n",
    "        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n\n",
    "        idxs = list(range(4*n))\n",
    "        random.Random(seed).shuffle(idxs)\n",
    "        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]\n",
    "\n",
    "    batches = [\n",
    "        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)\n",
    "    ]\n",
    "\n",
    "    return batches\n",
    "\n",
    "def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):\n",
    "    if train:\n",
    "        data = dataset['train']\n",
    "    else:\n",
    "        data = dataset['test']\n",
    "    if ambiguous:\n",
    "        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        neg_labels, pos_labels = (0, 0), (1, 1)\n",
    "        subgroups = [(neg, neg_labels), (pos, pos_labels)]\n",
    "    else:\n",
    "        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]\n",
    "        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]\n",
    "        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)\n",
    "        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]\n",
    "    \n",
    "    out = {}\n",
    "    for data, label_profile in subgroups:\n",
    "        out[label_profile] = []\n",
    "        for i in range(0, len(data), batch_size):\n",
    "            text = data[i:i+batch_size]\n",
    "            out[label_profile].append(\n",
    "                (\n",
    "                    text,\n",
    "                    t.tensor([label_profile[0]]*len(text), device=DEVICE),\n",
    "                    t.tensor([label_profile[1]]*len(text), device=DEVICE)\n",
    "                )\n",
    "            )\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# probe training hyperparameters\n",
    "class Probe(nn.Module):\n",
    "    def __init__(self, activation_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Linear(activation_dim, 1, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.net(x).squeeze(-1)\n",
    "        return logits\n",
    "\n",
    "def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):\n",
    "    t.manual_seed(seed)\n",
    "    probe = Probe(dim).to(DEVICE)\n",
    "    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    losses = []\n",
    "    for epoch in range(epochs):\n",
    "        for batch in batches:\n",
    "            text = batch[0]\n",
    "            labels = batch[label_idx+1] \n",
    "            acts = get_acts(text)\n",
    "            logits = probe(acts)\n",
    "            loss = criterion(logits, labels.float())\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            losses.append(loss.item())\n",
    "\n",
    "    return probe, losses\n",
    "\n",
    "def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):\n",
    "    with t.no_grad():\n",
    "        corrects = []\n",
    "\n",
    "        for batch in batches:\n",
    "            text = batch[0]\n",
    "            labels = batch[label_idx+1]\n",
    "            acts = get_acts(text)\n",
    "            logits = probe(acts)\n",
    "            preds = (logits > 0.0).long()\n",
    "            corrects.append((preds == labels).float())\n",
    "        return t.cat(corrects).mean().item()\n",
    "    \n",
    "def get_acts(text):\n",
    "    with t.no_grad(): \n",
    "        with model.trace(text, **tracer_kwargs):\n",
    "            attn_mask = model.input[1]['attention_mask']\n",
    "            acts = model.gpt_neox.layers[layer].output[0]\n",
    "            acts = acts * attn_mask[:, :, None]\n",
    "            acts = acts.sum(1) / attn_mask.sum(1)[:, None]\n",
    "            acts = acts.save()\n",
    "        return acts.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False))\n",
    "print(\"ambiguous test accuracy\", test_probe(oracle, get_acts, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print(\"ground truth accuracy:\", test_probe(oracle, get_acts, batches=batches, label_idx=0))\n",
    "print(\"unintended feature accuracy:\", test_probe(oracle, get_acts, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get worst-group accuracy of oracle probe\n",
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(oracle, get_acts, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probe, _ = train_probe(get_acts, label_idx=0)\n",
    "print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))\n",
    "print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "submodules = []\n",
    "dictionaries = {}\n",
    "\n",
    "submodules.append(model.gpt_neox.embed_in)\n",
    "ae = AutoEncoder(512, dictionary_size).to(DEVICE)\n",
    "ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt'))\n",
    "dictionaries[model.gpt_neox.embed_in] = ae\n",
    "\n",
    "for i in range(layer + 1):\n",
    "    submodules.append(model.gpt_neox.layers[i].attention)\n",
    "    ae = AutoEncoder(512, dictionary_size).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))\n",
    "    dictionaries[model.gpt_neox.layers[i].attention] = ae\n",
    "\n",
    "    submodules.append(model.gpt_neox.layers[i].mlp)\n",
    "    ae = AutoEncoder(512, dictionary_size).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))\n",
    "    dictionaries[model.gpt_neox.layers[i].mlp] = ae\n",
    "\n",
    "    submodules.append(model.gpt_neox.layers[i])\n",
    "    ae = AutoEncoder(512, dictionary_size).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))\n",
    "    dictionaries[model.gpt_neox.layers[i]] = ae\n",
    "\n",
    "def metric_fn(model, labels=None):\n",
    "    attn_mask = model.input[1]['attention_mask']\n",
    "    acts = model.gpt_neox.layers[layer].output[0]\n",
    "    acts = acts * attn_mask[:, :, None]\n",
    "    acts = acts.sum(1) / attn_mask.sum(1)[:, None]\n",
    "    \n",
    "    return t.where(\n",
    "        labels == 0,\n",
    "        probe(acts),\n",
    "        - probe(acts)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_batches = 25\n",
    "batch_size = 4\n",
    "\n",
    "running_total = 0\n",
    "nodes = None\n",
    "\n",
    "for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):\n",
    "    if batch_idx == n_batches:\n",
    "        break\n",
    "\n",
    "    effects, _, _, _ = patching_effect(\n",
    "        clean,\n",
    "        None,\n",
    "        model,\n",
    "        submodules,\n",
    "        dictionaries,\n",
    "        metric_fn,\n",
    "        metric_kwargs=dict(labels=labels),\n",
    "        method='ig'\n",
    "    )\n",
    "    with t.no_grad():\n",
    "        if nodes is None:\n",
    "            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}\n",
    "        else:\n",
    "            for k, v in effects.items():\n",
    "                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)\n",
    "        running_total += len(clean)\n",
    "    del effects, _\n",
    "    gc.collect()\n",
    "\n",
    "nodes = {k : v / running_total for k, v in nodes.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = 0\n",
    "for component_idx, effect in enumerate(nodes.values()):\n",
    "    print(f\"Component {component_idx}:\")\n",
    "    for idx in (effect > 0.1).nonzero():\n",
    "        print(idx.item(), effect[idx].item())\n",
    "        n_features += 1\n",
    "print(f\"total features: {n_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "component_idx = 9\n",
    "feat_idx = 31098\n",
    "\n",
    "submodule = submodules[component_idx]\n",
    "dictionary = dictionaries[submodule]\n",
    "\n",
    "# interpret some features\n",
    "data = hf_dataset_to_generator(\"monology/pile-uncopyrighted\")\n",
    "buffer = ActivationBuffer(\n",
    "    data,\n",
    "    model,\n",
    "    submodule,\n",
    "    out_feats=512,\n",
    "    in_batch_size=128,\n",
    "    n_ctxs=512,\n",
    ")\n",
    "\n",
    "out = examine_dimension(\n",
    "    model,\n",
    "    submodule,\n",
    "    buffer,\n",
    "    dictionary,\n",
    "    dim_idx=feat_idx,\n",
    "    n_inputs=256\n",
    ")\n",
    "print(out.top_tokens)\n",
    "print(out.top_affected)\n",
    "out.top_contexts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feats_to_ablate = {\n",
    "    submodules[0] : [\n",
    "        946, # 'his'\n",
    "        # 5719, # 'research'\n",
    "        7392, # 'He'\n",
    "        # 10784, # 'Nursing'\n",
    "        17846, # 'He'\n",
    "        22068, # 'His'\n",
    "        # 23079, # 'tastes'\n",
    "        # 25904, # 'nursing'\n",
    "        28533, # 'She'\n",
    "        29476, # 'he'\n",
    "        31461, # 'His'\n",
    "        31467, # 'she'\n",
    "        32081, # 'her'\n",
    "        32469, # 'She'\n",
    "    ],\n",
    "    submodules[1] : [\n",
    "        # 23752, # capitalized words, especially pronouns\n",
    "    ],\n",
    "    submodules[2] : [\n",
    "        2995, # 'he'\n",
    "        3842, # 'She'\n",
    "        10258, # female names\n",
    "        13387, # 'she'\n",
    "        13968, # 'He'\n",
    "        18382, # 'her'\n",
    "        19369, # 'His'\n",
    "        28127, # 'She'\n",
    "        30518, # 'He'\n",
    "    ],\n",
    "    submodules[3] : [\n",
    "        1022, # 'she'\n",
    "        9651, # female names\n",
    "        10060, # 'She'\n",
    "        18967, # 'He'\n",
    "        22084, # 'he'\n",
    "        23898, # 'His'\n",
    "        # 24799, # promotes surnames\n",
    "        26504, # 'her'\n",
    "        29626, # 'his'\n",
    "        # 31201, # 'nursing'\n",
    "    ],\n",
    "    submodules[4] : [\n",
    "        # 8147, # unclear, something with names\n",
    "    ],\n",
    "    submodules[5] : [\n",
    "        24159, # 'She', 'she'\n",
    "        25018, # female names\n",
    "    ],\n",
    "    submodules[6] : [\n",
    "        4592, # 'her'\n",
    "        8920, # 'he'\n",
    "        9877, # female names\n",
    "        12128, # 'his'\n",
    "        15017, # 'she'\n",
    "        # 17369, # contact info\n",
    "        # 26969, # related to nursing\n",
    "        30248, # female names\n",
    "    ],\n",
    "    submodules[7] : [\n",
    "        13570, # promotes male-related words\n",
    "        27472, # female names, promotes female-related words\n",
    "    ],\n",
    "    submodules[8] : [\n",
    "    ],\n",
    "    submodules[9] : [\n",
    "        1995, # promotes female-associated words\n",
    "        9128, # feminine pronouns\n",
    "        11656, # promotes male-associated words\n",
    "        12440, # promotes female-associated words\n",
    "        # 14638, # related to contact information?\n",
    "        29206, # gendered pronouns\n",
    "        29295, # female names\n",
    "        # 31098, # nursing-related words\n",
    "    ],\n",
    "    submodules[10] : [\n",
    "        2959, # promotes female-associated words\n",
    "        19128, # promotes male-associated words\n",
    "        22029, # promotes female-associated words\n",
    "    ],\n",
    "    submodules[11] : [\n",
    "    ],\n",
    "    submodules[12] : [\n",
    "        19558, # promotes female-associated words\n",
    "        23545, # 'she'\n",
    "        24806, # 'her'\n",
    "        27334, # promotes male-associated words\n",
    "        31453, # female names\n",
    "    ],\n",
    "    submodules[13] : [\n",
    "        31101, # promotes female-associated words\n",
    "    ],\n",
    "    submodules[14] : [\n",
    "    ],\n",
    "    submodules[15] : [\n",
    "        9766, # promotes female-associated words\n",
    "        12420, # promotes female pronouns\n",
    "        30220, # promotes male pronouns\n",
    "    ]\n",
    "}\n",
    "\n",
    "print(f\"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def n_hot(feats, dim=dictionary_size):\n",
    "    out = t.zeros(dim, dtype=t.bool, device=DEVICE)\n",
    "    for feat in feats:\n",
    "        out[feat] = True\n",
    "    return out\n",
    "\n",
    "feats_to_ablate = {\n",
    "    submodule : n_hot(feats) for submodule, feats in feats_to_ablate.items()\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_tuple = {}\n",
    "with t.no_grad(), model.trace(\"_\"):\n",
    "    for submodule in submodules:\n",
    "        is_tuple[submodule] = type(submodule.output.shape) == tuple\n",
    "\n",
    "def get_acts_ablated(\n",
    "    text,\n",
    "    model,\n",
    "    submodules,\n",
    "    dictionaries,\n",
    "    to_ablate\n",
    "):\n",
    "\n",
    "    with t.no_grad(), model.trace(text, **tracer_kwargs):\n",
    "        for submodule in submodules:\n",
    "            dictionary = dictionaries[submodule]\n",
    "            feat_idxs = to_ablate[submodule]\n",
    "            x = submodule.output\n",
    "            if is_tuple[submodule]:\n",
    "                x = x[0]\n",
    "            x_hat, f = dictionary(x, output_features=True)\n",
    "            res = x - x_hat\n",
    "            f[...,feat_idxs] = 0. # zero ablation\n",
    "            if is_tuple[submodule]:\n",
    "                submodule.output[0][:] = dictionary.decode(f) + res\n",
    "            else:\n",
    "                submodule.output = dictionary.decode(f) + res\n",
    "        attn_mask = model.input[1]['attention_mask']\n",
    "        act = model.gpt_neox.layers[layer].output[0]\n",
    "        act = act * attn_mask[:, :, None]\n",
    "        act = act.sum(1) / attn_mask.sum(1)[:, None]\n",
    "        act = act.save()\n",
    "    return act.value\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Accuracy after ablating features judged irrelevant by human annotators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)\n",
    "\n",
    "print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))\n",
    "print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concept bottleneck probing baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = [    \n",
    "    ' nurse',\n",
    "    ' healthcare',\n",
    "    ' hospital',\n",
    "    ' patient',\n",
    "    ' medical',\n",
    "    ' clinic',\n",
    "    ' triage',\n",
    "    ' medication',\n",
    "    ' emergency',\n",
    "    ' surgery',\n",
    "    ' professor',\n",
    "    ' academia',\n",
    "    ' research',\n",
    "    ' university',\n",
    "    ' tenure',\n",
    "    ' faculty',\n",
    "    ' dissertation',\n",
    "    ' sabbatical',\n",
    "    ' publication',\n",
    "    ' grant',\n",
    "]\n",
    "# get concept vectors\n",
    "with t.no_grad(), model.trace(concepts):\n",
    "    concept_vectors = model.gpt_neox.layers[layer].output[0][:, -1, :].save()\n",
    "concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)\n",
    "\n",
    "def get_bottleneck(text):\n",
    "    with t.no_grad(), model.trace(text, **tracer_kwargs):\n",
    "        attn_mask = model.input[1]['attention_mask']\n",
    "        acts = model.gpt_neox.layers[layer].output[0]\n",
    "        acts = acts * attn_mask[:, :, None]\n",
    "        acts = acts.sum(1) / attn_mask.sum(1)[:, None]\n",
    "        # compute cosine similarity with concept vectors\n",
    "        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])\n",
    "        sims = sims.save()\n",
    "    return sims.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))\n",
    "print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get subgroup accuracies\n",
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get skyline neuron performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get neurons which are most influential for giving gender label\n",
    "neuron_dicts = {\n",
    "    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules\n",
    "}\n",
    "\n",
    "n_batches = 25\n",
    "batch_size = 4\n",
    "\n",
    "running_total = 0\n",
    "nodes = None\n",
    "\n",
    "for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):\n",
    "    if batch_idx == n_batches:\n",
    "        break\n",
    "\n",
    "    effects, _, _, _ = patching_effect(\n",
    "        clean,\n",
    "        None,\n",
    "        model,\n",
    "        submodules,\n",
    "        neuron_dicts,\n",
    "        metric_fn,\n",
    "        metric_kwargs=dict(labels=labels),\n",
    "        method='ig'\n",
    "    )\n",
    "    with t.no_grad():\n",
    "        if nodes is None:\n",
    "            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}\n",
    "        else:\n",
    "            for k, v in effects.items():\n",
    "                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)\n",
    "        running_total += len(clean)\n",
    "    del effects, _\n",
    "    gc.collect()\n",
    "\n",
    "nodes = {k : v / running_total for k, v in nodes.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neurons_to_ablate = {}\n",
    "total_neurons = 0\n",
    "for component_idx, effect in enumerate(nodes.values()):\n",
    "    print(f\"Component {component_idx}:\")\n",
    "    neurons_to_ablate[submodules[component_idx]] = []\n",
    "    for idx in (effect.act > 0.2135).nonzero():\n",
    "        print(idx.item(), effect[idx].item())\n",
    "        neurons_to_ablate[submodules[component_idx]].append(idx.item())\n",
    "        total_neurons += 1\n",
    "print(f\"total neurons: {total_neurons}\")\n",
    "\n",
    "neurons_to_ablate = {\n",
    "    submodule : n_hot([neuron_idx], dim=512) for submodule, neuron_idx in neurons_to_ablate.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acts_abl(text):\n",
    "    with t.no_grad(), model.trace(text, **tracer_kwargs):\n",
    "        for submodule in submodules:\n",
    "            x = submodule.output\n",
    "            if is_tuple[submodule]:\n",
    "                x = x[0]\n",
    "            x[...,neurons_to_ablate[submodule]] = x.mean(dim=(0,1))[...,neurons_to_ablate[submodule]] # mean ablation\n",
    "            if is_tuple[submodule]:\n",
    "                submodule.output[0][:] = x\n",
    "            else:\n",
    "                submodule.output = x\n",
    "\n",
    "        attn_mask = model.input[1]['attention_mask']\n",
    "        act = model.gpt_neox.layers[layer].output[0]\n",
    "        act = act * attn_mask[:, :, None]\n",
    "        act = act.sum(1) / attn_mask.sum(1)[:, None]\n",
    "        act = act.save()\n",
    "    return act.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))\n",
    "print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get skyline feature performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_batches = 25\n",
    "batch_size = 4\n",
    "\n",
    "running_total = 0\n",
    "running_nodes = None\n",
    "\n",
    "for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):\n",
    "    if batch_idx == n_batches:\n",
    "        break\n",
    "\n",
    "    effects, _, _, _ = patching_effect(\n",
    "        clean,\n",
    "        None,\n",
    "        model,\n",
    "        submodules,\n",
    "        dictionaries,\n",
    "        metric_fn,\n",
    "        metric_kwargs=dict(labels=labels),\n",
    "        method='ig'\n",
    "    )\n",
    "    with t.no_grad():\n",
    "        if running_nodes is None:\n",
    "            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}\n",
    "        else:\n",
    "            for k, v in effects.items():\n",
    "                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)\n",
    "        running_total += len(clean)\n",
    "    del effects, _\n",
    "    gc.collect()\n",
    "\n",
    "nodes = {k : v / running_total for k, v in running_nodes.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_feats_to_ablate = {}\n",
    "total_features = 0\n",
    "for component_idx, effect in enumerate(nodes.values()):\n",
    "    print(f\"Component {component_idx}:\")\n",
    "    top_feats_to_ablate[submodules[component_idx]] = []\n",
    "    for idx in (effect > 0.1107).nonzero():\n",
    "        print(idx.item(), effect[idx].item())\n",
    "        top_feats_to_ablate[submodules[component_idx]].append(idx.item())\n",
    "        total_features += 1\n",
    "print(f\"total features: {total_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_feats_to_ablate = {\n",
    "    submodule : n_hot(feats) for submodule, feats in top_feats_to_ablate.items()\n",
    "}\n",
    "get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, top_feats_to_ablate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))\n",
    "print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retraining probe on activations after ablating features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)\n",
    "\n",
    "new_probe, _ = train_probe(get_acts_abl, label_idx=0)\n",
    "print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))\n",
    "print('Unintended feature accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subgroups = get_subgroups(train=False, ambiguous=False)\n",
    "for label_profile, batches in subgroups.items():\n",
    "    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
