{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Tracing\n",
    "\n",
    "A demonstration of the double-intervention causal tracing method.\n",
    "\n",
    "The strategy used by causal tracing is to understand important\n",
    "states within a transfomer by doing two interventions simultaneously:\n",
    "\n",
    "1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens\n",
    "   to frustrate the ability of the transformer to accurately complete factual\n",
    "   prompts about the subject.\n",
    "2. Restore a subset of the internal hidden states.  In our paper, we scan\n",
    "   hidden states at all layers and all tokens, searching for individual states\n",
    "   that carry the necessary information for the transformer to recover its\n",
    "   capability to complete the factual prompt.\n",
    "\n",
    "The traces of decisive states can be shown on a heatmap.  This notebook\n",
    "demonstrates the code for conducting causal traces and creating these heatmaps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `experiments.causal_trace` module contains a set of functions for running causal traces.\n",
    "\n",
    "In this notebook, we reproduce, demonstrate and discuss the interesting functions.\n",
    "\n",
    "We begin by importing several utility functions that deal with tokens and transformer models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os, re, json\n",
    "import torch, numpy\n",
    "from collections import defaultdict\n",
    "from util import nethook\n",
    "from util.globals import DATA_DIR\n",
    "from experiments.causal_trace import (\n",
    "    ModelAndTokenizer,\n",
    "    layername,\n",
    "    guess_subject,\n",
    "    plot_trace_heatmap,\n",
    ")\n",
    "from experiments.causal_trace import (\n",
    "    make_inputs,\n",
    "    decode_tokens,\n",
    "    find_token_range,\n",
    "    predict_token,\n",
    "    predict_from_input,\n",
    "    collect_embedding_std,\n",
    ")\n",
    "from dsets import KnownsDataset\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we load a model and tokenizer, and show that it can complete a couple factual statements correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_name = \"gpt2-xl\"  # or \"EleutherAI/gpt-j-6B\" or \"EleutherAI/gpt-neox-20b\"\n",
    "mt = ModelAndTokenizer(\n",
    "    model_name,\n",
    "    torch_dtype=(torch.float16 if \"20b\" in model_name else None),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_token(\n",
    "    mt,\n",
    "    [\"Megan Rapinoe plays the sport of\", \"The Space Needle is in the city of\"],\n",
    "    return_p=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To obfuscate the subject during Causal Tracing, we use noise sampled from a zero-centered spherical Gaussian, whose stddev is 3 times the $\\sigma$ stddev the model's embeddings. Let's compute that value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knowns = KnownsDataset(DATA_DIR)  # Dataset of known facts\n",
    "noise_level = 3 * collect_embedding_std(mt, [k[\"subject\"] for k in knowns])\n",
    "print(f\"Using noise level {noise_level}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tracing a single location\n",
    "\n",
    "The core intervention in causal tracing is captured in this function:\n",
    "\n",
    "`trace_with_patch` a single causal trace.\n",
    "\n",
    "It enables running a batch of inferences with two interventions.\n",
    "\n",
    "  1. Random noise can be added to corrupt the inputs of some of the batch.\n",
    "  2. At any point, clean non-noised state can be copied over from an\n",
    "     uncorrupted batch member to other batch members.\n",
    "  \n",
    "The convention used by this function is that the zeroth element of the\n",
    "batch is the uncorrupted run, and the subsequent elements of the batch\n",
    "are the corrupted runs.  The argument tokens_to_mix specifies an\n",
    "be corrupted by adding Gaussian noise to the embedding for the batch\n",
    "inputs other than the first element in the batch.  Alternately,\n",
    "subsequent runs could be corrupted by simply providing different\n",
    "input tokens via the passed input batch.\n",
    "\n",
    "To ensure that corrupted behavior is representative, in practice, we\n",
    "will actually run several (ten) corrupted runs in the same batch,\n",
    "each with its own sample of noise.\n",
    "\n",
    "Then when running, a specified set of hidden states will be uncorrupted\n",
    "by restoring their values to the same vector that they had in the\n",
    "zeroth uncorrupted run.  This set of hidden states is listed in\n",
    "states_to_patch, by listing [(token_index, layername), ...] pairs.\n",
    "To trace the effect of just a single state, this can be just a single\n",
    "token/layer pair.  To trace the effect of restoring a set of states,\n",
    "any number of token indices and layers can be listed.\n",
    "\n",
    "Note that this function is also in experiments.causal_trace; the code\n",
    "is shown here to show the logic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trace_with_patch(\n",
    "    model,  # The model\n",
    "    inp,  # A set of inputs\n",
    "    states_to_patch,  # A list of (token index, layername) triples to restore\n",
    "    answers_t,  # Answer probabilities to collect\n",
    "    tokens_to_mix,  # Range of tokens to corrupt (begin, end)\n",
    "    noise=0.1,  # Level of noise to add\n",
    "    trace_layers=None,  # List of traced outputs to return\n",
    "):\n",
    "    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise\n",
    "    patch_spec = defaultdict(list)\n",
    "    for t, l in states_to_patch:\n",
    "        patch_spec[l].append(t)\n",
    "    embed_layername = layername(model, 0, \"embed\")\n",
    "\n",
    "    def untuple(x):\n",
    "        return x[0] if isinstance(x, tuple) else x\n",
    "\n",
    "    # Define the model-patching rule.\n",
    "    def patch_rep(x, layer):\n",
    "        if layer == embed_layername:\n",
    "            # If requested, we corrupt a range of token embeddings on batch items x[1:]\n",
    "            if tokens_to_mix is not None:\n",
    "                b, e = tokens_to_mix\n",
    "                x[1:, b:e] += noise * torch.from_numpy(\n",
    "                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])\n",
    "                ).to(x.device)\n",
    "            return x\n",
    "        if layer not in patch_spec:\n",
    "            return x\n",
    "        # If this layer is in the patch_spec, restore the uncorrupted hidden state\n",
    "        # for selected tokens.\n",
    "        h = untuple(x)\n",
    "        for t in patch_spec[layer]:\n",
    "            h[1:, t] = h[0, t]\n",
    "        return x\n",
    "\n",
    "    # With the patching rules defined, run the patched model in inference.\n",
    "    additional_layers = [] if trace_layers is None else trace_layers\n",
    "    with torch.no_grad(), nethook.TraceDict(\n",
    "        model,\n",
    "        [embed_layername] + list(patch_spec.keys()) + additional_layers,\n",
    "        edit_output=patch_rep,\n",
    "    ) as td:\n",
    "        outputs_exp = model(**inp)\n",
    "\n",
    "    # We report softmax probabilities for the answers_t token predictions of interest.\n",
    "    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]\n",
    "\n",
    "    # If tracing all layers, collect all activations together to return.\n",
    "    if trace_layers is not None:\n",
    "        all_traced = torch.stack(\n",
    "            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2\n",
    "        )\n",
    "        return probs, all_traced\n",
    "\n",
    "    return probs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scanning all locations\n",
    "\n",
    "A causal flow heatmap is created by repeating `trace_with_patch` at every individual hidden state, and measuring the impact of restoring state at each location.\n",
    "\n",
    "The `calculate_hidden_flow` function does this loop.  It handles both the case of restoring a single hidden state, and also restoring MLP or attention states.  Because MLP and attention make small residual contributions, to observe a causal effect in those cases, we need to restore several layers of contributions at once, which is done by `trace_important_window`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_hidden_flow(\n",
    "    mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Runs causal tracing over every token/layer combination in the network\n",
    "    and returns a dictionary numerically summarizing the results.\n",
    "    \"\"\"\n",
    "    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))\n",
    "    with torch.no_grad():\n",
    "        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]\n",
    "    [answer] = decode_tokens(mt.tokenizer, [answer_t])\n",
    "    e_range = find_token_range(mt.tokenizer, inp[\"input_ids\"][0], subject)\n",
    "    low_score = trace_with_patch(\n",
    "        mt.model, inp, [], answer_t, e_range, noise=noise\n",
    "    ).item()\n",
    "    if not kind:\n",
    "        differences = trace_important_states(\n",
    "            mt.model, mt.num_layers, inp, e_range, answer_t, noise=noise\n",
    "        )\n",
    "    else:\n",
    "        differences = trace_important_window(\n",
    "            mt.model,\n",
    "            mt.num_layers,\n",
    "            inp,\n",
    "            e_range,\n",
    "            answer_t,\n",
    "            noise=noise,\n",
    "            window=window,\n",
    "            kind=kind,\n",
    "        )\n",
    "    differences = differences.detach().cpu()\n",
    "    return dict(\n",
    "        scores=differences,\n",
    "        low_score=low_score,\n",
    "        high_score=base_score,\n",
    "        input_ids=inp[\"input_ids\"][0],\n",
    "        input_tokens=decode_tokens(mt.tokenizer, inp[\"input_ids\"][0]),\n",
    "        subject_range=e_range,\n",
    "        answer=answer,\n",
    "        window=window,\n",
    "        kind=kind or \"\",\n",
    "    )\n",
    "\n",
    "\n",
    "def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1):\n",
    "    ntoks = inp[\"input_ids\"].shape[1]\n",
    "    table = []\n",
    "    for tnum in range(ntoks):\n",
    "        row = []\n",
    "        for layer in range(0, num_layers):\n",
    "            r = trace_with_patch(\n",
    "                model,\n",
    "                inp,\n",
    "                [(tnum, layername(model, layer))],\n",
    "                answer_t,\n",
    "                tokens_to_mix=e_range,\n",
    "                noise=noise,\n",
    "            )\n",
    "            row.append(r)\n",
    "        table.append(torch.stack(row))\n",
    "    return torch.stack(table)\n",
    "\n",
    "\n",
    "def trace_important_window(\n",
    "    model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1\n",
    "):\n",
    "    ntoks = inp[\"input_ids\"].shape[1]\n",
    "    table = []\n",
    "    for tnum in range(ntoks):\n",
    "        row = []\n",
    "        for layer in range(0, num_layers):\n",
    "            layerlist = [\n",
    "                (tnum, layername(model, L, kind))\n",
    "                for L in range(\n",
    "                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))\n",
    "                )\n",
    "            ]\n",
    "            r = trace_with_patch(\n",
    "                model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise\n",
    "            )\n",
    "            row.append(r)\n",
    "        table.append(torch.stack(row))\n",
    "    return torch.stack(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting the results\n",
    "\n",
    "The `plot_trace_heatmap` function draws the data on a heatmap.  That function is not shown here; it is in `experiments.causal_trace`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hidden_flow(\n",
    "    mt,\n",
    "    prompt,\n",
    "    subject=None,\n",
    "    samples=10,\n",
    "    noise=0.1,\n",
    "    window=10,\n",
    "    kind=None,\n",
    "    modelname=None,\n",
    "    savepdf=None,\n",
    "):\n",
    "    if subject is None:\n",
    "        subject = guess_subject(prompt)\n",
    "    result = calculate_hidden_flow(\n",
    "        mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind\n",
    "    )\n",
    "    plot_trace_heatmap(result, savepdf, modelname=modelname)\n",
    "\n",
    "\n",
    "def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None):\n",
    "    for kind in [None, \"mlp\", \"attn\"]:\n",
    "        plot_hidden_flow(\n",
    "            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following prompt can be changed to any factual statement to trace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_all_flow(mt, \"The Space Needle is in the city of\", noise=noise_level)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we trace a few more factual statements from a file of test cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for knowledge in knowns[:5]:\n",
    "    plot_all_flow(mt, knowledge[\"prompt\"], knowledge[\"subject\"], noise=noise_level)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3.9.7 ('rome')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
