{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "777f4802",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nnsight\n",
    "import torch\n",
    "from nnsight import LanguageModel\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-1b\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "path1 = \"~pythia_replicate/hf_output/clean_1b/step=5000\"\n",
    "path2 = \"~pythia_replicate/hf_output/clean_1b_step0/step=0\"\n",
    "\n",
    "\n",
    "model_1 = LanguageModel(path1, revision=None, tokenizer=tokenizer, device_map=\"auto\", dispatch=True)\n",
    "model_2 = LanguageModel(path2, revision=None, tokenizer=tokenizer, device_map=\"auto\", dispatch=True)\n",
    "torch.manual_seed(42)\n",
    "rand_seq = torch.randint(0, 10000, (1000, 10))\n",
    "rand_repeated_seq = torch.cat([rand_seq, rand_seq], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1cc7eea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1._attn_implementation = 'eager'\n",
    "model_2._attn_implementation = 'eager'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2142fda0",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_heads = model_1.config.num_attention_heads \n",
    "head_dim = model_1.config.hidden_size // n_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d459e6d",
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3947966",
   "metadata": {},
   "source": [
    "## Patch Attention Heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297ccc9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "heads_to_patch = [(i, j) for i in range(4) for j in range(12)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a51b2a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_heads = model_1.config.num_attention_heads \n",
    "head_dim = model_1.config.hidden_size // n_heads \n",
    "\n",
    "activations = {}\n",
    "with torch.no_grad():\n",
    "    with model_1.trace(rand_repeated_seq):\n",
    "        for layer_idx, head_idx in heads_to_patch:\n",
    "            o_proj = model_1.gpt_neox.layers[layer_idx].attention.dense\n",
    "\n",
    "            o_proj_inp = o_proj.inputs[0][0]\n",
    "            \n",
    "            bsz = o_proj_inp.shape[0]\n",
    "            seq_len = o_proj_inp.shape[1]\n",
    "            head_acts = o_proj_inp.view(bsz, seq_len, n_heads, head_dim)\n",
    "            \n",
    "            activation = head_acts[:, :, head_idx, :].save()\n",
    "            activations[(layer_idx, head_idx)] = activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "846e3adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f098596",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_heads = model_2.config.num_attention_heads \n",
    "head_dim = model_2.config.hidden_size // n_heads \n",
    "\n",
    "with torch.no_grad():\n",
    "    with model_2.trace(rand_repeated_seq, output_attentions=True):\n",
    "        out_clean = model_2.output.save()\n",
    "\n",
    "with torch.no_grad():\n",
    "    with model_2.trace(rand_repeated_seq, output_attentions=True):\n",
    "        for layer_idx, head_idx in heads_to_patch:\n",
    "            o_proj = model_2.gpt_neox.layers[layer_idx].attention.dense\n",
    "\n",
    "            o_proj_inp = o_proj.inputs[0][0]\n",
    "            \n",
    "            bsz = o_proj_inp.shape[0]\n",
    "            seq_len = o_proj_inp.shape[1]\n",
    "            head_acts = o_proj_inp.view(bsz, seq_len, n_heads, head_dim)\n",
    "            \n",
    "            head_acts[:, :, head_idx, :] = activations[(layer_idx, head_idx)]\n",
    "        out_modified = model_2.output.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73585768",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "def find_induction_positions(token_strings):\n",
    "    induction_positions = []\n",
    "\n",
    "    # For each position, look for ALL previous occurrences of the same token\n",
    "    for i, token in enumerate(token_strings):\n",
    "        for j in range(i):\n",
    "            if token_strings[j] == token:\n",
    "                # Found previous occurrence at position j\n",
    "                # Induction pattern: attend to position j+1 (what came after)\n",
    "                if j + 1 < len(token_strings):  # Make sure j+1 is valid\n",
    "                    induction_positions.append((i, j + 1))\n",
    "                # Don't break - continue looking for more previous occurrences\n",
    "\n",
    "    return induction_positions\n",
    "\n",
    "def plot_attention_heads(attention_weights, tokens=None, tokenizer=None):\n",
    "    \"\"\"Simple attention head plotting with induction pattern highlighting\"\"\"\n",
    "    # Convert to numpy and take first batch\n",
    "    if isinstance(attention_weights[0], torch.Tensor):\n",
    "        attention_weights = [w[0].detach().cpu().numpy() for w in attention_weights]\n",
    "\n",
    "    num_layers = len(attention_weights)\n",
    "    num_heads = attention_weights[0].shape[0]\n",
    "\n",
    "    # Show first 2 layers, first 4 heads for speed\n",
    "    layers_to_show = num_layers\n",
    "    heads_to_show = num_heads\n",
    "\n",
    "    fig, axes = plt.subplots(layers_to_show, heads_to_show, figsize=(12, 12))\n",
    "    if layers_to_show == 1:\n",
    "        axes = axes.reshape(1, -1)\n",
    "\n",
    "    # Find induction positions if tokens are provided\n",
    "    induction_positions = []\n",
    "    if tokens is not None and tokenizer is not None:\n",
    "        token_strings = [tokenizer.decode(i.item()) for i in tokens]\n",
    "        induction_positions = find_induction_positions(token_strings)\n",
    "\n",
    "    for layer_idx in range(layers_to_show):\n",
    "        for head_idx in range(heads_to_show):\n",
    "            ax = axes[layer_idx, head_idx]\n",
    "            attn = attention_weights[layer_idx][head_idx]\n",
    "\n",
    "            ax.imshow(attn, cmap=\"Blues\")\n",
    "\n",
    "            # Add red squares for expected induction positions\n",
    "            for query_pos, key_pos in induction_positions:\n",
    "                rect = Rectangle(\n",
    "                    (key_pos - 0.5, query_pos - 0.5),\n",
    "                    1,\n",
    "                    1,\n",
    "                    linewidth=1,\n",
    "                    edgecolor=\"red\",\n",
    "                    facecolor=\"none\",\n",
    "                )\n",
    "                ax.add_patch(rect)\n",
    "\n",
    "            ax.set_title(f\"L{layer_idx}H{head_idx}\")\n",
    "            ax.set_xticks([])\n",
    "            ax.set_yticks([])\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d837aa33",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_attention_heads(out_clean.attentions, rand_repeated_seq[0], tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b77fd97",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_heads = model_2.config.num_attention_heads \n",
    "head_dim = model_2.config.hidden_size // n_heads \n",
    "\n",
    "with torch.no_grad():\n",
    "    with model_2.trace(rand_repeated_seq, output_attentions=True):\n",
    "        for layer_idx, head_idx in heads_to_patch:\n",
    "            attn_layer = model_2.gpt_neox.layers[layer_idx].attention\n",
    "            \n",
    "            # Get the hidden states\n",
    "            hidden_states = attn_layer.inputs[0][0]\n",
    "            \n",
    "            # Access shape without unpacking\n",
    "            bsz = hidden_states.shape[0]\n",
    "            seq_len = hidden_states.shape[1]\n",
    "            \n",
    "            # Get value vectors (adjust based on your model's architecture)\n",
    "            qkv = attn_layer.query_key_value(hidden_states)\n",
    "            value = qkv.view(bsz, seq_len, 3, n_heads, head_dim)[:, :, 2, :, :]\n",
    "            \n",
    "            # Extract values for the head we're patching\n",
    "            head_values = value[:, :, head_idx, :]  # [batch, seq_len, head_dim]\n",
    "            \n",
    "            # Create previous-token output using shift\n",
    "            prev_token_output = head_values.clone()\n",
    "            prev_token_output[:, 1:, :] = head_values[:, :-1, :]  # Shift right by 1\n",
    "            # Position 0 already has the right value (attends to itself)\n",
    "            \n",
    "            # Patch the output\n",
    "            o_proj = attn_layer.dense  \n",
    "            o_proj_inp = o_proj.inputs[0][0]\n",
    "            head_acts = o_proj_inp.view(bsz, seq_len, n_heads, head_dim)\n",
    "            head_acts[:, :, head_idx, :] = prev_token_output\n",
    "        out = model_2.output.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "185ce2ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_attention_heads(out.attentions, rand_repeated_seq[0], tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa2a71a6",
   "metadata": {},
   "source": [
    "# Patch Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5004bca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "layers_to_patch = [4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "15c44332",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "05c27c94",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "Setting `attention_type` to `eager` because `sdpa` does not support `output_attentions=True` or `head_mask`.\n"
     ]
    }
   ],
   "source": [
    "n_heads = model_1.config.num_attention_heads \n",
    "head_dim = model_1.config.hidden_size // n_heads \n",
    "\n",
    "activations = defaultdict(list)\n",
    "for start in range(0, len(rand_repeated_seq), batch_size):\n",
    "    with torch.no_grad():\n",
    "        with model_1.trace(rand_repeated_seq[start:start+batch_size], output_attentions=True):\n",
    "            for layer_idx in layers_to_patch:\n",
    "                layer_output = model_1.gpt_neox.layers[layer_idx].output\n",
    "                \n",
    "                activation = layer_output.save()\n",
    "                activations[layer_idx] += [activation]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "24e53691",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_layers = model_2.config.num_hidden_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ac67d402",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_heads = model_2.config.num_attention_heads \n",
    "head_dim = model_2.config.hidden_size // n_heads \n",
    "keys = {}\n",
    "queries = {}\n",
    "batch2attentions = {}\n",
    "batch2attentions_step1000 = {}\n",
    "batch2attentions_step0 = {}\n",
    "\n",
    "with torch.no_grad():\n",
    "    for bidx, start in enumerate(range(0, len(rand_repeated_seq), batch_size)):\n",
    "        with model_2.trace(rand_repeated_seq[start:start+batch_size], output_attentions=True):\n",
    "            out_clean = model_2.output.save()\n",
    "            batch2attentions_step0[bidx] = out_clean.attentions.save()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for bidx, start in enumerate(range(0, len(rand_repeated_seq), batch_size)):\n",
    "        with model_1.trace(rand_repeated_seq[start:start+batch_size], output_attentions=True):\n",
    "            out_step1000 = model_1.output.save()\n",
    "            batch2attentions_step1000[bidx] = out_step1000.attentions.save()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for bidx, start in enumerate(range(0, len(rand_repeated_seq), batch_size)):\n",
    "        with model_2.trace(rand_repeated_seq[start:start+batch_size], output_attentions=True):\n",
    "            for layer_idx in layers_to_patch:\n",
    "                model_2.gpt_neox.layers[layer_idx].output = activations[layer_idx][bidx]\n",
    "            out2_patched = model_2.output.save()\n",
    "            batch2attentions[bidx] = out2_patched.attentions.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cdeea466",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def compute_attention_scores(batch2attentions, score_type: str = \"induction\") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute per-layer, per-head attention scores.\n",
    "\n",
    "    Args:\n",
    "        batch2attentions: dict[int -> tuple of tensors]\n",
    "            Each entry: tuple of length num_layers\n",
    "            Each tensor: [B, H, S, S] attention matrix\n",
    "        score_type: \"induction\" or \"previous\"\n",
    "            - \"induction\": measures attention to earlier matching bigram target\n",
    "            - \"previous\": measures attention to the immediately previous token\n",
    "\n",
    "    Returns:\n",
    "        scores: torch.Tensor of shape [num_layers, n_heads]\n",
    "    \"\"\"\n",
    "    # assume all batches/layers have same shape\n",
    "    num_layers = len(next(iter(batch2attentions.values())))\n",
    "    n_heads = next(iter(batch2attentions.values()))[0].shape[1]\n",
    "    device = next(iter(batch2attentions.values()))[0].device\n",
    "\n",
    "    scores = torch.zeros(num_layers, n_heads, device=device)\n",
    "    total_sequences = 0\n",
    "\n",
    "    for bidx, attn_layers in batch2attentions.items():\n",
    "        B, H, S, _ = attn_layers[0].shape\n",
    "        total_sequences += B\n",
    "\n",
    "        for layer_idx, attn_probas in enumerate(attn_layers):\n",
    "            if score_type == \"induction\":\n",
    "                base_seq_len = S // 2\n",
    "                offset = -(base_seq_len - 1)\n",
    "                diag = attn_probas.diagonal(offset=offset, dim1=-2, dim2=-1)\n",
    "                diag = diag[..., 1:]  # skip first for bigram alignment\n",
    "            elif score_type == \"previous\":\n",
    "                diag = attn_probas.diagonal(offset=-1, dim1=-2, dim2=-1)  # [B,H,S-1]\n",
    "            else:\n",
    "                raise ValueError(\"score_type must be 'induction' or 'previous'\")\n",
    "\n",
    "            per_head_sum = diag.mean(dim=-1).sum(dim=0)  # [H]\n",
    "            scores[layer_idx] += per_head_sum\n",
    "\n",
    "    scores /= total_sequences\n",
    "    return scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12e112c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Induction scores: torch.Size([16, 8]) 0.06728119403123856\n",
      "Induction scores: torch.Size([16, 8]) 0.12408338487148285\n",
      "Induction scores: torch.Size([16, 8]) 0.06691370904445648\n",
      "Previous-token scores: torch.Size([16, 8]) 0.15170779824256897\n",
      "Previous-token scores: torch.Size([16, 8]) 0.17876824736595154\n",
      "Previous-token scores: torch.Size([16, 8]) 0.13638468086719513\n"
     ]
    }
   ],
   "source": [
    "# For induction score\n",
    "ind_scores = compute_attention_scores(batch2attentions, score_type=\"induction\")\n",
    "print(\"Induction scores:\", ind_scores.shape, ind_scores.mean().item())\n",
    "\n",
    "ind_scores_step5000 = compute_attention_scores(batch2attentions_step1000, score_type=\"induction\")\n",
    "print(\"Induction scores:\", ind_scores_step1000.shape, ind_scores_step1000.mean().item())\n",
    "\n",
    "ind_scores_step0 = compute_attention_scores(batch2attentions_step0, score_type=\"induction\")\n",
    "print(\"Induction scores:\", ind_scores_step0.shape, ind_scores_step0.mean().item())\n",
    "\n",
    "# For previous-token score\n",
    "prev_scores = compute_attention_scores(batch2attentions, score_type=\"previous\")\n",
    "print(\"Previous-token scores:\", prev_scores.shape, prev_scores.mean().item())\n",
    "\n",
    "prev_scores_step5000 = compute_attention_scores(batch2attentions_step1000, score_type=\"previous\")\n",
    "print(\"Previous-token scores:\", prev_scores_step1000.shape, prev_scores_step1000.mean().item())\n",
    "\n",
    "prev_scores_step0 = compute_attention_scores(batch2attentions_step0, score_type=\"previous\")\n",
    "print(\"Previous-token scores:\", prev_scores_step0.shape, prev_scores_step0.mean().item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "23e174a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0faf8d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_scores[4] = ind_scores_step0[4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f36f34a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#vmin = ind_scores.min().item()\n",
    "vmin = 0.05\n",
    "vmax = 0.10\n",
    "threshold = 0.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "af5d62b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lib.attn_heads import visualize_head_scores_publication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e8acb01e",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_head_scores_publication(ind_scores, threshold=threshold, title=\"Head Scores\", save_dir=\"plot_attention_figures\", model_type=f\"{model_type}_patched\", color_map=\"Purples\", vmin= vmin,vmax=vmax, layer_start=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "742fd18c",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_head_scores_publication(ind_scores_step0, threshold=threshold, title=\"Head Scores\", save_dir=\"plot_attention_figures\", model_type=f\"{model_type}_step0\", color_map=\"Purples\", vmin=vmin,vmax=vmax, layer_start=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2aa7dd4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9247, device='cuda:1')"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ind_scores_step5000.max()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythia_replicate",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
