{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa594828",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "class SimpleResModel(nn.Module):\n",
    "    def __init__(self, dim=64):\n",
    "        super().__init__()\n",
    "        self.ln = nn.LayerNorm(dim, elementwise_affine=False, bias=False, eps=1e-12)\n",
    "\n",
    "    def forward(self, x, layernorm=False, ln_scaling=False):\n",
    "        B, S, E = x.size()\n",
    "        hidden = x\n",
    "        if layernorm:\n",
    "            hidden = self.ln(hidden)\n",
    "        else:\n",
    "            hidden = hidden / (hidden.norm(dim=-1, keepdim=True) + 1e-12)\n",
    "        attn = hidden @ hidden.transpose(-2, -1)\n",
    "        if layernorm:\n",
    "            if ln_scaling:\n",
    "                attn = attn / E\n",
    "            else:\n",
    "                attn = attn / (E**0.5)\n",
    "\n",
    "        mask = torch.zeros(S, S, device=x.device)\n",
    "        mask = torch.masked_fill(\n",
    "            mask, torch.ones_like(mask).triu(1).bool(), float(\"-inf\")\n",
    "        )\n",
    "        attn = attn + mask.unsqueeze(0)\n",
    "\n",
    "        _out = F.softmax(attn, dim=-1)  # softmax to get attention weights\n",
    "\n",
    "        out = _out @ hidden  # B, S, E\n",
    "\n",
    "        out += x\n",
    "        return out, attn, _out\n",
    "\n",
    "\n",
    "import math\n",
    "\n",
    "# adjust plt fontsize\n",
    "plt.rcParams[\"axes.titlesize\"] = 20\n",
    "plt.rcParams[\"axes.labelsize\"] = 16\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "\n",
    "\n",
    "def plot_theory(\n",
    "    alpha, row, layernorm, total_rows, ln_scaling=False, num_layers=4\n",
    "):\n",
    "    B, S, E = 100000, 50, 64\n",
    "\n",
    "    layer = SimpleResModel(E)\n",
    "    m = 1\n",
    "    if alpha == 1:\n",
    "        l = 1\n",
    "        m = 0\n",
    "    elif alpha == 0:\n",
    "        l = 0\n",
    "        m = 1\n",
    "    else:\n",
    "        l = math.sqrt(alpha / (1 - alpha))\n",
    "    attn_outs, attn_scores, attn_probs = [], [], []\n",
    "\n",
    "    def _inner_plot():\n",
    "        out = torch.nn.functional.normalize(\n",
    "            m * torch.randn(B, S, E) + l * torch.randn(B, 1, E), dim=-1\n",
    "        )\n",
    "\n",
    "        for i in range(num_layers):\n",
    "            out, attn, _attn = layer(\n",
    "                out, layernorm=layernorm, ln_scaling=ln_scaling\n",
    "            )\n",
    "            plt.subplot(total_rows, num_layers, row * num_layers + i + 1)\n",
    "            inner = attn.mean(0)\n",
    "            plt.imshow(inner.detach().cpu().numpy(), cmap=\"rocket\")\n",
    "            if not layernorm:\n",
    "                plt.clim(vmin=0, vmax=1)\n",
    "            else:\n",
    "                if ln_scaling:\n",
    "                    plt.clim(vmin=0, vmax=1)\n",
    "                else:\n",
    "                    plt.clim(vmin=0, vmax=8)\n",
    "            plt.colorbar()\n",
    "\n",
    "            title = rf\"Layer {i+1}, $\\alpha$={alpha}\"\n",
    "            if layernorm:\n",
    "                if ln_scaling:\n",
    "                    title += \" (Scaling=$d$)\"\n",
    "                else:\n",
    "                    title += \" (Scaling=$\\sqrt{d}$)\"\n",
    "            plt.title(title)\n",
    "            attn_outs.append(out)\n",
    "            attn_scores.append(attn)\n",
    "            attn_probs.append(_attn)\n",
    "\n",
    "    _inner_plot()\n",
    "\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "plot_theory(0, 0, False, 2)\n",
    "plot_theory(0.2, 1, False, 2)\n",
    "plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/l2norm.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32ac8ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 5))\n",
    "plot_theory(0, 0, False, 1)\n",
    "plot_theory(0.2, 1, False, 2)\n",
    "plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/l2norm.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ee15fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 15))\n",
    "plot_theory(0.4, 0, False, 3)\n",
    "plot_theory(0.6, 1, False, 3)\n",
    "plot_theory(0.8, 2, False, 3)\n",
    "plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/l2norm_alpha.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6e34005",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 20))\n",
    "plot_theory(0.0, 0, True, 4)\n",
    "plot_theory(0.2, 1, True, 4)\n",
    "plot_theory(0.0, 2, True, 4, ln_scaling=True)\n",
    "plot_theory(0.2, 3, True, 4, ln_scaling=True)\n",
    "plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1)\n",
    "plt.tight_layout(pad=0.9)\n",
    "plt.savefig(\"../figures/layernorm.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
