{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "\n",
    "torch.manual_seed(2024)\n",
    "\n",
    "T = 10000\n",
    "\n",
    "def phi(x):\n",
    "    return nn.ELU()(x) + 1\n",
    "\n",
    "\n",
    "def prefix_attn(X, P, W_Q, W_K, W_V):\n",
    "    S = torch.cat([P, X])\n",
    "    Q = X @ W_Q\n",
    "    K = S @ W_K\n",
    "    V = S @ W_V\n",
    "\n",
    "    return torch.softmax(Q @ K.T, dim=-1) @ V\n",
    "\n",
    "def ntk_attn(X, Z, k, W_Q, W_K, W_V):\n",
    "    Q = X @ W_Q\n",
    "    K = X @ W_K\n",
    "    V = X @ W_V\n",
    "\n",
    "    A = torch.exp(Q @ K.T)\n",
    "    phi_Q = phi(Q)\n",
    "    D = 1 / (A.sum(-1).unsqueeze(-1) + phi_Q @ k)\n",
    "    return D * (A @ V + phi_Q @ Z)\n",
    "\n",
    "def get_time_prefix_attn(L, m, d):\n",
    "    W_Q, W_K, W_V = torch.randn(d, d), torch.randn(d, d), torch.randn(d, d)\n",
    "    P = torch.randn(m, d)\n",
    "    ret = []\n",
    "    for _ in range(T):\n",
    "        X = torch.randn(L, d)\n",
    "        time0 = time.time()\n",
    "        prefix_attn(X, P, W_Q, W_K, W_V)\n",
    "        ret.append(time.time() - time0)\n",
    "    return sum(ret) / T, max(ret), min(ret)\n",
    "\n",
    "def get_time_ntk_attn(L, m, d):\n",
    "    W_Q, W_K, W_V = torch.randn(d, d), torch.randn(d, d), torch.randn(d, d)\n",
    "    Z = torch.randn(d, d)\n",
    "    k = torch.randn(d, 1)\n",
    "    ret = []\n",
    "    for _ in range(T):\n",
    "        X = torch.randn(L, d)\n",
    "        time0 = time.time()\n",
    "        ntk_attn(X, Z, k, W_Q, W_K, W_V)\n",
    "        ret.append(time.time() - time0)\n",
    "    return sum(ret) / T, max(ret), min(ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    L_list = [32, 64, 128, 256]\n",
    "    m_list = [2 ** i for i in range(16)]\n",
    "    d = 32\n",
    "\n",
    "    prefix_ret = []\n",
    "    ntk_ret = []\n",
    "    for L in L_list:\n",
    "        row = []\n",
    "        for m in tqdm(m_list):\n",
    "            row.append(get_time_prefix_attn(L, m, d))\n",
    "        prefix_ret.append(row)\n",
    "        ntk_ret.append(get_time_ntk_attn(L, m, d))\n",
    "    return prefix_ret, ntk_ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1, x2 = main()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
