{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "parent_dir = os.path.abspath('..')\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "from nnsight import LanguageModel\n",
    "import random\n",
    "from datasets import load_dataset\n",
    "import torch as t\n",
    "import torch.nn as nn\n",
    "from dictionary_learning import AutoEncoder\n",
    "from circuit import get_circuit\n",
    "from circuit_plotting import plot_circuit\n",
    "from activation_utils import SparseAct\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import gc\n",
    "\n",
    "DEBUGGING = False\n",
    "if DEBUGGING:\n",
    "    tracer_kwargs = {'validate' : True, 'scan' : True}\n",
    "else:\n",
    "    tracer_kwargs = {'validate' : False, 'scan' : False}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"LabHC/bias_in_bios\")\n",
    "profession_dict = {'professor' : 21, 'nurse' : 13}\n",
    "male_prof = 'professor'\n",
    "female_prof = 'nurse'\n",
    "\n",
    "DEVICE = 'cuda:0'\n",
    "model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)\n",
    "layer = 4\n",
    "\n",
    "dict_id = 10\n",
    "\n",
    "batch_size = 1024\n",
    "SEED = 42\n",
    "\n",
    "def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):\n",
    "    if train:\n",
    "        data = dataset['train']\n",
    "    else:\n",
    "        data = dataset['test']\n",
    "    if ambiguous:\n",
    "        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        n = min([len(neg), len(pos)])\n",
    "        neg, pos = neg[:n], pos[:n]\n",
    "        data = neg + pos\n",
    "        labels = [0]*n + [1]*n\n",
    "        idxs = list(range(2*n))\n",
    "        random.Random(seed).shuffle(idxs)\n",
    "        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]\n",
    "        true_labels = spurious_labels = labels\n",
    "    else:\n",
    "        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]\n",
    "        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]\n",
    "        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])\n",
    "        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]\n",
    "        data = neg_neg + neg_pos + pos_neg + pos_pos\n",
    "        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n\n",
    "        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n\n",
    "        idxs = list(range(4*n))\n",
    "        random.Random(seed).shuffle(idxs)\n",
    "        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]\n",
    "\n",
    "    batches = [\n",
    "        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)\n",
    "    ]\n",
    "\n",
    "    return batches\n",
    "\n",
    "def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):\n",
    "    if train:\n",
    "        data = dataset['train']\n",
    "    else:\n",
    "        data = dataset['test']\n",
    "    if ambiguous:\n",
    "        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        neg_labels, pos_labels = (0, 0), (1, 1)\n",
    "        subgroups = [(neg, neg_labels), (pos, pos_labels)]\n",
    "    else:\n",
    "        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]\n",
    "        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]\n",
    "        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]\n",
    "        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]\n",
    "        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)\n",
    "        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]\n",
    "    \n",
    "    out = {}\n",
    "    for data, label_profile in subgroups:\n",
    "        out[label_profile] = []\n",
    "        for i in range(0, len(data), batch_size):\n",
    "            text = data[i:i+batch_size]\n",
    "            out[label_profile].append(\n",
    "                (\n",
    "                    text,\n",
    "                    t.tensor([label_profile[0]]*len(text), device=DEVICE),\n",
    "                    t.tensor([label_profile[1]]*len(text), device=DEVICE)\n",
    "                )\n",
    "            )\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# probe training hyperparameters\n",
    "class Probe(nn.Module):\n",
    "    def __init__(self, activation_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Linear(activation_dim, 1, bias=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.net(x).squeeze(-1)\n",
    "        return logits\n",
    "\n",
    "def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):\n",
    "    t.manual_seed(seed)\n",
    "    probe = Probe(dim).to('cuda:0')\n",
    "    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "\n",
    "    losses = []\n",
    "    for epoch in range(epochs):\n",
    "        for batch in batches:\n",
    "            text = batch[0]\n",
    "            labels = batch[label_idx+1] \n",
    "            acts = get_acts(text)\n",
    "            logits = probe(acts)\n",
    "            loss = criterion(logits, labels.float())\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            losses.append(loss.item())\n",
    "\n",
    "    return probe, losses\n",
    "\n",
    "def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):\n",
    "    with t.no_grad():\n",
    "        corrects = []\n",
    "\n",
    "        for batch in batches:\n",
    "            text = batch[0]\n",
    "            labels = batch[label_idx+1]\n",
    "            acts = get_acts(text)\n",
    "            logits = probe(acts)\n",
    "            preds = (logits > 0.0).long()\n",
    "            corrects.append((preds == labels).float())\n",
    "        return t.cat(corrects).mean().item()\n",
    "    \n",
    "def get_acts(text):\n",
    "    with t.no_grad(): \n",
    "        with model.trace(text, **tracer_kwargs):\n",
    "            attn_mask = model.input[1]['attention_mask']\n",
    "            acts = model.gpt_neox.layers[layer].output[0]\n",
    "            acts = acts * attn_mask[:, :, None]\n",
    "            acts = acts.sum(1) / attn_mask.sum(1)[:, None]\n",
    "            acts = acts.save()\n",
    "        return acts.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probe, _ = train_probe(get_acts, label_idx=0)\n",
    "print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))\n",
    "batches = get_data(train=False, ambiguous=False)\n",
    "print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))\n",
    "print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed = model.gpt_neox.embed_in\n",
    "attns = [l.attention for l in model.gpt_neox.layers[:layer+1]]\n",
    "mlps = [l.mlp for l in model.gpt_neox.layers[:layer+1]]\n",
    "resids = model.gpt_neox.layers[:layer+1]\n",
    "\n",
    "dictionaries = {}\n",
    "ae = AutoEncoder(512, 64 * 512).to(DEVICE)\n",
    "ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_32768/ae.pt'))\n",
    "dictionaries[embed] = ae\n",
    "for i in range(layer + 1):\n",
    "    ae = AutoEncoder(512, 64 * 512).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_32768/ae.pt'))\n",
    "    dictionaries[attns[i]] = ae\n",
    "\n",
    "    ae = AutoEncoder(512, 64 * 512).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_32768/ae.pt'))\n",
    "    dictionaries[mlps[i]] = ae\n",
    "\n",
    "    ae = AutoEncoder(512, 64 * 512).to(DEVICE)\n",
    "    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_32768/ae.pt'))\n",
    "    dictionaries[resids[i]] = ae\n",
    "\n",
    "def metric_fn(model, labels=None):\n",
    "    attn_mask = model.input[1]['attention_mask']\n",
    "    acts = model.gpt_neox.layers[layer].output[0]\n",
    "    acts = acts * attn_mask[:, :, None]\n",
    "    acts = acts.sum(1) / attn_mask.sum(1)[:, None]\n",
    "    \n",
    "    return t.where(\n",
    "        labels == 0,\n",
    "        probe(acts),\n",
    "        - probe(acts)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_batches = 25\n",
    "batch_size = 4\n",
    "\n",
    "running_total = 0\n",
    "running_nodes = None\n",
    "running_edges = None\n",
    "for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):\n",
    "    if batch_idx == n_batches:\n",
    "        break\n",
    "    nodes, edges = get_circuit(\n",
    "        clean,\n",
    "        None,\n",
    "        model,\n",
    "        embed,\n",
    "        attns,\n",
    "        mlps,\n",
    "        resids,\n",
    "        dictionaries,\n",
    "        metric_fn,\n",
    "        metric_kwargs={'labels': labels},\n",
    "        node_threshold=0.2, # NOTE: use lower threshold if sigmoid\n",
    "        edge_threshold=0.02,\n",
    "    )\n",
    "    running_total += len(clean)\n",
    "    if running_nodes is None:\n",
    "        running_nodes = { k : len(clean) * v.to('cpu') if k != 'y' else None for k, v in nodes.items() }\n",
    "        running_edges = { k : { kk : len(clean) * v.to('cpu') for kk, v in vv.items() } for k, vv in edges.items() }\n",
    "    else:\n",
    "        for k, effect in nodes.items():\n",
    "            if k == 'y': continue\n",
    "            running_nodes[k] += len(clean) * effect.to('cpu')\n",
    "        for k in edges.keys():\n",
    "            for kk, effect in edges[k].items():\n",
    "                running_edges[k][kk] += len(clean) * effect.to('cpu')\n",
    "    del nodes, edges\n",
    "    gc.collect()\n",
    "\n",
    "for k in running_nodes.keys():\n",
    "    if k == 'y': continue\n",
    "    running_nodes[k] = running_nodes[k].to('cuda:0') / running_total\n",
    "for k in running_edges.keys():\n",
    "    for kk in running_edges[k].keys():\n",
    "        running_edges[k][kk] = running_edges[k][kk].to('cuda:0') / running_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only plot positive effect nodes\n",
    "running_nodes = {\n",
    "    k : SparseAct(act=t.clamp(v.act, min=0), resc=t.clamp(v.resc, min=0)) if v is not None else None for k, v in running_nodes.items()\n",
    "}\n",
    "\n",
    "# get annotations\n",
    "try:\n",
    "    annotations = {}\n",
    "    with open(f\"../annotations/10_32768.jsonl\", 'r') as annotations_data:\n",
    "        for annotation_line in annotations_data:\n",
    "            annotation = json.loads(annotation_line)\n",
    "            annotations[annotation[\"Name\"]] = annotation[\"Annotation\"]\n",
    "except:\n",
    "    annotations = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_circuit(\n",
    "    running_nodes,\n",
    "    running_edges,\n",
    "    layers=5,\n",
    "    node_threshold=0.1,\n",
    "    edge_threshold=0.01,\n",
    "    pen_thickness=1,\n",
    "    save_dir='../circuits/figures/bib_circuit',\n",
    "    annotations=annotations\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
