{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24c8d68d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "from datasets import load_from_disk\n",
    "\n",
    "\n",
    "from transformers import LlamaForCausalLM, LlamaConfig\n",
    "from transformers.models.llama.modeling_llama import repeat_kv\n",
    "\n",
    "import transformers.models.llama.modeling_llama as modeling_llama\n",
    "\n",
    "def noop_apply_rotary_pos_emb(q, k, *args, **kwargs):\n",
    "    return q, k\n",
    "\n",
    "\n",
    "modeling_llama.apply_rotary_pos_emb = noop_apply_rotary_pos_emb\n",
    "\n",
    "\n",
    "nope_model = LlamaForCausalLM.from_pretrained(\n",
    "    \"{NOPE_MODEL_PATH}\",\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    attn_implementation=\"eager\",\n",
    ")\n",
    "\n",
    "\n",
    "dset = load_from_disk(\"../data\")\n",
    "\n",
    "input_ids = torch.stack([torch.LongTensor(i[:50]) for i in dset[\"input_ids\"]]).to(\n",
    "    nope_model.device\n",
    ") \n",
    "\n",
    "with torch.no_grad():\n",
    "    nope_out = nope_model(\n",
    "        input_ids=input_ids,\n",
    "        output_hidden_states=True,\n",
    "    )\n",
    "\n",
    "nope_hiddens = nope_out.hidden_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc443dca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.models.llama.modeling_llama import apply_rotary_pos_emb\n",
    "\n",
    "plt.rcParams[\"axes.titlesize\"] = 20\n",
    "plt.rcParams[\"axes.labelsize\"] = 16\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "\n",
    "\n",
    "def plot_matrix(mat, title):\n",
    "    mat = mat.detach().float().cpu().numpy()\n",
    "    if mat.ndim == 3:\n",
    "        mat = mat.mean(0)\n",
    "    elif mat.ndim == 4:\n",
    "        mat = mat.mean((0, 1))\n",
    "    mat[np.triu_indices(mat.shape[0], k=1)] = np.nan\n",
    "    plt.imshow(mat, cmap=\"rocket\")\n",
    "    plt.title(title)\n",
    "    plt.colorbar()\n",
    "    plt.xlabel(\"Key Position\")\n",
    "    plt.ylabel(\"Query Position\")\n",
    "\n",
    "\n",
    "def plot_activations(layer_idx):\n",
    "    x = nope_hiddens[layer_idx]\n",
    "    next_hidden = nope_hiddens[layer_idx + 1]\n",
    "    model = nope_model\n",
    "\n",
    "    ln = model.model.layers[layer_idx].input_layernorm(x)\n",
    "    B, S, E = x.shape\n",
    "\n",
    "    q = (\n",
    "        model.model.layers[layer_idx]\n",
    "        .self_attn.q_proj(ln)\n",
    "        .reshape(B, S, -1, 64)\n",
    "        .transpose(1, 2)\n",
    "    )\n",
    "    k = (\n",
    "        model.model.layers[layer_idx]\n",
    "        .self_attn.k_proj(ln)\n",
    "        .reshape(B, S, -1, 64)\n",
    "        .transpose(1, 2)\n",
    "    )\n",
    "    v = (\n",
    "        model.model.layers[layer_idx]\n",
    "        .self_attn.v_proj(ln)\n",
    "        .reshape(B, S, -1, 64)\n",
    "        .transpose(1, 2)\n",
    "    )\n",
    "\n",
    "    k = repeat_kv(k, 4)\n",
    "    v = repeat_kv(v, 4)\n",
    "\n",
    "    attn_score = q @ k.mT / math.sqrt(64)\n",
    "    attn_score = attn_score + torch.triu(\n",
    "        torch.ones((S, S), device=attn_score.device)\n",
    "        * torch.finfo(attn_score.dtype).min,\n",
    "        1,\n",
    "    )\n",
    "\n",
    "    attn_prob = attn_score.softmax(dim=-1, dtype=torch.float32).to(q.dtype)\n",
    "    qkv = (attn_prob @ v).transpose(1, 2).reshape(B, S, -1)\n",
    "    qkvo = model.model.layers[layer_idx].self_attn.o_proj(qkv)\n",
    "    qkvox = qkvo + x\n",
    "\n",
    "    ln_f = model.model.layers[layer_idx].post_attention_layernorm(qkvox)\n",
    "    f = model.model.layers[layer_idx].mlp(ln_f)\n",
    "    fx = f + qkvox\n",
    "\n",
    "    next_ln = model.model.layers[layer_idx + 1].input_layernorm(fx)\n",
    "    next_q = (\n",
    "        model.model.layers[layer_idx + 1]\n",
    "        .self_attn.q_proj(next_ln)\n",
    "        .reshape(B, S, -1, 64)\n",
    "        .transpose(1, 2)\n",
    "    )\n",
    "    next_k = (\n",
    "        model.model.layers[layer_idx + 1]\n",
    "        .self_attn.k_proj(next_ln)\n",
    "        .reshape(B, S, -1, 64)\n",
    "        .transpose(1, 2)\n",
    "    )\n",
    "    next_k = repeat_kv(next_k, 4)\n",
    "\n",
    "    next_qk = next_q @ next_k.mT / math.sqrt(64)\n",
    "\n",
    "    assert torch.allclose(fx, next_hidden, atol=1e-5)\n",
    "\n",
    "    plt.figure(figsize=(20, 20))\n",
    "\n",
    "\n",
    "    plt.subplot(3, 3, 1)\n",
    "    plot_matrix(\n",
    "        ln @ ln.mT,\n",
    "        rf\"(a) $\\mathbf{{Y^{{({layer_idx+1})}}Y^{{({layer_idx+1})\\intercal}}}}$\"\n",
    "        + \"\\n$(Y=LN(X))$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 2)\n",
    "    plot_matrix(\n",
    "        attn_score,\n",
    "        rf\"(b) $\\mathbf{{Q^{{({layer_idx+1})}}K^{{({layer_idx+1})\\intercal}}/\\sqrt{{d}}}}$\"\n",
    "        + \"\\n$(Q=W_QY, K=W_KY)$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 3)\n",
    "    plot_matrix(\n",
    "        qkv @ qkv.mT,\n",
    "        rf\"(c) $\\mathbf{{(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}})(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}})^\\intercal}}$\"\n",
    "        + \"\\n$(A=Softmax(Causal(QK^\\intercal)))$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 4)\n",
    "    plot_matrix(\n",
    "        qkvo @ qkvo.mT,\n",
    "        \"(d)\\n\"\n",
    "        + rf\"$\\mathbf{{(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}}W_O^{{({layer_idx+1})}})(A^{{({layer_idx+1})}}V^{{({layer_idx+1})}}W_O^{{({layer_idx+1})}})^\\intercal}}$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 5)\n",
    "    plot_matrix(\n",
    "        qkvox @ qkvox.mT,\n",
    "        rf\"(e) $\\mathbf{{O^{{({layer_idx+1})}}O^{{({layer_idx+1})\\intercal}}}}$\"\n",
    "        + \"\\n$(O=AVW_O+X)$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 6)\n",
    "    plot_matrix(\n",
    "        ln_f @ ln_f.mT,\n",
    "        rf\"(f) $\\mathbf{{LN(O^{{({layer_idx+1})}})LN(O^{{({layer_idx+1})}})^\\intercal}}$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 7)\n",
    "    plot_matrix(\n",
    "        fx @ fx.mT,\n",
    "        f\"(g) $\\mathbf{{X^{{({layer_idx+1})}}X^{{({layer_idx+1})\\intercal}}}}$\"\n",
    "        + \"\\n$(X=FFN(LN(O))+O)$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 8)\n",
    "    plot_matrix(\n",
    "        next_ln @ next_ln.mT,\n",
    "        f\"(h) $\\mathbf{{Y^{{({layer_idx+2})}}Y^{{({layer_idx+2})\\intercal}}}}$\",\n",
    "    )\n",
    "\n",
    "    plt.subplot(3, 3, 9)\n",
    "    plot_matrix(\n",
    "        next_qk, \n",
    "        rf\"(b) $\\mathbf{{Q^{{({layer_idx+1})}}K^{{({layer_idx+1})\\intercal}}/\\sqrt{{d}}}}$\"\n",
    "    )\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"../figures/nope_params.pdf\")\n",
    "\n",
    "plot_activations(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc71cf9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 40))\n",
    "\n",
    "\n",
    "def normalize_lower_diagonals(attn: torch.Tensor):\n",
    "    n, _ = attn.shape\n",
    "    for i in range(n):\n",
    "        diag = attn.diag(-i)\n",
    "        diag = diag - diag.mean()\n",
    "        attn[torch.arange(i, n), torch.arange(n - i)] = diag\n",
    "    return attn\n",
    "\n",
    "\n",
    "rope = False\n",
    "for l in range(22):\n",
    "    plt.subplot(8, 4, l + 1)\n",
    "\n",
    "    x = nope_hiddens[l]\n",
    "    next_hidden = nope_hiddens[l + 1]\n",
    "    model = nope_model\n",
    "\n",
    "    ln = model.model.layers[l].input_layernorm(x)\n",
    "    B, S, E = x.shape\n",
    "\n",
    "    q = model.model.layers[l].self_attn.q_proj(ln).reshape(B, S, -1, 64).transpose(1, 2)\n",
    "    k = model.model.layers[l].self_attn.k_proj(ln).reshape(B, S, -1, 64).transpose(1, 2)\n",
    "\n",
    "\n",
    "    k = repeat_kv(k, 4)\n",
    "\n",
    "    attn_score = q @ k.mT / math.sqrt(64)\n",
    "    attn_score = attn_score + torch.triu(\n",
    "        torch.ones((S, S), device=attn_score.device)\n",
    "        * torch.finfo(attn_score.dtype).min,\n",
    "        1,\n",
    "    )\n",
    "\n",
    "    plt.imshow(\n",
    "        normalize_lower_diagonals(\n",
    "            attn_score.mean(dim=(0, 1)).detach().cpu().float()\n",
    "        ).numpy(),\n",
    "        cmap=\"rocket\",\n",
    "    )\n",
    "    plt.colorbar()\n",
    "    plt.title(f\"Layer {l+1}\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../figures/nope_attns_all.pdf\", bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
