{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "def imshow(img):\n",
    "    import torch\n",
    "    if isinstance(img, torch.Tensor):\n",
    "        img = img.detach().cpu().numpy()\n",
    "    plt.imshow(img)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/user/anaconda3/envs/torch/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 101, 7592, 1010, 2088,  999,  102, 1045, 2293, 8870, 1012,  102,    0,\n",
       "             0,    0,    0,    0,    0,    0,    0,    0,    0],\n",
       "         [ 101, 2023, 2003, 1996, 2742,  102, 1045, 2215, 2000, 2022, 3407,  102,\n",
       "             0,    0,    0,    0,    0,    0,    0,    0,    0],\n",
       "         [ 101, 3198, 2025, 2054, 2115, 2406, 2064, 2079, 2005, 2017,  102, 3198,\n",
       "          2054, 2017, 2064, 2079, 2005, 2115, 2406, 1012,  102]]),\n",
       " '[CLS] hello, world! [SEP] i love cats. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXTS = [\n",
    "    \"Hello, World! [SEP] I love cats.\",\n",
    "    \"This is the example [SEP] I want to be happy\",\n",
    "    \"Ask not what your country can do for you [SEP] Ask what you can do for your country.\",\n",
    "]\n",
    "\n",
    "model_id_hf = 'bert-base-uncased'\n",
    "tokenizer = transformers.BertTokenizerFast.from_pretrained(model_id_hf)\n",
    "tokenized_result = tokenizer(TEXTS, padding=True, max_length=512, truncation=True, return_tensors=\"pt\")\n",
    "tokenized_result['input_ids'], tokenizer.decode(tokenized_result['input_ids'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: mrpc\n",
      "Trainer: Checkpoint path saves/glue-mrpc-4.pth\n"
     ]
    }
   ],
   "source": [
    "# Manual top-k\n",
    "from models import sparse_token as sparse\n",
    "from trainer import glue_base\n",
    "\n",
    "topk_trainer = glue_base.GlueAttentionApproxTrainer('mrpc', factor=4, batch_size=1, wiki_train=False, device='cpu')\n",
    "target_ks = 0.5\n",
    "if target_ks <= 0.666:\n",
    "    ksx = [target_ks*0.5+((1-x/10.0)**1.0) * target_ks for x in range(12)]\n",
    "else:\n",
    "    ksx = [(1-x/10.0)*(2-2*target_ks)+(2*target_ks-1) for x in range(12)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD4CAYAAAAO2kjhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXeUlEQVR4nO3de4zdd3nn8ffnzMX3a3yLL8QJjWiz3cYgK0CbdpNS0iRCDe22NFG3m+6yMluBVKRWFe1KTUW1EquKsmqDoC5YSSsI9BYaqRGJxRaFlAIxUYITCMRrzMaXjO3YHsf2xOOZ8+wf8xs0me85c54z53jmjPm8JGvO/M4z3+/vHB8/PpdnnkcRgZnZVLX5PgEz6z1ODGZWcGIws4ITg5kVnBjMrNA/3yfQyLq1fbF920DLuBf3L8svOth6PQAujeXXTIp6PReo/JqqJXO6covG2Hh+7+R9GaOX0mtmzxNyn6KpjTsz+7mc+nL3eTv3ZfY0tWhRfs3x1vuPjL3KaH2k6e49mRi2bxvgG49taxl3x3VvS69Z27wpFVd/+Xh6zaz6yEgqTn196TVrS5fmApP/iMdfOZXeu3/z1lTc2EtH02tqIPlQTDzoAdSff2jHeC5x11YuT8WNnx5O761aLjPUtm/Przn8asuYr57825n3S+/W6ASk2yV9V9IBSR9qcP0iSZ+vrv+6pO2d7Gdmc2PWiUFSH/Bx4A7gBuAeSTdMC3svcDoifgz4GPC/Zrufmc2dTp4x3AQciIiDETEKfA64a1rMXcCD1eW/B94hpV9Mmtk86SQxbAFemvL94epYw5iIGAOGgas62NPM5kDPfFwpaZekfZL2nXiljXd1zazrOkkMR4CpHx1srY41jJHUD6wCXmm0WETsjoidEbFz/VX5d+fNrPs6SQxPAddLulbSIHA38Mi0mEeAe6vLvwr8n/Cvc5r1vFnXMUTEmKQPAI8BfcCeiHhe0oeBfRHxCPBp4G8kHQBOMZE8zKzHqRf/A1++dlv81Dt+p2Xcqn89lF5z/Op1qbiLG5ak11z67ZfTsRmRLIQCIFmUQ2SrLtt48rgh9/6x2qgijWRRkBYN5hZcsji9d/34yVxg8t9KW8VVo6O5NZclC9oAJYrfvvryZxm+ONT0E8KeefPRzHqHE4OZFZwYzKzgxGBmBScGMys4MZhZwYnBzApODGZWcGIws4ITg5kVerLnY+3iOMsPnWsZd/o/XJtec/WXD6bixq/Znl4zW8IcI6+l4tRGGS/JxqRkf4M92ywX0IXc7WmHVq1IxcW587n12mjIWluR6+VYP38hFRdjbTQUTjb1rSdvN0DfYKJsvEV5t58xmFnBicHMCk4MZlZwYjCzghODmRWcGMys0MnAmW2S/kXStyU9L6louSTpFknDkp6p/vxRZ6drZnOhkzqGMeB3I+JpSSuAb0raGxHfnhb3lYh4Vwf7mNkcm/Uzhog4FhFPV5dfBb5DOXDGzBagrlQ+VsNq3wx8vcHVb5f0LHAU+L2IeL7JGruAXQCL+1fQd7p1pVffpVzFGsDwz+WqJFc++lx6zbi29URuAL12MRd3LldZB8DSXNPaGD6bW6+dar1kY9LITrAG4uhQKi7bDDZebT3x+YdrLs5VnGanK0YbUxjTE87bmIQely4lgmaufOw4MUhaDvwD8MGImP4ofBq4JiLOSboT+AJwfePzjN3AboBVizf1Xutqsx8hHX0qIWmAiaTwmYj4x+nXR8TZiDhXXX4UGJCU6+NuZvOmk08lxMRAme9ExJ81idk0Od1a0k3Vfg1H1JlZ7+jkpcTPAL8J7Jf0THXsD4E3AETEJ5kYS/fbksaAEeBuj6gz632djKh7EpjxXZaIuB+4f7Z7mNn8cOWjmRWcGMys4MRgZgUnBjMrODGYWaEnm8FGXx/jV7VuDrpi/4n0mhrJlSWPvP3H02sueeHlVNz4pjW5BVflSo0B+o6cTMdmaNXKfPClXPm02vlkOluWvCxZCt5G89S4kGzqOzqaitOS3DkCxMXc45J6Pb1m7r6cuWzbzxjMrODEYGYFJwYzKzgxmFnBicHMCk4MZlZwYjCzghODmRWcGMys0JOVjxq9RO1Q66pCZUfBkx9NvuRgvsHU4f94TSru6ieHU3G6mB/dThsNR1OS1YwAsWRRLrCNar24kGyEm60UrOXvn2xFY/Yx1M7jUtmmvsnqzPT+Le4eP2Mws0LHiUHSIUn7q0lT+xpcL0l/LumApG9Jekune5rZ5dWtlxK3RkSz3+q5g4mW8dcDbwU+UX01sx41Fy8l7gL+OiZ8DVgt6eo52NfMZqkbiSGAxyV9s5omNd0W4KUp3x+mwSg7Sbsk7ZO0b7Sef6PFzLqvGy8lbo6II5I2AHslvRART7S7yOsmUQ1scIt5s3nU8TOGiDhSfT0OPAzcNC3kCDB1yOPW6piZ9ahOR9Qtk7Ri8jJwGzB9KuwjwH+uPp14GzAcEcc62dfMLq9OX0psBB6uptD1A5+NiC9K+u/ww2lUjwJ3AgeAC8B/6XBPM7vMOkoMEXEQuLHB8U9OuRzA+9s7qz5Y37pPYjvT7mIweVPH8tV6m798OhWn86/lFjx5Kr33xTe/MRU3+Mz3U3GxNNdzEUAXkrenPz+6naW5fpdKrlk/fSa9tQYHU3HZykdGE2PoJ9dMPoa1KFltCtTPJ6pIx2d+nLvy0cwKTgxmVnBiMLOCE4OZFZwYzKzgxGBmBScGMys4MZhZwYnBzApODGZW6MlmsIyNwVDrMe/tjBvPltJeujo5sh4Y+H7rhrUTm+cakx79T/8uvfemB/fnArONSdspL180kAtsp2Ht6hW5vc+8mtt6+bL01vXh5JrJx1u2uSwAA7n7sn7ufHpJLc6XTzfjZwxmVnBiMLOCE4OZFZwYzKzgxGBmBScGMyvMOjFIelM1fWryz1lJH5wWc4uk4Skxf9TxGZvZZTfrOoaI+C6wA0BSHxOdnx9uEPqViHjXbPcxs7nXrZcS7wD+b0T8oEvrmdk86lbl493AQ02ue7ukZ4GjwO9FxPONgqopVrsAFteW53Zto9lodnR7/8E2OtvXknl1Sa7R6qYv55vB1n9ieyqu/8grqbhL65L3OdB/JjcpTGfz1XoM5ioA0w1Za/mqy/TY+kvJJq/JakYAxsdzcZFvUlzLVH2OzHybuzHtehD4JeDvGlz9NHBNRNwI/AXwhWbrRMTuiNgZETsHa/mOxWbWfd14KXEH8HREDE2/IiLORsS56vKjwICkdV3Y08wuo24khnto8jJC0iZV02gk3VTtl3tua2bzpqP3GKqxdO8E3jfl2NQpVL8K/LakMWAEuDvamRJjZvOi00lU54Grph2bOoXqfuD+TvYws7nnykczKzgxmFnBicHMCk4MZlbozZ6PEXApUeHWzgccx47n4tRGrlyV61OYui2ALuZ7BfafS4w6B0av3ZCKGzycr7o8+5arU3HLD+YfXrWXc59ia82qVFycOp3em75cBW226rK2aX166zhalP801LdtS3rN+snE3+X4zJWUfsZgZgUnBjMrODGYWcGJwcwKTgxmVnBiMLOCE4OZFZwYzKzgxGBmBScGMyv0ZEn0pTVLOP7u1iPh1z/0bHpNLcuNRU83BgUiOZo8PWa9nfHl/bm/usFk2fjoG/Id98YW5Rqt6gdtNNZdNJiLS5aXU8+Xy8fFi6k4LVmSWy9Z5gxQP597DOlC7vELUFuZKNUfmbkM3M8YzKyQSgyS9kg6Lum5KcfWStor6cXq65omP3tvFfOipHu7deJmdvlknzE8ANw+7diHgC9FxPXAl6rvX0fSWuA+4K3ATcB9zRKImfWOVGKIiCeA6b/LeRfwYHX5QeDdDX70F4G9EXEqIk4DeykTjJn1mE7eY9gYEZPvLr0MbGwQswV4acr3h6tjBUm7JO2TtG/stTYmGJlZ13XlzceqJXxHbeGnTqLqX5x/B9bMuq+TxDAk6WqA6mujFklHgG1Tvt9aHTOzHtZJYngEmPyU4V7gnxrEPAbcJmlN9abjbdUxM+th2Y8rHwL+DXiTpMOS3gt8BHinpBeBX6i+R9JOSZ8CiIhTwJ8AT1V/PlwdM7Mepl6cGLdqcEP89Ppfbx2YHJ0OEKeHc4HJkfVtGXktF5dsSgoQo7nGsUquqUy1XOXwr12Titv68OH0mmffnGswu+L5k6k4nT2X3jsujOTisvd5skISgFquirSdSk4tbb3/V4c+x/DoUNPNXfloZgUnBjMrODGYWcGJwcwKTgxmVnBiMLOCE4OZFZwYzKzgxGBmBScGMyv0ZDPYGBxgbNv6lnH9Q2fyiybLp2P4bH7NgdyamRJVAMaSjU6B2orlqbhsyXuM5MqCAa7an2ueeuKWhq03Gq/5d9/KBV6TW7N+/kJ6b67bmgqrDeV+zaedvZUt6x8fT68Zo5cSQTM/LvyMwcwKTgxmVnBiMLOCE4OZFZwYzKzgxGBmhZaJockUqj+V9IKkb0l6WNLqJj97SNJ+Sc9I2tfF8zazyyjzjOEByiExe4GfjIifAr4H/MEMP39rROyIiJ2zO0Uzm2stE0OjKVQR8XhETFbjfI2JtvBmdoXoRuXjfwU+3+S6AB6XFMBfRsTuZotI2gXsAljct5z+l0603rnWxsj6ZMNPLVuaXlODudHtUa+n18yK13INZseTTVH7t+SasQIsOpm7L2tjbYyi//HtuTVP5W7P2L+/Lr13/8ncmnEx1wy2tv6q9N71ocTjHNp6rNf6ErEtetB2lBgk/Q9gDPhMk5CbI+KIpA3AXkkvVM9AClXS2A0TXaI7OS8z68ysP5WQ9FvAu4DfiCYF+RFxpPp6HHiYiYnXZtbjZpUYJN0O/D7wSxHR8DdGJC2TtGLyMhNTqJ5rFGtmvSXzcWWjKVT3AyuYeHnwjKRPVrGbJT1a/ehG4ElJzwLfAP45Ir54WW6FmXVVy/cYIuKeBoc/3ST2KHBndfkgcGNHZ2dm88KVj2ZWcGIws4ITg5kVnBjMrNCTPR+JIC617lvXzuj22oZ1ucDXcv0MIT8WHSVHnfe38deRHIteW7I4t17i/v7hmifOpOIGWJ1eM+v4rZtTcesf+V56zYs3bk/FLTp1JrdgssoWQEtzlbZx/nx6zVTPyfGZq3H9jMHMCk4MZlZwYjCzghODmRWcGMys4MRgZgUnBjMrODGYWcGJwcwKTgxmVujNkuh6pBpvqo3S0+zI+lizMr/midxY9LSRXINXyI+3bzXu/IdhyeayAFqcK7POjo0HGN+ca6C6/uu5NY+9503pvTf9zf5U3PhPbE/F9Z84m96bs6/m4rJl9UAt09B4ZObnBH7GYGaF2U6i+mNJR6q2bs9IurPJz94u6buSDkj6UDdP3Mwun9lOogL4WDVhakdEPDr9Skl9wMeBO4AbgHsk3dDJyZrZ3JjVJKqkm4ADEXEwIkaBzwF3zWIdM5tjnbzH8IFqqO0eSWsaXL8FeGnK94erYw1J2iVpn6R9o5F/I8zMum+2ieETwBuBHcAx4KOdnkhE7I6InRGxc1DJ5iJmdlnMKjFExFBEjEdEHfgrGk+YOgJsm/L91uqYmfW42U6imjoB9ZdpPGHqKeB6SddKGgTuBh6ZzX5mNrdaFjhVk6huAdZJOgzcB9wiaQcT06wPAe+rYjcDn4qIOyNiTNIHgMeAPmBPRDx/OW6EmXWX0hV0c2jVok3x01t+o7uLJivH4kwbVWvj47m4wVzVJWNj+b2T4lJuzWxTUgBWLc+tmdwboL5iWSqudj5X7Vpfmb89Op98szv52Bj/sabvsRf6h4ZTcXEu0eB1MnZj6yrSrx34NMMXjjb9R+HKRzMrODGYWcGJwcwKTgxmVnBiMLOCE4OZFZwYzKzgxGBmBScGMyv0Zs/Hvhr1Va0r4Wonc1VjAPVXci0ltGJFes3s2Hr196Xi6ok+l+2KixdTcbWVbdzuNioas2qvnEnFZXtD9h19Jb13ff3qVJxqubj+l06m9073cly7Kr1kLG79uIwW2/oZg5kVnBjMrODEYGYFJwYzKzgxmFnBicHMCk4MZlbItHbbA7wLOB4RP1kd+zwwORxwNXAmInY0+NlDwKvAODAWETu7ctZmdlllKnQeAO4H/nryQET8+uRlSR8FZqo0ujUi2qj4MLP51jIxRMQTkrY3uk6SgPcAP9/l8zKzedRpSfTPAkMR8WKT6wN4XFIAfxkRu5stJGkXsAtgsZbBwcMtN4++/Fsk2aa3cfp0es1asoHq+NZtrYOA2vdzjU4B4rVcqTN9uXLsGMnvTbJxq9oo8Y7RXKy+9/9y67Xx2NCxXFNfLV6Uihu685r03hueOJGKu7glXxJdG68ngmauie40MdwDPDTD9TdHxBFJG4C9kl6oZmEWqqSxG2BV37rea11t9iNk1p9KSOoHfgX4fLOYiDhSfT0OPEzjiVVm1mM6+bjyF4AXIqLhc35JyyStmLwM3EbjiVVm1mNaJoZqEtW/AW+SdFjSe6ur7mbaywhJmyU9Wn27EXhS0rPAN4B/jogvdu/UzexyyXwqcU+T47/V4NhR4M7q8kHgxg7Pz8zmgSsfzazgxGBmBScGMys4MZhZoWebwdaWt66ui/P50eBKVgBq5cr8mksX5+K+dygVVx/LN1lVthHteLaqL3dbALiQGxtfH86NjQdQslJRG9el4uLoUH7verIqtp67L9c+l39cXtyaq2i8sHEgveb4YOsGs+P7Z76//YzBzApODGZWcGIws4ITg5kVnBjMrODEYGYFJwYzKzgxmFnBicHMCk4MZlboyZLo0bWLOHz3dS3jNu9+Nr1mbc3qVFxcyDdFrQ/lGnlGpjkn+bJgAJKlzgzkSmnbud2XbnxjKq7/qVfSa9KXa7TK6ZkmFUyh1mXBk9KNaAdy/1z6DxxN793foinrpNPX5+5zgNUHWt+e2ujMZeCZDk7bJP2LpG9Lel7S71TH10raK+nF6uuaJj9/bxXzoqR7W56xmc27zH9RY8DvRsQNwNuA90u6AfgQ8KWIuB74UvX960haC9wHvJWJRrD3NUsgZtY7WiaGiDgWEU9Xl18FvgNsAe4CHqzCHgTe3eDHfxHYGxGnIuI0sBe4vQvnbWaXUVtvPlYTqd4MfB3YGBHHqqteZqL563RbgJemfH+4OmZmPSydGCQtB/4B+GBEvO4X7WNizFNHQ2Ik7ZK0T9K+8QvnO1nKzDqUSgySBphICp+JiH+sDg9Jurq6/mrgeIMfPQJMnc+2tTpWiIjdEbEzInb2Lc2NQDOzyyPzqYSATwPfiYg/m3LVI8Dkpwz3Av/U4McfA26TtKZ60/G26piZ9bDMM4afAX4T+HlJz1R/7gQ+ArxT0otMTKX6CICknZI+BRARp4A/AZ6q/ny4OmZmPSwzcOZJoFkVxjsaxO8D/tuU7/cAe2Z7gmY295QdDz+XJJ0AfjDt8Drg5DyczuVyJd2eK+m2wI/G7bkmItY3+4GeTAyNSNoXETvn+zy65Uq6PVfSbQHfHvAvUZlZA04MZlZYSIlh93yfQJddSbfnSrot4NuzcN5jMLO5s5CeMZjZHHFiMLNCzycGSbdL+q6kA5KKng8LjaRDkvZXFaT75vt82iVpj6Tjkp6bcizVtKcXNbk9fyzpyLRK357XaVOlqXo6MUjqAz4O3AHcANxTNYlZ6G6NiB0L9LPyByh7arRs2tPDHqBxj5CPVX9HOyLi0Tk+p9madVOl6Xo6MTDR9elARByMiFHgc0w0iLF5EhFPANN/3yXTtKcnNbk9C1KHTZVep9cTw5XY6CWAxyV9U9Ku+T6ZLsk07VloPiDpW9VLjQXz0mjSLJoqvU6vJ4Yr0c0R8RYmXh69X9LPzfcJdVM3mvb0gE8AbwR2AMeAj87r2bSpG02Vej0xpBu9LBQRcaT6ehx4mImXSwtdpmnPghERQxExHhF14K9YQH9HHTRVep1eTwxPAddLulbSIHA3Ew1iFiRJyyStmLzMROOa52b+qQUh07RnwZj8R1T5ZRbI31GHTZVev1avVz5WHxX9b6AP2BMR/3N+z2j2JF3HxLMEmOiF8dmFdnskPQTcwsSv8g4xMR7gC8DfAm9g4tfl37NQGvI0uT23MPEyIoBDwPumvEbvWZJuBr4C7Acmpxz9IRPvM7T199PzicHM5l6vv5Qws3ngxGBmBScGMys4MZhZwYnBzApODGZWcGIws8L/Bx0tISAKoKXAAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAADeCAYAAADPT+AWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAANoElEQVR4nO3ce6hl5X3G8e/TGS8dIzpesN6oWqxgS41yEHOphGiN2qBpCUVpWnOBIbS2WlrClEAS+lfTS+iFkDI1NrYVlRrTSDAdjUkIhTrmzGS8jnFGa+PoqBMtmjYQL/n1j72mnBzPPnPOXvsyvn4/cDhrr/Xu8/5499rPWXuttd9UFZKkN7afmnUBkqT+DHNJaoBhLkkNMMwlqQGGuSQ1YO00OzvmqDV1yskHTbNLjdGj96+baf8//0s/nGn/s9R37PuO3Rv9tZ/1+PW19f4ffb+qjl2uTaZ5a+LcWYfWvZtPnlp/Gq/3nPDWmfa/+entM+1/lvqOfd+xe6O/9rMev77WHL9ra1XNLdfG0yyS1ADDXJIaYJhLUgN6hXmSi5N8N8muJBvHVZQkaXVGDvMka4DPApcAZwJXJjlzXIVJklauz5H5ucCuqnq8ql4GbgYuH09ZkqTV6BPmJwJPLni8u1v3E5JsSDKfZH7v86/16E6SNMzEL4BW1aaqmququWOPXjPp7iTpTalPmD8FLPwG0EndOknSlPUJ828Dpyc5NcnBwBXA7eMpS5K0GiPPzVJVrya5GtgMrAGur6qHxlaZJGnFek20VVV3AHeMqRZJ0oj8BqgkNcAwl6QGTHU+c+mN6o0+hara55G5JDXAMJekBhjmktQAw1ySGmCYS1IDDHNJaoBhLkkNMMwlqQGGuSQ1wDCXpAYY5pLUAMNckhpgmEtSAwxzSWqAYS5JDXA+c+lNoO987DrweWQuSQ0wzCWpAYa5JDXAMJekBowc5klOTvKNJA8neSjJNeMsTJK0cn3uZnkV+MOq2pbkcGBrkruq6uEx1SZJWqGRj8yrak9VbeuWfwDsAE4cV2GSpJUbyznzJKcAZwNblti2Icl8kvm9z782ju4kSYv0DvMkbwG+CFxbVS8t3l5Vm6pqrqrmjj16Td/uJElL6BXmSQ5iEOQ3VtVt4ylJkrRafe5mCfB5YEdVfWZ8JUmSVqvPkfk7gN8C3p1ke/dz6ZjqkiStwsi3JlbVvwMZYy2SpBH5DVBJaoBhLkkNcD5zSQc852PfP4/MJakBhrkkNcAwl6QGGOaS1ADDXJIaYJhLUgMMc0lqgGEuSQ0wzCWpAYa5JDXAMJekBhjmktQAw1ySGmCYS1IDDHNJaoDzmU9R3zmZNz+9fSx1SGqPR+aS1ADDXJIaYJhLUgMMc0lqQO8wT7ImyXeSfGUcBUmSVm8cR+bXADvG8HckSSPqFeZJTgJ+FbhuPOVIkkbR98j8r4CPAT8e1iDJhiTzSeb3Pv9az+4kSUsZOcyTvBd4rqq2LteuqjZV1VxVzR179JpRu5MkLaPPkfk7gMuSPAHcDLw7yT+PpSpJ0qqMHOZV9cdVdVJVnQJcAXy9qj4wtsokSSvmfeaS1ICxTLRVVd8EvjmOvyVJWj2PzCWpAYa5JDXgTTWfufOJS2qVR+aS1ADDXJIaYJhLUgMMc0lqgGEuSQ0wzCWpAYa5JDXAMJekBhjmktQAw1ySGmCYS1IDDHNJaoBhLkkNMMwlqQGGuSQ14E01n/kbnfOxSxrGI3NJaoBhLkkNMMwlqQGGuSQ1oFeYJzkyya1JHkmyI8nbxlWYJGnl+t7N8tfAv1XV+5McDKwbQ02SpFUaOcyTHAGcD3wQoKpeBl4eT1mSpNXoc5rlVGAv8A9JvpPkuiSHLW6UZEOS+STze59/rUd3kqRh+oT5WuAc4HNVdTbwv8DGxY2qalNVzVXV3LFHr+nRnSRpmD5hvhvYXVVbuse3Mgh3SdKUjRzmVfUM8GSSM7pVFwAPj6UqSdKq9L2b5feAG7s7WR4HPtS/JEnSavUK86raDsyNpxRJ0qj8BqgkNcAwl6QGTHU+80fvX9drTm7n45akpXlkLkkNMMwlqQGGuSQ1wDCXpAYY5pLUAMNckhpgmEtSAwxzSWqAYS5JDTDMJakBhrkkNcAwl6QGGOaS1ADDXJIaYJhLUgOmOp+59GbVZx5/aSU8MpekBhjmktQAw1ySGmCYS1IDeoV5kj9I8lCSB5PclOTQcRUmSVq5kcM8yYnA7wNzVfWLwBrginEVJklaub6nWdYCP51kLbAOeLp/SZKk1Ro5zKvqKeAvgO8Be4AXq+rOxe2SbEgyn2T+FX40eqWSpKH6nGZZD1wOnAqcAByW5AOL21XVpqqaq6q5gzhk9EolSUP1Oc1yIfCfVbW3ql4BbgPePp6yJEmr0SfMvwecl2RdkgAXADvGU5YkaTX6nDPfAtwKbAMe6P7WpjHVJUlahV4TbVXVJ4FPjqkWSdKI/AaoJDXAMJekBjifuSTtR9/56Dc/vX0sdSzHI3NJaoBhLkkNMMwlqQGGuSQ1wDCXpAYY5pLUAMNckhpgmEtSAwxzSWqAYS5JDTDMJakBhrkkNcAwl6QGGOaS1ADDXJIa4Hzmq9B3TmNJmhSPzCWpAYa5JDXAMJekBhjmktSA/YZ5kuuTPJfkwQXrjkpyV5Kd3e/1ky1TkrSclRyZfwG4eNG6jcDdVXU6cHf3WJI0I/sN86r6FvDCotWXAzd0yzcA7xtvWZKk1Rj1PvPjqmpPt/wMcNywhkk2ABsADmXdiN1JkpbT+wJoVRVQy2zfVFVzVTV3EIf07U6StIRRw/zZJMcDdL+fG19JkqTVGjXMbweu6pavAr48nnIkSaNYya2JNwH/AZyRZHeSjwB/CvxKkp3Ahd1jSdKM7PcCaFVdOWTTBWOuRZI0Ir8BKkkNMMwlqQEZ3Fk4HXNnHVr3bj555Oc7n7ikN6Ov1a1bq2puuTYemUtSAwxzSWqAYS5JDTDMJakBhrkkNcAwl6QGGOaS1ADDXJIaYJhLUgMMc0lqgGEuSQ0wzCWpAYa5JDXAMJekBhjmktSAqc5nnmQv8F/LNDkG+P6UyhmF9Y3uQK4NrK8v6+tnf/X9bFUdu9wfmGqY70+S+f1NwD5L1je6A7k2sL6+rK+fcdTnaRZJaoBhLkkNONDCfNOsC9gP6xvdgVwbWF9f1tdP7/oOqHPmkqTRHGhH5pKkERjmktSAmYR5kouTfDfJriQbl9h+SJJbuu1bkpwypbpOTvKNJA8neSjJNUu0eVeSF5Ns734+MY3aFvT/RJIHur7nl9ieJH/Tjd39Sc6ZYm1nLBiX7UleSnLtojZTHb8k1yd5LsmDC9YdleSuJDu73+uHPPeqrs3OJFdNsb4/T/JI9/p9KcmRQ5677L4wwfo+leSpBa/hpUOeu+z7fIL13bKgtieSbB/y3ImO37A8mdj+V1VT/QHWAI8BpwEHA/cBZy5q8zvA33XLVwC3TKm244FzuuXDgUeXqO1dwFemPW4L+n8COGaZ7ZcCXwUCnAdsmVGda4BnGHzZYWbjB5wPnAM8uGDdnwEbu+WNwKeXeN5RwOPd7/Xd8vop1XcRsLZb/vRS9a1kX5hgfZ8C/mgFr/+y7/NJ1bdo+18Cn5jF+A3Lk0ntf7M4Mj8X2FVVj1fVy8DNwOWL2lwO3NAt3wpckCSTLqyq9lTVtm75B8AO4MRJ9ztmlwP/WAP3AEcmOX4GdVwAPFZVy33jd+Kq6lvAC4tWL9y/bgDet8RT3wPcVVUvVNV/A3cBF0+jvqq6s6pe7R7eA5w07n5Xasj4rcRK3ue9LVdflxm/Adw07n5XYpk8mcj+N4swPxF4csHj3bw+MP+/TbdTvwgcPZXqOt2pnbOBLUtsfluS+5J8NckvTLMuoIA7k2xNsmGJ7SsZ32m4guFvolmOH8BxVbWnW34GOG6JNgfKOH6YwSetpexvX5ikq7vTQNcPOU1wIIzfLwPPVtXOIdunNn6L8mQi+58XQJeQ5C3AF4Frq+qlRZu3MTh1cBbwt8C/Trm8d1bVOcAlwO8mOX/K/e9XkoOBy4B/WWLzrMfvJ9TgM+0BeX9uko8DrwI3Dmkyq33hc8DPAW8F9jA4lXEgupLlj8qnMn7L5ck4979ZhPlTwMkLHp/UrVuyTZK1wBHA89MoLslBDAb+xqq6bfH2qnqpqv6nW74DOCjJMdOorevzqe73c8CXGHycXWgl4ztplwDbqurZxRtmPX6dZ/edeup+P7dEm5mOY5IPAu8FfrN7w7/OCvaFiaiqZ6vqtar6MfD3Q/qd9fitBX4duGVYm2mM35A8mcj+N4sw/zZwepJTuyO4K4DbF7W5Hdh39fb9wNeH7dDj1J1j+zywo6o+M6TNz+w7f5/kXAZjOK1/NIclOXzfMoMLZQ8uanY78NsZOA94ccFHumkZekQ0y/FbYOH+dRXw5SXabAYuSrK+O41wUbdu4pJcDHwMuKyqfjikzUr2hUnVt/AazK8N6Xcl7/NJuhB4pKp2L7VxGuO3TJ5MZv+b1JXc/VzlvZTBld3HgI936/6Ewc4LcCiDj+i7gHuB06ZU1zsZfOS5H9je/VwKfBT4aNfmauAhBlfn7wHePsVxO63r976uhn1jt7C+AJ/txvYBYG7Kr+1hDML5iAXrZjZ+DP6p7AFeYXDe8SMMrr/cDewEvgYc1bWdA65b8NwPd/vgLuBDU6xvF4Pzpfv2wX13dp0A3LHcvjCl+v6p27fuZxBMxy+ur3v8uvf5NOrr1n9h3z63oO1Ux2+ZPJnI/ufX+SWpAV4AlaQGGOaS1ADDXJIaYJhLUgMMc0lqgGEuSQ0wzCWpAf8Hhj0t5qkLj6IAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "BATCH_IDX = 2\n",
    "wrapped_bert = sparse.ApproxSparseBertModel(topk_trainer.model_bert, approx_bert=topk_trainer.approx_bert.module, ks=ksx)\n",
    "wrapped_bert.use_forward_sparse = True\n",
    "wrapped_bert.run_original_attention = False\n",
    "wrapped_bert.sparse_bert.load_state_dict(topk_trainer.model.bert.state_dict(), strict=False)\n",
    "wrapped_bert.to('cpu')\n",
    "\n",
    "output = wrapped_bert(\n",
    "    input_ids=tokenized_result['input_ids'], \n",
    "    attention_mask=tokenized_result['attention_mask'],\n",
    "    output_attentions=True\n",
    ")\n",
    "imshow(torch.mean(output['attentions'][0][BATCH_IDX], dim=0))\n",
    "\n",
    "masks = []\n",
    "for layer in wrapped_bert.sparse_bert.encoder.layer:\n",
    "    indices = layer.output.dense.channel_indices\n",
    "    mask = torch.zeros(tokenized_result['input_ids'].shape[0], tokenized_result['input_ids'].shape[1], device='cpu', dtype=torch.float32)\\\n",
    "        .scatter_(1, indices.squeeze(-1), 1.0)\n",
    "    masks.append(mask[BATCH_IDX])\n",
    "imshow(torch.stack(masks, dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Attention back-tracking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concrete masking"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.12 ('torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "58c896f8fe28377dc6f47dbc9814b9367447c8ff4b1090ace6962dd6db7d2533"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
