{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f79958",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "319b76fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f12fe410",
   "metadata": {},
   "source": [
    "For convenience, here's a little API for getting GPT-J completions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e8e0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "@torch.inference_mode()\n",
    "def complete(prompt, max_new_tokens=1):\n",
    "    inputs = mt.tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    outputs = mt.model.generate(\n",
    "        **inputs,\n",
    "        max_new_tokens=max_new_tokens,\n",
    "        pad_token_id=mt.tokenizer.pad_token_id,\n",
    "    )\n",
    "    return mt.tokenizer.batch_decode(outputs)[0]\n",
    "\n",
    "complete(\"The Eiffel Tower is located in the city of\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f913de60",
   "metadata": {},
   "source": [
    "Which attention heads, at layers after the ROME-optimal layer, look at the last subject token?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3093ae6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "from src.utils import tokenizer_utils\n",
    "\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def compute_attns_and_words(prompt_template, subject, layers, ablates=None, subject_token_offset=-1):\n",
    "    prompt = prompt_template.format(subject)\n",
    "    si, sj = tokenizer_utils.find_token_range(prompt, subject, tokenizer=mt.tokenizer)\n",
    "    \n",
    "    if subject_token_offset < 0:\n",
    "        ti = sj + subject_token_offset\n",
    "    else:\n",
    "        ti = si + subject_token_offset\n",
    "\n",
    "    z_layernames = [f\"transformer.h.{layer}\" for layer in layers]\n",
    "    attn_layernames = [f\"transformer.h.{layer}.attn.attn_dropout\" for layer in layers]\n",
    "    \n",
    "    edit_output = None\n",
    "    if ablates is not None:\n",
    "        ablate_heads_by_layer = defaultdict(list)\n",
    "        for (layer, head) in ablates:\n",
    "            ablate_layername = f\"transformer.h.{layer}.attn.attn_dropout\"\n",
    "            ablate_heads_by_layer[ablate_layername].append(head)\n",
    "\n",
    "        def edit_output(x, layer):\n",
    "            if layer not in ablate_heads_by_layer:\n",
    "                return x\n",
    "            for head in ablate_heads_by_layer[layer]:\n",
    "                x[:, head, -1, ti] = 0\n",
    "            return x\n",
    "\n",
    "    with baukit.TraceDict(mt.model, (*z_layernames, *attn_layernames), edit_output=edit_output) as ret:\n",
    "        complete(prompt)\n",
    "\n",
    "    attns = []\n",
    "    words = []\n",
    "    for z_layername, attn_layername in tqdm(tuple(zip(z_layernames, attn_layernames))):\n",
    "        attn = ret[attn_layername].output[0, :, -1, ti]\n",
    "        attns.append(attn)\n",
    "\n",
    "        z = ret[z_layername].output[0][:, -1]\n",
    "        logits = mt.lm_head(z)\n",
    "        token_id = logits.argmax(dim=-1).item()\n",
    "        token_prob = torch.softmax(logits.float(), dim=-1).max()\n",
    "        word = mt.tokenizer.decode(token_id)\n",
    "        word = word.strip()\n",
    "        words.append(f\"{word} ({token_prob:.2f})\")\n",
    "    return attns, words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0653735",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "subject = \"The Space Needle\"\n",
    "prompt_template = \"{} is located in the city of\"\n",
    "# subject = \"Canberra\"\n",
    "# prompt_template = \"{} is the capital city of\"\n",
    "layers = range(27)\n",
    "\n",
    "attns, words = compute_attns_and_words(\n",
    "    prompt_template, \n",
    "    subject,\n",
    "    layers,\n",
    "#     ablates=[(8, 3)],\n",
    "    subject_token_offset=-1,\n",
    ")\n",
    "\n",
    "heads_ordered_list = [\n",
    "    sorted(enumerate(attn.tolist()), key=lambda kv: kv[-1])\n",
    "    for attn in attns\n",
    "]\n",
    "heads_ordered = [[x[0] for x in heads] for heads in heads_ordered_list]\n",
    "attns_ordered = [[x[1] for x in heads] for heads in heads_ordered_list]\n",
    "# attns_ordered = torch.tensor([\n",
    "#     sorted(attn.tolist())\n",
    "#     for attn in attns\n",
    "# ])\n",
    "\n",
    "sns.heatmap(\n",
    "    torch.tensor(attns_ordered).numpy(),\n",
    "    yticklabels=layers,\n",
    "    cbar=False,\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    "    # Uncomment this to see the ID.\n",
    "#     annot=heads_ordered\n",
    ")\n",
    "plt.title(\"Last Token Attn. to Subj Token\")\n",
    "plt.ylabel(\"Layer\")\n",
    "plt.xlabel(\"Head\")\n",
    "\n",
    "# plt.yaxis.tick_left()\n",
    "for i, word in enumerate(words):\n",
    "    plt.text(attns[0].shape[0] + 0.5, i + 0.5, word, va='center')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41530b92",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
