{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1a852108",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace_frozen_mlp_attn.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" align=\"left\"/></a>&nbsp;or in a local notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b9dc9c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:14:03.871869Z",
     "iopub.status.busy": "2022-09-02T03:14:03.871146Z",
     "iopub.status.idle": "2022-09-02T03:14:03.894378Z",
     "shell.execute_reply": "2022-09-02T03:14:03.894777Z"
    }
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit\n",
    "cd /content && rm -rf /content/memit\n",
    "git clone https://github.com/kmeng01/memit memit > install.log 2>&1\n",
    "pip install -r /content/memit/scripts/colab_reqs/rome.txt >> install.log 2>&1\n",
    "pip install --upgrade google-cloud-storage >> install.log 2>&1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "befc82ae",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:14:03.900021Z",
     "iopub.status.busy": "2022-09-02T03:14:03.899545Z",
     "iopub.status.idle": "2022-09-02T03:14:03.901624Z",
     "shell.execute_reply": "2022-09-02T03:14:03.901219Z"
    }
   },
   "outputs": [],
   "source": [
    "IS_COLAB = False\n",
    "try:\n",
    "    import google.colab, torch, os\n",
    "\n",
    "    IS_COLAB = True\n",
    "    os.chdir(\"/content/memit\")\n",
    "    if not torch.cuda.is_available():\n",
    "        raise Exception(\"Change runtime type to include a GPU.\")\n",
    "except ModuleNotFoundError as _:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8203f43f",
   "metadata": {},
   "source": [
    "# Frozen-MLP causal tracing\n",
    "\n",
    "This notebook executes causal traces with all the MLP modules for a token disabled (we also do Attn modules separately), by freezing them at the corrupted state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f7e67a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:14:03.905534Z",
     "iopub.status.busy": "2022-09-02T03:14:03.905029Z",
     "iopub.status.idle": "2022-09-02T03:14:03.916397Z",
     "shell.execute_reply": "2022-09-02T03:14:03.916735Z"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ba3338",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:14:03.921948Z",
     "iopub.status.busy": "2022-09-02T03:14:03.921594Z",
     "iopub.status.idle": "2022-09-02T03:14:06.501717Z",
     "shell.execute_reply": "2022-09-02T03:14:06.502266Z"
    }
   },
   "outputs": [],
   "source": [
    "import os, re\n",
    "import torch, numpy\n",
    "import importlib, copy\n",
    "import transformers\n",
    "from collections import defaultdict\n",
    "from util import nethook\n",
    "from matplotlib import pyplot as plt\n",
    "from experiments.causal_trace import (\n",
    "    ModelAndTokenizer,\n",
    "    make_inputs,\n",
    "    predict_from_input,\n",
    "    decode_tokens,\n",
    "    layername,\n",
    "    find_token_range,\n",
    "    trace_with_patch,\n",
    "    plot_trace_heatmap,\n",
    "    collect_embedding_std,\n",
    ")\n",
    "from util.globals import DATA_DIR\n",
    "from dsets import KnownsDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16c46e43",
   "metadata": {},
   "source": [
    "Load model and compute its corresponding noise level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce71fd8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:14:06.506653Z",
     "iopub.status.busy": "2022-09-02T03:14:06.506137Z",
     "iopub.status.idle": "2022-09-02T03:17:51.611062Z",
     "shell.execute_reply": "2022-09-02T03:17:51.610434Z"
    }
   },
   "outputs": [],
   "source": [
    "model_name = \"EleutherAI/gpt-j-6B\"  # \"gpt2-xl\" or \"EleutherAI/gpt-j-6B\" or \"EleutherAI/gpt-neox-20b\"\n",
    "mt = ModelAndTokenizer(\n",
    "    model_name,\n",
    "    low_cpu_mem_usage=IS_COLAB,\n",
    "    torch_dtype=(torch.float16 if \"20b\" in model_name else None),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0089c69",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:17:51.615660Z",
     "iopub.status.busy": "2022-09-02T03:17:51.614971Z",
     "iopub.status.idle": "2022-09-02T03:19:09.157205Z",
     "shell.execute_reply": "2022-09-02T03:19:09.156516Z"
    }
   },
   "outputs": [],
   "source": [
    "knowns = KnownsDataset(DATA_DIR)  # Dataset of known facts\n",
    "noise_level = 3 * collect_embedding_std(mt, [k[\"subject\"] for k in knowns])\n",
    "print(f\"Using noise level {noise_level}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c18cfb8",
   "metadata": {},
   "source": [
    "## Tracing a single location\n",
    "\n",
    "The strategy here is to use three interventions, rather than two:\n",
    "\n",
    "1. As before, corrupt a subset of the input.\n",
    "2. As before, restore a subset of the internal hidden states to see\n",
    "   which ones restore the output.\n",
    "3. But now, while doing so, freeze a set of MLP modules when processing\n",
    "   the specific subject token, so that they are stuck in the corrupted\n",
    "   state.  This reveals effect of the hidden states on everything\n",
    "   except for those particular MLP executions.\n",
    "   \n",
    "This three-way intervention is implemented in `trace_with_repatch`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a36b314e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:19:09.168504Z",
     "iopub.status.busy": "2022-09-02T03:19:09.167742Z",
     "iopub.status.idle": "2022-09-02T03:19:09.187978Z",
     "shell.execute_reply": "2022-09-02T03:19:09.187397Z"
    }
   },
   "outputs": [],
   "source": [
    "def trace_with_repatch(\n",
    "    model,  # The model\n",
    "    inp,  # A set of inputs\n",
    "    states_to_patch,  # A list of (token index, layername) triples to restore\n",
    "    states_to_unpatch,  # A list of (token index, layername) triples to re-randomize\n",
    "    answers_t,  # Answer probabilities to collect\n",
    "    tokens_to_mix,  # Range of tokens to corrupt (begin, end)\n",
    "    noise=0.1,  # Level of noise to add\n",
    "):\n",
    "    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise\n",
    "    patch_spec = defaultdict(list)\n",
    "    for t, l in states_to_patch:\n",
    "        patch_spec[l].append(t)\n",
    "    unpatch_spec = defaultdict(list)\n",
    "    for t, l in states_to_unpatch:\n",
    "        unpatch_spec[l].append(t)\n",
    "\n",
    "    def untuple(x):\n",
    "        return x[0] if isinstance(x, tuple) else x\n",
    "\n",
    "    # Define the model-patching rule.\n",
    "    def patch_rep(x, layer):\n",
    "        if layer == \"transformer.wte\":\n",
    "            # If requested, we corrupt a range of token embeddings on batch items x[1:]\n",
    "            if tokens_to_mix is not None:\n",
    "                b, e = tokens_to_mix\n",
    "                x[1:, b:e] += noise * torch.from_numpy(\n",
    "                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])\n",
    "                ).to(x.device)\n",
    "            return x\n",
    "        if first_pass or (layer not in patch_spec and layer not in unpatch_spec):\n",
    "            return x\n",
    "        # If this layer is in the patch_spec, restore the uncorrupted hidden state\n",
    "        # for selected tokens.\n",
    "        h = untuple(x)\n",
    "        for t in patch_spec.get(layer, []):\n",
    "            h[1:, t] = h[0, t]\n",
    "        for t in unpatch_spec.get(layer, []):\n",
    "            h[1:, t] = untuple(first_pass_trace[layer].output)[1:, t]\n",
    "        return x\n",
    "\n",
    "    # With the patching rules defined, run the patched model in inference.\n",
    "    for first_pass in [True, False] if states_to_unpatch else [False]:\n",
    "        with torch.no_grad(), nethook.TraceDict(\n",
    "            model,\n",
    "            [\"transformer.wte\"] + list(patch_spec.keys()) + list(unpatch_spec.keys()),\n",
    "            edit_output=patch_rep,\n",
    "        ) as td:\n",
    "            outputs_exp = model(**inp)\n",
    "            if first_pass:\n",
    "                first_pass_trace = td\n",
    "\n",
    "    # We report softmax probabilities for the answers_t token predictions of interest.\n",
    "    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]\n",
    "\n",
    "    return probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe52a4a",
   "metadata": {},
   "source": [
    "## Tracing all locations\n",
    "\n",
    "Now we just need to repeat it over all locations, and draw the heatmaps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9b5a7c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:19:09.197472Z",
     "iopub.status.busy": "2022-09-02T03:19:09.196609Z",
     "iopub.status.idle": "2022-09-02T03:19:09.219938Z",
     "shell.execute_reply": "2022-09-02T03:19:09.219327Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def calculate_hidden_flow_3(\n",
    "    mt,\n",
    "    prompt,\n",
    "    subject,\n",
    "    token_range=None,\n",
    "    samples=10,\n",
    "    noise=0.1,\n",
    "    window=10,\n",
    "    extra_token=0,\n",
    "    disable_mlp=False,\n",
    "    disable_attn=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Runs causal tracing over every token/layer combination in the network\n",
    "    and returns a dictionary numerically summarizing the results.\n",
    "    \"\"\"\n",
    "    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))\n",
    "    with torch.no_grad():\n",
    "        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]\n",
    "    [answer] = decode_tokens(mt.tokenizer, [answer_t])\n",
    "    e_range = find_token_range(mt.tokenizer, inp[\"input_ids\"][0], subject)\n",
    "    if token_range == \"last_subject\":\n",
    "        token_range = [e_range[1] - 1]\n",
    "    e_range = (e_range[0], e_range[1] + extra_token)\n",
    "    low_score = trace_with_patch(\n",
    "        mt.model, inp, [], answer_t, e_range, noise=noise\n",
    "    ).item()\n",
    "    differences = trace_important_states_3(\n",
    "        mt.model,\n",
    "        mt.num_layers,\n",
    "        inp,\n",
    "        e_range,\n",
    "        answer_t,\n",
    "        noise=noise,\n",
    "        disable_mlp=disable_mlp,\n",
    "        disable_attn=disable_attn,\n",
    "        token_range=token_range,\n",
    "    )\n",
    "    differences = differences.detach().cpu()\n",
    "    return dict(\n",
    "        scores=differences,\n",
    "        low_score=low_score,\n",
    "        high_score=base_score,\n",
    "        input_ids=inp[\"input_ids\"][0],\n",
    "        input_tokens=decode_tokens(mt.tokenizer, inp[\"input_ids\"][0]),\n",
    "        subject_range=e_range,\n",
    "        answer=answer,\n",
    "        window=window,\n",
    "        kind=\"\",\n",
    "    )\n",
    "\n",
    "\n",
    "def trace_important_states_3(\n",
    "    model,\n",
    "    num_layers,\n",
    "    inp,\n",
    "    e_range,\n",
    "    answer_t,\n",
    "    noise=0.1,\n",
    "    disable_mlp=False,\n",
    "    disable_attn=False,\n",
    "    token_range=None,\n",
    "):\n",
    "    ntoks = inp[\"input_ids\"].shape[1]\n",
    "    table = []\n",
    "    zero_mlps = []\n",
    "    if token_range is None:\n",
    "        token_range = range(ntoks)\n",
    "    for tnum in token_range:\n",
    "        zero_mlps = []\n",
    "        if disable_mlp:\n",
    "            zero_mlps = [\n",
    "                (tnum, layername(model, L, \"mlp\")) for L in range(0, num_layers)\n",
    "            ]\n",
    "        if disable_attn:\n",
    "            zero_mlps += [\n",
    "                (tnum, layername(model, L, \"attn\")) for L in range(0, num_layers)\n",
    "            ]\n",
    "        row = []\n",
    "        for layer in range(0, num_layers):\n",
    "            r = trace_with_repatch(\n",
    "                model,\n",
    "                inp,\n",
    "                [(tnum, layername(model, layer))],\n",
    "                zero_mlps,  # states_to_unpatch\n",
    "                answer_t,\n",
    "                tokens_to_mix=e_range,\n",
    "                noise=noise,\n",
    "            )\n",
    "            row.append(r)\n",
    "        table.append(torch.stack(row))\n",
    "    return torch.stack(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d559800a",
   "metadata": {},
   "source": [
    "Here is a causal trace with MLP disabled - it looks quite different from normal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "427f3989",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:19:09.223992Z",
     "iopub.status.busy": "2022-09-02T03:19:09.223279Z",
     "iopub.status.idle": "2022-09-02T03:20:09.387116Z",
     "shell.execute_reply": "2022-09-02T03:20:09.386501Z"
    }
   },
   "outputs": [],
   "source": [
    "prefix = \"Megan Rapinoe plays the sport of\"\n",
    "entity = \"Megan Rapinoe\"\n",
    "\n",
    "no_attn_r = calculate_hidden_flow_3(\n",
    "    mt, prefix, entity, disable_mlp=True, noise=noise_level\n",
    ")\n",
    "plot_trace_heatmap(no_attn_r, title=\"Impact with MLP at last subject token disabled\")\n",
    "ordinary_r = calculate_hidden_flow_3(mt, prefix, entity, noise=noise_level)\n",
    "plot_trace_heatmap(ordinary_r, title=\"Impact with MLP enabled as usual\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0362ea3",
   "metadata": {},
   "source": [
    "## Comparing the with-MLP/Attn and without-MLP/Attn traces\n",
    "\n",
    "Plotting on a bar graph makes it easier to see the difference between the causal effects with and without MLP enabled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2926aeac",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:20:09.404243Z",
     "iopub.status.busy": "2022-09-02T03:20:09.403498Z",
     "iopub.status.idle": "2022-09-02T03:20:09.424633Z",
     "shell.execute_reply": "2022-09-02T03:20:09.424001Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def plot_last_subject(mt, prefix, entity, token_range=\"last_subject\", savepdf=None):\n",
    "    ordinary, no_attn, no_mlp = calculate_last_subject(\n",
    "        mt, prefix, entity, token_range=token_range\n",
    "    )\n",
    "    plot_comparison(ordinary, no_attn, no_mlp, prefix, savepdf=savepdf)\n",
    "\n",
    "\n",
    "def calculate_last_subject(mt, prefix, entity, cache=None, token_range=\"last_subject\"):\n",
    "    def load_from_cache(filename):\n",
    "        try:\n",
    "            dat = numpy.load(f\"{cache}/{filename}\")\n",
    "            return {\n",
    "                k: v\n",
    "                if not isinstance(v, numpy.ndarray)\n",
    "                else str(v)\n",
    "                if v.dtype.type is numpy.str_\n",
    "                else torch.from_numpy(v)\n",
    "                for k, v in dat.items()\n",
    "            }\n",
    "        except FileNotFoundError as e:\n",
    "            return None\n",
    "\n",
    "    no_attn_r = load_from_cache(\"no_attn_r.npz\")\n",
    "    uncached_no_attn_r = no_attn_r is None\n",
    "    no_mlp_r = load_from_cache(\"no_mlp_r.npz\")\n",
    "    uncached_no_mlp_r = no_mlp_r is None\n",
    "    ordinary_r = load_from_cache(\"ordinary.npz\")\n",
    "    uncached_ordinary_r = ordinary_r is None\n",
    "    if uncached_no_attn_r:\n",
    "        no_attn_r = calculate_hidden_flow_3(\n",
    "            mt,\n",
    "            prefix,\n",
    "            entity,\n",
    "            disable_attn=True,\n",
    "            token_range=token_range,\n",
    "            noise=noise_level,\n",
    "        )\n",
    "    if uncached_no_mlp_r:\n",
    "        no_mlp_r = calculate_hidden_flow_3(\n",
    "            mt,\n",
    "            prefix,\n",
    "            entity,\n",
    "            disable_mlp=True,\n",
    "            token_range=token_range,\n",
    "            noise=noise_level,\n",
    "        )\n",
    "    if uncached_ordinary_r:\n",
    "        ordinary_r = calculate_hidden_flow_3(\n",
    "            mt, prefix, entity, token_range=token_range, noise=noise_level\n",
    "        )\n",
    "    if cache is not None:\n",
    "        os.makedirs(cache, exist_ok=True)\n",
    "        for u, r, filename in [\n",
    "            (uncached_no_attn_r, no_attn_r, \"no_attn_r.npz\"),\n",
    "            (uncached_no_mlp_r, no_mlp_r, \"no_mlp_r.npz\"),\n",
    "            (uncached_ordinary_r, ordinary_r, \"ordinary.npz\"),\n",
    "        ]:\n",
    "            if u:\n",
    "                numpy.savez(\n",
    "                    f\"{cache}/{filename}\",\n",
    "                    **{\n",
    "                        k: v.cpu().numpy() if torch.is_tensor(v) else v\n",
    "                        for k, v in r.items()\n",
    "                    },\n",
    "                )\n",
    "    if False:\n",
    "        return (ordinary_r[\"scores\"][0], no_attn_r[\"scores\"][0], no_mlp_r[\"scores\"][0])\n",
    "    return (\n",
    "        ordinary_r[\"scores\"][0] - ordinary_r[\"low_score\"],\n",
    "        no_attn_r[\"scores\"][0] - ordinary_r[\"low_score\"],\n",
    "        no_mlp_r[\"scores\"][0] - ordinary_r[\"low_score\"],\n",
    "    )\n",
    "\n",
    "    # return ordinary_r['scores'][0], no_attn_r['scores'][0]\n",
    "\n",
    "\n",
    "def plot_comparison(ordinary, no_attn, no_mlp, title, savepdf=None):\n",
    "    with plt.rc_context(rc={\"font.family\": \"Times New Roman\"}):\n",
    "        import matplotlib.ticker as mtick\n",
    "\n",
    "        fig, ax = plt.subplots(1, figsize=(6, 1.5), dpi=300)\n",
    "        ax.bar(\n",
    "            [i - 0.3 for i in range(len(ordinary))],\n",
    "            ordinary,\n",
    "            width=0.3,\n",
    "            color=\"#7261ab\",\n",
    "            label=\"Impact of single state on P\",\n",
    "        )\n",
    "        ax.bar(\n",
    "            [i for i in range(len(no_attn))],\n",
    "            no_attn,\n",
    "            width=0.3,\n",
    "            color=\"#f3201b\",\n",
    "            label=\"Impact with Attn severed\",\n",
    "        )\n",
    "        ax.bar(\n",
    "            [i + 0.3 for i in range(len(no_mlp))],\n",
    "            no_mlp,\n",
    "            width=0.3,\n",
    "            color=\"#20b020\",\n",
    "            label=\"Impact with MLP severed\",\n",
    "        )\n",
    "        ax.set_title(\n",
    "            title\n",
    "        )  #'Impact of individual hidden state at last subject token with MLP disabled')\n",
    "        ax.set_ylabel(\"Indirect Effect\")\n",
    "        # ax.set_xlabel('Layer at which the single hidden state is restored')\n",
    "        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))\n",
    "        ax.set_ylim(None, max(0.025, ordinary.max() * 1.05))\n",
    "        ax.legend()\n",
    "        if savepdf:\n",
    "            os.makedirs(os.path.dirname(savepdf), exist_ok=True)\n",
    "            plt.savefig(savepdf, bbox_inches=\"tight\")\n",
    "            plt.close()\n",
    "        else:\n",
    "            plt.show()\n",
    "\n",
    "\n",
    "if False:  # Some representative cases.\n",
    "    plot_last_subject(mt, \"Megan Rapinoe plays the sport of\", \"Megan Rapinoe\")\n",
    "    plot_last_subject(mt, \"The Big Bang Theory premires on\", \"The Big Bang Theory\")\n",
    "    plot_last_subject(mt, \"Germaine Greer's domain of work is\", \"Germaine Greer\")\n",
    "    plot_last_subject(mt, \"Brian de Palma works in the area of\", \"Brian de Palma\")\n",
    "    plot_last_subject(mt, \"The headquarter of Zillow is in downtown\", \"Zillow\")\n",
    "    plot_last_subject(\n",
    "        mt,\n",
    "        \"Mitsubishi Electric started in the 1900s as a small company in\",\n",
    "        \"Mitsubishi\",\n",
    "    )\n",
    "    plot_last_subject(\n",
    "        mt,\n",
    "        \"Mitsubishi Electric started in the 1900s as a small company in\",\n",
    "        \"Mitsubishi Electric\",\n",
    "    )\n",
    "    plot_last_subject(mt, \"Madame de Montesson died in the city of\", \"Madame\")\n",
    "    plot_last_subject(\n",
    "        mt, \"Madame de Montesson died in the city of\", \"Madame de Montesson\"\n",
    "    )\n",
    "    plot_last_subject(mt, \"Edmund Neupert, performing on the\", \"Edmund Neupert\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1631379e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:20:09.429547Z",
     "iopub.status.busy": "2022-09-02T03:20:09.428848Z",
     "iopub.status.idle": "2022-09-02T03:20:21.255019Z",
     "shell.execute_reply": "2022-09-02T03:20:21.254651Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_last_subject(mt, \"The Space Needle is in the city of\", \"The Space Needle\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e3fb70",
   "metadata": {},
   "source": [
    "## Average Indirect Effects\n",
    "\n",
    "Now we average over hundreds of factual statements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0d802d7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-09-02T03:20:21.267822Z",
     "iopub.status.busy": "2022-09-02T03:20:21.262178Z",
     "iopub.status.idle": "2022-09-02T06:24:04.688526Z",
     "shell.execute_reply": "2022-09-02T06:24:04.688828Z"
    }
   },
   "outputs": [],
   "source": [
    "import tqdm\n",
    "\n",
    "all_ordinary = []\n",
    "all_no_attn = []\n",
    "all_no_mlp = []\n",
    "for i, knowledge in enumerate(tqdm.tqdm(knowns[:1000])):\n",
    "    # plot_all_flow(mt, knowledge['prompt'], knowledge['subject'])\n",
    "    ordinary, no_attn, no_mlp = calculate_last_subject(\n",
    "        mt,\n",
    "        knowledge[\"prompt\"],\n",
    "        knowledge[\"subject\"],\n",
    "        cache=f\"results/ct_disable_attn/case_{i}\",\n",
    "    )\n",
    "    all_ordinary.append(ordinary)\n",
    "    all_no_attn.append(no_attn)\n",
    "    all_no_mlp.append(no_mlp)\n",
    "title = \"Causal effect of states at the early site with Attn or MLP modules severed\"\n",
    "\n",
    "avg_ordinary = torch.stack(all_ordinary).mean(dim=0)\n",
    "avg_no_attn = torch.stack(all_no_attn).mean(dim=0)\n",
    "avg_no_mlp = torch.stack(all_no_mlp).mean(dim=0)\n",
    "import matplotlib.ticker as mtick\n",
    "\n",
    "with plt.rc_context(rc={\"font.family\": \"Times New Roman\"}):\n",
    "    fig, ax = plt.subplots(1, figsize=(6, 2.1), dpi=300)\n",
    "    ax.bar(\n",
    "        [i - 0.3 for i in range(len(avg_ordinary))],\n",
    "        avg_ordinary,\n",
    "        width=0.3,\n",
    "        color=\"#7261ab\",\n",
    "        label=\"Effect of single state on P\",\n",
    "    )\n",
    "    ax.bar(\n",
    "        [i for i in range(len(avg_no_attn))],\n",
    "        avg_no_attn,\n",
    "        width=0.3,\n",
    "        color=\"#f3201b\",\n",
    "        label=\"Effect with Attn severed\",\n",
    "    )\n",
    "    ax.bar(\n",
    "        [i + 0.3 for i in range(len(avg_no_mlp))],\n",
    "        avg_no_mlp,\n",
    "        width=0.3,\n",
    "        color=\"#20b020\",\n",
    "        label=\"Effect with MLP severed\",\n",
    "    )\n",
    "    ax.set_title(\n",
    "        title\n",
    "    )  #'Impact of individual hidden state at last subject token with MLP disabled')\n",
    "    ax.set_ylabel(\"Average Indirect Effect\")\n",
    "    ax.set_xlabel(\"Layer at which the single hidden state is restored\")\n",
    "    ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))\n",
    "    ax.set_ylim(None, max(0.025, 0.125))\n",
    "\n",
    "    ax.legend(frameon=False)\n",
    "fig.savefig(\"causal-trace-no-attn-mlp.pdf\", bbox_inches=\"tight\")\n",
    "print([d[20] - d[10] for d in [avg_ordinary, avg_no_attn, avg_no_mlp]])\n",
    "print(avg_ordinary[15], avg_no_attn[15], avg_no_mlp[15])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda42ea1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
