{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa594828",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from transformers.models.llama.modeling_llama import (\n",
    "    LlamaRotaryEmbedding,\n",
    "    apply_rotary_pos_emb,\n",
    ")\n",
    "from dataclasses import dataclass\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def normalize_lower_diagonals(attn: torch.Tensor):\n",
    "    n, _ = attn.shape\n",
    "    for i in range(n):\n",
    "        diag = attn.diag(-i)\n",
    "        diag = diag - diag.mean()\n",
    "        attn[torch.arange(i, n), torch.arange(n - i)] = diag\n",
    "\n",
    "    for i in range(n):\n",
    "        diag = attn.diag(i)\n",
    "        diag = diag - diag.mean()\n",
    "        attn[torch.arange(n - i), torch.arange(i, n)] = diag\n",
    "\n",
    "    return attn\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Config:\n",
    "    max_position_embeddings: int = 100\n",
    "    rope_theta: int = 10000\n",
    "    head_dim: int = 64\n",
    "    hidden_size: int = 64\n",
    "    num_attention_heads: int = 1\n",
    "\n",
    "\n",
    "class SimpleResModel(nn.Module):\n",
    "    def __init__(self, dim=64):\n",
    "        super().__init__()\n",
    "        self.ln = nn.LayerNorm(dim)\n",
    "        config = Config(head_dim=dim)\n",
    "        self.rope = LlamaRotaryEmbedding(config)\n",
    "\n",
    "    def forward(self, x, causal=True):\n",
    "        B, S, E = x.size()\n",
    "        hidden = x\n",
    "        hidden = hidden / (hidden.norm(dim=-1, keepdim=True) + 1e-12)\n",
    "        pos = torch.arange(S, device=x.device).unsqueeze(0)\n",
    "        cos, sin = self.rope(hidden, pos)\n",
    "        q, k = apply_rotary_pos_emb(hidden.unsqueeze(1), hidden.unsqueeze(1), cos, sin)\n",
    "        q, k = q.squeeze(1), k.squeeze(1)\n",
    "\n",
    "        attn = q @ k.transpose(-2, -1)\n",
    "\n",
    "        mask = torch.zeros(S, S, device=x.device)\n",
    "        mask = torch.masked_fill(\n",
    "            mask, torch.ones_like(mask).triu(1).bool(), float(\"-inf\")\n",
    "        )\n",
    "        if causal:\n",
    "            attn = attn + mask.unsqueeze(0)\n",
    "\n",
    "        _out = F.softmax(attn, dim=-1)  # softmax to get attention weights\n",
    "\n",
    "        out = _out @ hidden  # B, S, E\n",
    "\n",
    "        out += x\n",
    "        return out, attn, _out\n",
    "\n",
    "\n",
    "import math\n",
    "\n",
    "# adjust plt fontsize\n",
    "plt.rcParams[\"axes.titlesize\"] = 20\n",
    "plt.rcParams[\"axes.labelsize\"] = 16\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "def plot_theory(alpha, causal):\n",
    "    plt.figure(figsize=(20, 8))\n",
    "\n",
    "    B, S, E = 100000, 50, 64\n",
    "\n",
    "    layer = SimpleResModel(E)\n",
    "    m = 1\n",
    "    if alpha == 1:\n",
    "        l = 1\n",
    "        m = 0\n",
    "    elif alpha == 0:\n",
    "        l = 0\n",
    "        m = 1\n",
    "    else:\n",
    "        l = math.sqrt(alpha / (1 - alpha))\n",
    "    print(l)\n",
    "\n",
    "    attn_outs, attn_scores, attn_probs = [], [], []\n",
    "\n",
    "    def _inner_plot(causal, row):\n",
    "        out = torch.nn.functional.normalize(\n",
    "            m * torch.randn(B, S, E) + l * torch.randn(B, 1, E), dim=-1\n",
    "        )\n",
    "\n",
    "        for i in range(4):\n",
    "            out, attn, _attn = layer(out, causal=causal)\n",
    "            plt.subplot(2, 4, row * 8 + i + 1)\n",
    "            inner = attn.mean(0)\n",
    "\n",
    "            plt.imshow(\n",
    "                (inner.detach().cpu()).numpy(),\n",
    "                cmap=\"rocket\",\n",
    "            )\n",
    "\n",
    "            plt.colorbar()\n",
    "            plt.clim(0, 1)\n",
    "            plt.title(f\"Layer {i+1}\")\n",
    "            attn_outs.append(out)\n",
    "            attn_scores.append(attn)\n",
    "            attn_probs.append(_attn)\n",
    "\n",
    "            plt.subplot(2, 4, row * 4 + i + 5)\n",
    "            inner = normalize_lower_diagonals(inner)\n",
    "            plt.imshow(\n",
    "                (inner.detach().cpu()).numpy(),\n",
    "                cmap=\"rocket\",\n",
    "            )\n",
    "            plt.colorbar()\n",
    "\n",
    "            plt.title(f\"Layer {i+1} (Normalized)\")\n",
    "\n",
    "    _inner_plot(causal, 0)\n",
    "\n",
    "    return attn_outs, attn_scores, attn_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faeb710e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_theory(0.0, causal=True)\n",
    "plt.tight_layout()\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/rope_causal_0.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fbf342b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_theory(0.5, causal=True)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/rope_causal_05.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba87493",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_theory(0.0, causal=False)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/rope_noncausal_0.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
