{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "960c32c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from typing import Optional\n",
    "import transformers\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams[\"axes.titlesize\"] = 20\n",
    "plt.rcParams[\"axes.labelsize\"] = 16\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "\n",
    "\n",
    "# identical for phi-3, llama3, qwen3\n",
    "def eager_attention_forward(\n",
    "    module: nn.Module,\n",
    "    query: torch.Tensor,\n",
    "    key: torch.Tensor,\n",
    "    value: torch.Tensor,\n",
    "    attention_mask: Optional[torch.Tensor],\n",
    "    scaling: float,\n",
    "    dropout: float = 0.0,\n",
    "    **kwargs,\n",
    "):\n",
    "    key_states = transformers.models.llama.modeling_llama.repeat_kv(\n",
    "        key, module.num_key_value_groups\n",
    "    )\n",
    "    value_states = transformers.models.llama.modeling_llama.repeat_kv(\n",
    "        value, module.num_key_value_groups\n",
    "    )\n",
    "\n",
    "    _attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n",
    "    if attention_mask is not None:\n",
    "        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n",
    "        _attn_weights = _attn_weights + causal_mask\n",
    "\n",
    "    attn_weights = nn.functional.softmax(_attn_weights, dim=-1, dtype=torch.float32).to(\n",
    "        query.dtype\n",
    "    )\n",
    "    attn_weights = nn.functional.dropout(\n",
    "        attn_weights, p=dropout, training=module.training\n",
    "    )\n",
    "    attn_output = torch.matmul(attn_weights, value_states)\n",
    "    attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "\n",
    "    return attn_output, _attn_weights.mean(dim=(0, 1))\n",
    "\n",
    "\n",
    "transformers.models.llama.modeling_llama.eager_attention_forward = (\n",
    "    eager_attention_forward\n",
    ")\n",
    "transformers.models.phi3.modeling_phi3.eager_attention_forward = eager_attention_forward\n",
    "transformers.models.qwen3.modeling_qwen3.eager_attention_forward = (\n",
    "    eager_attention_forward\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c317f335",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_from_disk\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "\n",
    "def normalize_lower_diagonals(attn: torch.Tensor):\n",
    "    n, _ = attn.shape\n",
    "    for i in range(n):\n",
    "        diag = attn.diag(-i)\n",
    "        diag = diag - diag.mean()\n",
    "        attn[torch.arange(i, n), torch.arange(n - i)] = diag\n",
    "\n",
    "    mask = torch.triu(torch.ones_like(attn), diagonal=1).bool()\n",
    "    attn = attn.masked_fill(mask, float(\"nan\"))\n",
    "\n",
    "    return attn\n",
    "\n",
    "\n",
    "def plot_normalized_attention(model_name, normalize=False):\n",
    "    dset = load_from_disk(\"../data\")\n",
    "    tok = AutoTokenizer.from_pretrained(model_name)\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name,\n",
    "        device_map=\"auto\",\n",
    "        torch_dtype=torch.bfloat16,\n",
    "        attn_implementation=\"eager\",\n",
    "    )\n",
    "    model.eval()\n",
    "\n",
    "    dset = dset.map(lambda x: {\"input_ids\": tok.encode(x[\"text\"])[:1024]}, num_proc=24)\n",
    "    input_ids = (\n",
    "        torch.from_numpy(np.asarray([i for i in dset[\"input_ids\"]])).long().cuda()\n",
    "    )\n",
    "\n",
    "    attns = [torch.zeros(1024, 1024) for _ in range(len(model.model.layers))]\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, 20, 20):\n",
    "            inputs = input_ids[i : i + 20].to(\"cuda\")\n",
    "            out = model(input_ids=inputs, output_attentions=True)\n",
    "            for j in range(len(model.model.layers)):\n",
    "                attns[j] += out.attentions[j].cpu().float()\n",
    "    attns = [i / 50.0 for i in attns]\n",
    "\n",
    "    plt.figure(figsize=(20, 40))\n",
    "    num_layers = len(attns)\n",
    "    for l in range(num_layers):\n",
    "        plt.subplot(num_layers // 4, 4, l + 1)\n",
    "        attn = attns[l]\n",
    "        if normalize:\n",
    "            attn = normalize_lower_diagonals(attn.clone())\n",
    "\n",
    "        notna = attn[~torch.isnan(attn)]\n",
    "        notinf = notna[~torch.isinf(notna)]\n",
    "        vmin = torch.quantile(notinf, 0.01).item()\n",
    "        vmax = torch.quantile(notinf, 0.99).item()\n",
    "        plt.imshow(\n",
    "            attn.numpy(),\n",
    "            cmap=\"rocket\",\n",
    "            vmin=vmin,\n",
    "            vmax=vmax,\n",
    "        )\n",
    "        plt.colorbar()\n",
    "        plt.title(f\"Layer {l+1}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99dc3292",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "plot_normalized_attention(\"meta-llama/Llama-3.1-8B\", False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/llama3_rope_orig.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "plot_normalized_attention(\"microsoft/phi-4\", False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/phi4_rope_orig.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "plot_normalized_attention(\"Qwen/Qwen3-8B\", False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/qwen3_rope_orig.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "plot_normalized_attention(\"meta-llama/Llama-3.1-8B\", True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/llama3_rope_norm.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "plot_normalized_attention(\"microsoft/phi-4\", True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/phi4_rope_norm.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "plot_normalized_attention(\"Qwen/Qwen3-8B\", True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/qwen3_rope_norm.pdf\")\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c6c639",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all_models(models, aliases):\n",
    "    plt.figure(figsize=(20, 12))\n",
    "\n",
    "    for idx, model_name in enumerate(models):\n",
    "        dset = load_from_disk(\"../data\")\n",
    "        tok = AutoTokenizer.from_pretrained(model_name)\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name,\n",
    "            device_map=\"auto\",\n",
    "            torch_dtype=torch.bfloat16,\n",
    "            attn_implementation=\"eager\",\n",
    "        )\n",
    "        model.eval()\n",
    "\n",
    "        dset = dset.map(\n",
    "            lambda x: {\"input_ids\": tok.encode(x[\"text\"])[:1024]}, num_proc=24\n",
    "        )\n",
    "        input_ids = (\n",
    "            torch.from_numpy(np.asarray([i for i in dset[\"input_ids\"]])).long().cuda()\n",
    "        )\n",
    "\n",
    "        attns = [torch.zeros(1024, 1024) for _ in range(len(model.model.layers))]\n",
    "        with torch.no_grad():\n",
    "            for i in range(0, 1000, 20):\n",
    "                inputs = input_ids[i : i + 20].to(\"cuda\")\n",
    "                out = model(input_ids=inputs, output_attentions=True)\n",
    "                for j in range(len(model.model.layers)):\n",
    "                    attns[j] += out.attentions[j].cpu().float()\n",
    "        attns = [i / 50.0 for i in attns]\n",
    "\n",
    "        num_layers = len(attns)\n",
    "        for l in range(4):\n",
    "            plt.subplot(len(models), 4, idx * 4 + l + 1)\n",
    "            normalized = normalize_lower_diagonals(attns[l])\n",
    "            notna = normalized[~torch.isnan(normalized)]\n",
    "            notinf = notna[~torch.isinf(notna)]\n",
    "            vmin = torch.quantile(notinf, 0.01).item()\n",
    "            vmax = torch.quantile(notinf, 0.99).item()\n",
    "            vmin = torch.quantile(notinf, 0.01).item()\n",
    "            vmax = torch.quantile(notinf, 0.99).item()\n",
    "\n",
    "            plt.imshow(\n",
    "                normalize_lower_diagonals(attns[l]), cmap=\"rocket\", vmin=vmin, vmax=vmax\n",
    "            )\n",
    "            plt.colorbar()\n",
    "            plt.title(f\"{aliases[idx]}, Layer {l+1}\")\n",
    "\n",
    "\n",
    "models = [\"meta-llama/Llama-3.1-8B\", \"microsoft/phi-4\", \"Qwen/Qwen3-8B\"]\n",
    "alias = [\"Llama-3\", \"Phi-4\", \"Qwen3\"]\n",
    "\n",
    "plot_all_models(models, alias)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/all_rope.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
