{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, random, math, time, sys, os, tqdm\n",
    "import numpy as np\n",
    "import numba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 120000/120000 [00:00<00:00, 209422.79it/s]\n",
      "100%|██████████| 7600/7600 [00:00<00:00, 205411.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:1012, avg_len:236.477525, count:120000\n",
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:892, avg_len:235.2992105263158, count:7600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at google/bert_uncased_L-12_H-768_A-12 were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer.__init__: Model initialized. model = bert-base\n",
      "Trainer.load: Loading... saves/cls_bert-base.pth\n"
     ]
    }
   ],
   "source": [
    "from trainer.classification import Trainer\n",
    "from trainer.attention_approx import Trainer as ApproxTrainer\n",
    "batch_size = 4\n",
    "device = 0\n",
    "factor = 16\n",
    "\n",
    "trainer = Trainer(batch_size=batch_size, model='bert-base', device=device)\n",
    "trainer.load()\n",
    "trainer.model.eval()\n",
    "bert = trainer.model.bert\n",
    "fc = trainer.model.classifier\n",
    "batch = trainer.get_batch()\n",
    "test_batch = trainer.get_batch(test=False)\n",
    "\n",
    "approx_trainer = ApproxTrainer(batch_size=batch_size, factor=factor, model=trainer.model_type, device=trainer.device)\n",
    "approx_trainer.load()\n",
    "approx_bert = approx_trainer.bert\n",
    "approx_bert = approx_bert.eval()\n",
    "print('approx trained', approx_trainer.steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'models.sparse_token' from 'f:\\\\Library\\\\discrete_edge_learning\\\\models\\\\sparse_token.py'>"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import importlib\n",
    "import models.sparse_token as sparse\n",
    "importlib.reload(sparse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(trainer.device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "apply input mask, before softmax. input_mask:torch.Size([4, 71]), attention_scores:torch.Size([4, 4, 71, 71])\n",
      "SelfAttention.forward: last_attention_probs backuped\n",
      "apply input mask, before softmax. input_mask:torch.Size([4, 71]), attention_scores:torch.Size([4, 4, 71, 71])\n",
      "SelfAttention.forward: last_attention_probs backuped\n",
      "apply input mask, before softmax. input_mask:torch.Size([4, 71]), attention_scores:torch.Size([4, 4, 71, 71])\n",
      "SelfAttention.forward: last_attention_probs backuped\n",
      "apply input mask, before softmax. input_mask:torch.Size([4, 71]), attention_scores:torch.Size([4, 4, 71, 71])\n",
      "SelfAttention.forward: last_attention_probs backuped\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "((tensor([1, 2, 4, 1], device='cuda:1'),\n",
       "  tensor([1, 2, 4, 1], device='cuda:1')),\n",
       " (tensor([1, 2, 4, 1], device='cuda:1'),\n",
       "  tensor([1, 2, 4, 1], device='cuda:1')))"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval_fc(lm_output, fc=fc, batch=batch):\n",
    "    last_hidden = lm_output.last_hidden_state[:,0,:]\n",
    "    x = fc(last_hidden)\n",
    "    return torch.argmax(x, dim=-1), batch.labels, lm_output\n",
    "\n",
    "def eval(bert, fc=fc, batch=batch):\n",
    "    lm_output = bert(\n",
    "        input_ids = batch.input_ids, \n",
    "        attention_mask = batch.attention_masks,\n",
    "        output_hidden_states = True,\n",
    "        output_attentions = True,\n",
    "    )\n",
    "    return eval_fc(lm_output, fc=fc, batch=batch)\n",
    "\n",
    "def approx_eval(sparse_bert, approx_bert, fc=fc, batch=batch):\n",
    "    lm_output = sparse.run_bert_with_approx(\n",
    "        sparse_bert, \n",
    "        approx_bert, \n",
    "        {\n",
    "            'input_ids': batch.input_ids,\n",
    "            'attention_mask': batch.attention_masks,\n",
    "            'output_hidden_states': True,\n",
    "            'output_attentions': True,\n",
    "        },\n",
    "        ks = [0.25]*len(sparse_bert.encoder.layer),\n",
    "    )\n",
    "    return eval_fc(lm_output, fc=fc, batch=batch)\n",
    "    \n",
    "eval(bert)[:2], approx_eval(sparse_bert, approx_bert)[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 475/475 [00:07<00:00, 61.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bert.attention.output        : 0:00:00.480983 (2.51%)\n",
      "bert.attention.probs.dropout : 0:00:00.022029 (0.12%)\n",
      "bert.attention.qkv           : 0:00:00.922955 (4.82%)\n",
      "bert.attention.scores.matmul : 0:00:00.188991 (0.99%)\n",
      "bert.intermediate            : 0:00:00.359040 (1.88%)\n",
      "bert.output                  : 0:00:00.447039 (2.34%)\n",
      "eval                         : 0:00:06.368620 (33.29%)\n",
      "eval.approx_att_bert         : 0:00:02.271654 (11.87%)\n",
      "eval.reset_mask              : 0:00:00.130015 (0.68%)\n",
      "eval.reset_mask.convert_mask : 0:00:00.038006 (0.20%)\n",
      "eval.sparse_bert             : 0:00:03.411039 (17.83%)\n",
      "eval.update_mask             : 0:00:00.548912 (2.87%)\n",
      "sparselinear                 : 0:00:01.653979 (8.65%)\n",
      "sparselinear.gather          : 0:00:00.372953 (1.95%)\n",
      "sparselinear.linear          : 0:00:00.578048 (3.02%)\n",
      "sparselinear.scatter_        : 0:00:00.634994 (3.32%)\n",
      "update_mask                  : 0:00:00.358924 (1.88%)\n",
      "update_mask.reduce           : 0:00:00.185955 (0.97%)\n",
      "update_mask.topk             : 0:00:00.156970 (0.82%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.9106578947368421, 0.9106578947368421, False)"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importlib.reload(sparse)\n",
    "sparse.benchmark_reset()\n",
    "sparse.timer_reset()\n",
    "\n",
    "def accuracy(batch_eval, N=7600//16, return_lm=False):\n",
    "    #N = 10\n",
    "    trainer.seed()\n",
    "    trainer.dataset.batch_size = 16\n",
    "    acc_sum = 0\n",
    "    for i in tqdm.tqdm(range(N)):\n",
    "        batch = trainer.get_batch(test=True)\n",
    "        with torch.no_grad():\n",
    "            output, label, _ = batch_eval(batch)\n",
    "        acc_sum += torch.mean((output == label) * 1.0)\n",
    "    if return_lm: return acc_sum.item() / N, _\n",
    "    return acc_sum.item() / N\n",
    "\n",
    "# setup for evaluation\n",
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(trainer.device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse_bert = sparse_bert.to(trainer.device)\n",
    "approx_bert = approx_bert.to(trainer.device)\n",
    "bert = bert.to(trainer.device)\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse.timer_reset()\n",
    "acc_approx, lm = accuracy(\n",
    "    lambda batch: approx_eval(sparse_bert, approx_bert, batch=batch),\n",
    "    return_lm = True,\n",
    ")\n",
    "acc_bert = acc_approx\n",
    "#acc_bert = accuracy(lambda batch: eval(bert, batch=batch))\n",
    "sparse.timer_report()\n",
    "sparse.benchmark_report()\n",
    "acc_bert, acc_approx, abs(acc_approx - 0.9236111111111112) < 0.001\n",
    "\n",
    "#bert-mini 4 0.25\n",
    "#sum, sum = 0.9217105263157894\n",
    "#sum, max = 0.9236842105263158\n",
    "#max, sum = 0.9223684210526316\n",
    "#max, max = 0.9221052631578948\n",
    "#no token impact = 0.9239473684210526 V\n",
    "\n",
    "#bert-mini 8 0.25\n",
    "#sum, sum = 0.9178947368421052 V\n",
    "#sum, max = 0.9111842105263158\n",
    "#max, sum = 0.9157894736842105\n",
    "#max, max = 0.9102631578947369\n",
    "#no token impact = 0.9146052631578947\n",
    "\n",
    "#bert-mini 16 0.25\n",
    "#sum, sum = 0.910921052631579\n",
    "#sum, max = 0.91\n",
    "#max, sum = 0.9134210526315789 V\n",
    "#max, max = 0.9117105263157895\n",
    "#no token impact = 0.9106578947368421"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD6CAYAAABnLjEDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAATEklEQVR4nO3dfWyd1X0H8O/3XtuxkxA7TkIwdtaEkg4C5WXLgI6pqpIipWnV5A9WgaopkyJFmroJWKU2bNOkapsE1Qb0j65TVFAzqS0wyhQEbBNL03VTSyAkQIEAMSlgBwcDscm74+v72x/3CfE557Hvte+7z/cjWfF57vPy842/fu55Xs5DM4OIzH2ZehcgIrWhsItEQmEXiYTCLhIJhV0kEgq7SCTKCjvJDSTfINlPcnulihKRyuNsz7OTzAJ4E8AtAAYBPA/gdjN7bapl2jjP2rHgwjpaW8KZJvJO06+P2ZS/T96PYBMTYb0tWf8H8BZKeR/y3rSMv+2UZVrcnynf5m43M+7+fABgWa+Wk2ecZtr7ZOO5cNsCtrU6bTs3PvOVzO9w26fPpM/XgM7iFM7ZGNNeS0lbyW4A0G9mhwGA5MMANgGYMuztWIAbuf7CxpcuD+bJnzjptC3n/lJnFi5AwAvlxMhIMEt28RKnzVbvl2I8/KUw7z+ZC7xt58M/Klja7TTPrFrstNuH3J8PAHKL2p125n8POO2WZZeEywwdDbctaLl0hdPOvTMw43Xw6s86bXv+N2XVVEt7bfeUr5XzMb4XwOR3cjCZJiINqJw9e0lIbgOwDQDaMb/amxORKZQT9iMAJn9m6kumOcxsB4AdALCI3c7n7dz7w0U3wqzX1/b70QDyJ08VXY/581x0kbudlO4BW6Z/eyaOHw+mZZa4H9vHutz6OwbCPvvpnnlOe6H3ui3w+pAyJRv5uPx1+MdQ5ohyPsY/D2A1yVUk2wDcBuCJypQlIpU26z27meVI/jmA/wKQBfCQmb1ascpEpKLK6rOb2dMAnq5QLSJSRbqCTiQSVT8aP53g4BsAXHeF07T9B512/sSJYBH/XHwp8t65+M/8x2gwT/8XvbMHrW3eSlIOFr4z6LQ7B4fc11Muhuk84p4z98/eUxfQlCx/5mz562hx94FzZY84V34OESlCYReJhMIuEom69tnT2Avu2buTf3yj01746LOzWm/evyHCu679zZvdC1sA4D9/+z9Oe+Nn1znttBtuMln/Roxzbjvl+EL+zPQXceQ/PDbt63KBjZ8rPlMRLQcOOe3wMqjmpD27SCQUdpFIKOwikahrnz3z6ZXBtPzhd532RY/vc9oDd/9hsEzfPb92J6QNRGFezyvjnuNPO+f/pdU3e+vwzuH66wTABd65+RXevegH3wqWyax078GeeKPfaQ/+2bXBMpf+46+CaQJwnnvsxcbGZr6OXu//7M3w/6wZac8uEgmFXSQSCrtIJBR2kUjU96Kaj8KBIf0LVZhxLzhZcd8LwTJvP3K10/7U18IBAtnm3sSSucgdDyY/mjLCiXfQLhjRleHfyvzKHqd9brE7mKR3K01hvd7NMr6L98/8IFOs/IuYZoPHw0FB5wLt2UUiobCLREJhF4lEXfvsljLQgH9xi9+Hb1naFSzz6bs+ctqpQz34A034T5qZF94I41+QEVx4k3JRTfY9t5YPv3CZ014+tiZY5uxStyc//9/3Ou3MxFy5FaP6MvPdi5ryp4qPPOyzOfp+a88uEgmFXSQSCrtIJBR2kUjUd3TZjvZwmncQLO8dJMsNfxgukyn+uJ6rnnUP2716k3vhRCbtEUv+BT7ehTlIOag3ccy9UKj3+/vddXiPnQKAtnZ3Pf4BxtbXBuFLeX6sYIrHgM9Ud6fb/uCD8tfZALRnF4mEwi4SCYVdJBL17bNf5D+cGLBR9zHIfh8+bUSZTLf7mOTcYPDkaBxc7/aV3/6bq5z2ZQ+4T54Bwj567qpVTrvlzYFwGb8+uscTLOWJNv5It4HO8H2aK/3ISjN/FOFZyC1132++UfYqG4L27CKRUNhFIqGwi0SivjfCpJwTHflJt9Pu3OyeYybDc+q5944G04JteTc3rPy755z28U2/HyyzcJc7UEZm/+tOO5/yRJiA34dPefIrTp2edhVMW0ZSsc19Ig+mf2tTtQy7x43myjUN2rOLREJhF4mEwi4SiaJhJ/kQyWGSr0ya1k3yGZKHkn8XT7cOEak/WtqjkibPQH4ewEkA/2pmVyfTvgvgmJndQ3I7gMVm9u1iG1vEbruR6z9pt6z8nWAeOzbqtK/c496w8sofFL/pJe0iFba4BwODUWxTLtY5ctcNTrv3AfdRVJYLL+AIbpbxDq6lLtPiPebZe+xwJuXmmXzaxTkSXMSU+iiwIvz3u5ne6722G8ftWGpIiu7ZzeyXAPwHhG8CsDP5fieAzeUUKCLVN9tTb8vN7Pxg50cBLJ9qRpLbAGwDgHbMn2o2Eamysg/QWaEfMOVnJTPbYWZrzWxtK8L7v0WkNma7Z3+fZI+ZDZHsATA8m5XYvPD5KOzuctqv3+oe+2NreAFN8KSWFEG/N+v9nUu5gaLvn19ya/nBNU77iruK3yFB72ec+MjvEQH0arHy7+WIVqbDHYQkf3rmV9Wc+fwVTnveU8+XVVOjmO2e/QkAW5LvtwDYVZlyRKRaSjn19lMAvwbwuyQHSW4FcA+AW0geAvDFpC0iDazox3gzu32Kl9ZPMV1EGlB9b4QZeC+Yxs5F7jzjbgc29SmdJZxL5XyvLzcy6rQzi8Jz2XbCPcd/5fbfOu07DzwbLPNPn3H79Txzxm2nnM8vdq2DeeuQqZVy/KaY+YdHnbZuhBGRpqKwi0RCYReJhMIuEonGG13WG7XFP7B2erN7cwoALHrpfaedO/x2uF7vopnggF3KaDHBwTRvtJv7rvq9lGXcg23ZJd4NgSkXEo33uqPz8FfuxTzZZUuDZXJDxUfniRGvdB+RbS+/PsWcUxv3RpedK3vEufJziEgRCrtIJBR2kUjUt8/uD/QAAGnTJlm0L3zaS0nbap/+jjt2pDzFtYjswgUzLyRlpNjWgY+cdnBZSEtd/5uaSmbEvRAqP8V802l7233aTvmX6TQG7dlFIqGwi0RCYReJRH1vhDk7FkzLj4y48/gDQ3qDM5a8Le8Gmrw/iATDv3v+tmHFe4CZ+e7QW/T69Zc/5f58ANC/bvrjBfnRj4tuVwps4cyPvfhOXd3jtOcNDE4xZ3PRnl0kEgq7SCQUdpFIKOwikajv1Rpd4egwOO4+Lte/GSXtkc1onf6JKkD4RBj/gp7gYBwAZrxlvFrSRs3JdHU67VPX9Drt/nVng2Ve/4crnfbqv9jrrnOpe6MM0FxPKaklezcc/WimOo647+1sLsxpRNqzi0RCYReJhMIuEon63ghzNuzzZi9eVvZ6U/uzGffvWmZJ2A+uhgWvDLkTvNFzAeDK77o39wQ3XuTmyvim1ZfpdgcLyZ86NfN1VOBmmkakPbtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRKBp2kitI7iH5GslXSd6RTO8m+QzJQ8m/i4utS0Tqp5Q9ew7AN81sDYCbAHyD5BoA2wHsNrPVAHYnbRFpUEXDbmZDZrY/+f4EgIMAegFsArAzmW0ngM1VqlFEKmBGfXaSKwFcD2AvgOVmdv7+zaMAlle2NBGppJLDTnIhgJ8BuNPMnIHizMwAhE8sLCy3jeQ+kvvGET4UQkRqo6Swk2xFIeg/NrPHk8nvk+xJXu8BMJy2rJntMLO1Zra2FdM/SVVEqqeUo/EE8CCAg2Z236SXngCwJfl+C4BdlS9PRCqllGGpbgbwJwB+Q/LFZNpfAbgHwKMktwJ4B8DXqlKhiFRE0bCb2f8BSBmsHQCwvrLliEi16Ao6kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIlE07CTbST5H8iWSr5L8TjJ9Fcm9JPtJPkKyrfrlishslbJnHwOwzsyuBXAdgA0kbwJwL4D7zexyACMAtlatShEpW9GwW8HJpNmafBmAdQAeS6bvBLC5GgWKSGWU1GcnmSX5IoBhAM8AeAvAqJnlklkGAfROsew2kvtI7hvHWAVKFpHZKCnsZjZhZtcB6ANwA4ArSt2Ame0ws7VmtrYV82ZXpYiUbUZH481sFMAeAJ8D0EWyJXmpD8CRypYmIpVUytH4ZSS7ku87ANwC4CAKob81mW0LgF1VqlFEKqCl+CzoAbCTZBaFPw6PmtmTJF8D8DDJvwdwAMCDVaxTRMpUNOxm9jKA61OmH0ah/y4iTUBX0IlEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIlHI/e9XY6dPhtHPjbvvMGXcGhn+fLDceTAvmOXnK2845d7V9PeEyA++5E/J5t3ku3G6mo91d7/z57gydC8PiRo+H0ybX8fH0r8sF535Ep51ZP/N1TFzc6U4YGCyjosahPbtIJBR2kUgo7CKRqGufPejPAqA/qaszmKeYXEofiwsXuG24bZw+Gy6zpHva7czqL+VYyvGFjo5pF2HnonDicfXj07T9qTnt3BTzTSc7/HHZ62hE2rOLREJhF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCYVdJBL1HbzCG1ACAPKXXeq0syPuPPn3jgbLpN4o4m/LHyjj4iVuOzcRLnT8pLuOYACM8KaW7CUXO+2jG1c47UueejesbdGCYNpk+dGPp31dLggGO5mF/EfHKlBJ49GeXSQSCrtIJEoOO8ksyQMkn0zaq0juJdlP8hGSbdUrU0TKNZM++x0ADgI430G+F8D9ZvYwyX8BsBXAD2aycX9ACSAcOMCXWbZ0Jpu4sC1/oIyTJfTt2ty/X+ye+d+zS54e8AphMA9PhANvTpZJGcAjf+LEjGuJAYsMBFKKjDdoST5lYNRmVNKenWQfgC8D+GHSJoB1AB5LZtkJYHMV6hORCin1Y/wDAL4F4PxYyksAjJrZ+RF7BgH0pi1IchvJfST3jWOsnFpFpAxFw07yKwCGzeyF2WzAzHaY2VozW9uKebNZhYhUQCl99psBfJXkRgDtKPTZvwegi2RLsnfvA3CkemWKSLmK7tnN7G4z6zOzlQBuA/BzM/s6gD0Abk1m2wJgV9WqFJGylXOe/dsA/pJkPwp9+AcrU5KIVMOMLpc1s18A+EXy/WEAN1S+JBGpBl1BJxIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFI0MxqtzHyAwDvAFgK4MOabbg8zVQr0Fz1NlOtQHPU+ykzW5b2Qk3D/slGyX1mtrbmG56FZqoVaK56m6lWoPnq9eljvEgkFHaRSNQr7DvqtN3ZaKZageaqt5lqBZqvXkdd+uwiUnv6GC8SiZqGneQGkm+Q7Ce5vZbbLgXJh0gOk3xl0rRuks+QPJT8u7ieNZ5HcgXJPSRfI/kqyTuS6Y1abzvJ50i+lNT7nWT6KpJ7k9+JR0i21bvW80hmSR4g+WTSbthaS1GzsJPMAvg+gC8BWAPgdpJrarX9Ev0IwAZv2nYAu81sNYDdSbsR5AB808zWALgJwDeS97NR6x0DsM7MrgVwHYANJG8CcC+A+83scgAjALbWr8TAHQAOTmo3cq1F1XLPfgOAfjM7bGbnADwMYFMNt1+Umf0SwDFv8iYAO5PvdwLYXMuapmJmQ2a2P/n+BAq/lL1o3HrNzE4mzdbkywCsA/BYMr1h6iXZB+DLAH6YtIkGrbVUtQx7L4CBSe3BZFqjW25mQ8n3RwEsr2cxaUiuBHA9gL1o4HqTj8UvAhgG8AyAtwCMmlkumaWRficeAPAtAPmkvQSNW2tJdIBuBqxw6qKhTl+QXAjgZwDuNLPjk19rtHrNbMLMrgPQh8InvSvqW1E6kl8BMGxmL9S7lkpqqeG2jgBYMandl0xrdO+T7DGzIZI9KOyVGgLJVhSC/mMzezyZ3LD1nmdmoyT3APgcgC6SLckes1F+J24G8FWSGwG0A1gE4HtozFpLVss9+/MAVidHNNsA3AbgiRpuf7aeALAl+X4LgF11rOUTSR/yQQAHzey+SS81ar3LSHYl33cAuAWF4wx7ANyazNYQ9ZrZ3WbWZ2YrUfg9/bmZfR0NWOuMmFnNvgBsBPAmCn21v67ltkus76cAhgCMo9An24pCX203gEMA/htAd73rTGr9IxQ+or8M4MXka2MD13sNgANJva8A+Ntk+mUAngPQD+DfAMyrd61e3V8A8GQz1FrsS1fQiURCB+hEIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKR+H8E/oEyaA8Q6AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.6138e-02, 8.9100e-03, 5.0388e-04, 6.5075e-04, 3.6405e-01, 4.1655e-02,\n",
      "         1.7010e-03],\n",
      "        [2.0344e-01, 1.6438e-01, 1.4774e-03, 3.3319e-05, 5.7084e-03, 1.3540e-02,\n",
      "         4.2704e-03]], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "def plot_grid(grid):\n",
    "    plt.imshow(grid.cpu().detach().numpy())\n",
    "    #plt.colorbar()\n",
    "    plt.show()\n",
    "plot_grid(lm.attentions[-4][0][0][:50, :50])\n",
    "print(lm.attentions[-2][0][0][:2, :7])\n",
    "#tensor([[0.0246, 0.0012, 0.0206, 0.0003, 0.0000, 0.0002, 0.0049]],"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark(eval, batch_size=8, N=100, WARM=20, amp=False, device=trainer.device, end_warm=None):\n",
    "    assert WARM < (N * 0.33)\n",
    "    trainer.dataset.batch_size = batch_size\n",
    "    batch = trainer.get_batch(test=True)\n",
    "    batch = batch.to(device)\n",
    "    assert batch.input_ids.shape[0] == batch_size\n",
    "    for i in tqdm.tqdm(range(N)):\n",
    "        if i == WARM: \n",
    "            t = time.time()\n",
    "            if not end_warm is None: end_warm()\n",
    "        with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp):\n",
    "            eval(batch)\n",
    "    t = time.time() - t\n",
    "    t_item = t / (batch_size * (N-WARM))\n",
    "    return t, t_item * 1000, 1.0/t_item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▍      | 87/250 [00:17<00:33,  4.89it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-85-4c3b0d51e8bf>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     14\u001b[0m \u001b[0msparse_bert\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msparse_bert\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbenchmark_device\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     15\u001b[0m \u001b[0mapprox_bert\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapprox_bert\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbenchmark_device\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 16\u001b[1;33m time_approx = benchmark(\n\u001b[0m\u001b[0;32m     17\u001b[0m     \u001b[0meval\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mlambda\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mapprox_eval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msparse_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapprox_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfc\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     18\u001b[0m     \u001b[0mbatch_size\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbenchmark_batch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-81-289d4bd01514>\u001b[0m in \u001b[0;36mbenchmark\u001b[1;34m(eval, batch_size, N, WARM, amp, device, end_warm)\u001b[0m\n\u001b[0;32m     10\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mend_warm\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mend_warm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menabled\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mamp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m             \u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     13\u001b[0m     \u001b[0mt\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     14\u001b[0m     \u001b[0mt_item\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m \u001b[1;33m/\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mWARM\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-85-4c3b0d51e8bf>\u001b[0m in \u001b[0;36m<lambda>\u001b[1;34m(batch)\u001b[0m\n\u001b[0;32m     15\u001b[0m \u001b[0mapprox_bert\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mapprox_bert\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbenchmark_device\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m time_approx = benchmark(\n\u001b[1;32m---> 17\u001b[1;33m     \u001b[0meval\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mlambda\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mapprox_eval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msparse_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapprox_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfc\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     18\u001b[0m     \u001b[0mbatch_size\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbenchmark_batch_size\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m     \u001b[0mWARM\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m30\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-78-4969bf8071f1>\u001b[0m in \u001b[0;36mapprox_eval\u001b[1;34m(sparse_bert, approx_bert, fc, batch)\u001b[0m\n\u001b[0;32m     14\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     15\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mapprox_eval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msparse_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mapprox_bert\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfc\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 16\u001b[1;33m     lm_output = sparse.run_bert_with_approx(\n\u001b[0m\u001b[0;32m     17\u001b[0m         \u001b[0msparse_bert\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     18\u001b[0m         \u001b[0mapprox_bert\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mrun_bert_with_approx\u001b[1;34m(sparse_bert, approx_bert, input_dict, ks)\u001b[0m\n\u001b[0;32m    912\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    913\u001b[0m     \u001b[0mtimer_start\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'eval.sparse_bert'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 914\u001b[1;33m     \u001b[0mret_sparse\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msparse_bert\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0minput_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    915\u001b[0m     \u001b[0mtimer_end\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'eval.sparse_bert'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    916\u001b[0m     \u001b[0mtimer_end\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'eval'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m    792\u001b[0m             \u001b[0mpast_key_values_length\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpast_key_values_length\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    793\u001b[0m         )\n\u001b[1;32m--> 794\u001b[1;33m         encoder_outputs = self.encoder(\n\u001b[0m\u001b[0;32m    795\u001b[0m             \u001b[0membedding_output\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    796\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mextended_attention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m    580\u001b[0m                 )\n\u001b[0;32m    581\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 582\u001b[1;33m                 layer_outputs = layer_module(\n\u001b[0m\u001b[0;32m    583\u001b[0m                     \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    584\u001b[0m                     \u001b[0mattention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m    491\u001b[0m         \u001b[1;31m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    492\u001b[0m         \u001b[0mself_attn_past_key_value\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mpast_key_value\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m         self_attention_outputs = self.attention(\n\u001b[0m\u001b[0;32m    494\u001b[0m             \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    495\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m    415\u001b[0m         \u001b[0moutput_attentions\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    416\u001b[0m     ):\n\u001b[1;32m--> 417\u001b[1;33m         self_outputs = self.self(\n\u001b[0m\u001b[0;32m    418\u001b[0m             \u001b[0mhidden_states\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    419\u001b[0m             \u001b[0mattention_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m    309\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    310\u001b[0m             \u001b[0mkey_layer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose_for_scores\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 311\u001b[1;33m             \u001b[0mvalue_layer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose_for_scores\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    312\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    313\u001b[0m         \u001b[0mquery_layer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtranspose_for_scores\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmixed_query_layer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1102\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1103\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mf:\\Library\\discrete_edge_learning\\models\\sparse_token.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    204\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    205\u001b[0m             \u001b[0mtimer_start\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'sparselinear.scatter_'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 206\u001b[1;33m             x = torch.zeros(N, C, self.out_features, \n\u001b[0m\u001b[0;32m    207\u001b[0m                 dtype=x.dtype, device=x.device).\\\n\u001b[0;32m    208\u001b[0m                     \u001b[0mscatter_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mchannel_indices_unsqueeze\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexpand\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msrc\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "importlib.reload(sparse)\n",
    "sparse.timer_reset()\n",
    "benchmark_device = 'cuda'\n",
    "benchmark_batch_size = 128\n",
    "\n",
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(benchmark_device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse_bert=sparse_bert.to(benchmark_device)\n",
    "approx_bert=approx_bert.to(benchmark_device)\n",
    "time_approx = benchmark(\n",
    "    eval = lambda batch: approx_eval(sparse_bert, approx_bert, batch=batch, fc=lambda x: x),\n",
    "    batch_size = benchmark_batch_size,\n",
    "    WARM = 30,\n",
    "    N = 250,\n",
    "    device = benchmark_device,\n",
    "    end_warm = lambda: sparse.timer_reset()\n",
    ")\n",
    "sparse.timer_report()\n",
    "time_approx\n",
    "\n",
    "# bert.attention.output        : 0:00:00.545988 (3.04%)\n",
    "# bert.attention.probs.dropout : 0:00:00.006996 (0.04%)\n",
    "# bert.attention.qkv           : 0:00:01.473990 (8.20%)\n",
    "# bert.attention.scores.matmul : 0:00:00.113033 (0.63%)\n",
    "# bert.intermediate            : 0:00:02.113082 (11.75%)\n",
    "# bert.output                  : 0:00:01.848977 (10.28%)\n",
    "# sparselinear.linear          : 0:00:03.801073 (21.14%)\n",
    "\n",
    "# bert.attention.output        : 0:00:01.103962 (2.75%)\n",
    "# bert.attention.probs.dropout : 0:00:00.010004 (0.02%)\n",
    "# bert.attention.qkv           : 0:00:02.440944 (6.07%)\n",
    "# bert.attention.scores.matmul : 0:00:00.534112 (1.33%)\n",
    "# bert.intermediate            : 0:00:02.811061 (6.99%)\n",
    "# bert.output                  : 0:00:02.548062 (6.34%)\n",
    "# sparselinear                 : 0:00:07.770034 (19.33%)\n",
    "# sparselinear.gather          : 0:00:01.343004 (3.34%)\n",
    "# sparselinear.linear          : 0:00:04.463333 (11.11%)\n",
    "# sparselinear.scatter_        : 0:00:01.947706 (4.85%)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25/25 [00:01<00:00, 18.39it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1.343578815460205, 0.4771231588992205, 2095.8949096227443)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert = bert.to(benchmark_device)\n",
    "time_bert = benchmark(\n",
    "    lambda batch: eval(bert, batch=batch, fc=lambda x: x), \n",
    "    batch_size = benchmark_batch_size,\n",
    "    WARM = 3,\n",
    "    N=25,\n",
    "    device = benchmark_device\n",
    ")\n",
    "time_bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],\n",
       "\n",
       "        [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.ones(2,2).unsqueeze(-1).repeat(1, 1, 12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertLayer(\n",
       "  (attention): BertAttention(\n",
       "    (self): BertSelfAttention(\n",
       "      (query): Linear(in_features=96, out_features=96, bias=True)\n",
       "      (key): Linear(in_features=96, out_features=96, bias=True)\n",
       "      (value): Linear(in_features=96, out_features=96, bias=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (output): BertSelfOutput(\n",
       "      (dense): Linear(in_features=96, out_features=96, bias=True)\n",
       "      (LayerNorm): LayerNorm((96,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "  )\n",
       "  (intermediate): BertIntermediate(\n",
       "    (dense): Linear(in_features=96, out_features=3072, bias=True)\n",
       "  )\n",
       "  (output): BertOutput(\n",
       "    (dense): Linear(in_features=3072, out_features=96, bias=True)\n",
       "    (LayerNorm): LayerNorm((96,), eps=1e-12, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "approx_bert.encoder.layer[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f7b3ac0126d0d6fea024471ce24e510948bf6332f7ae1a66cdcb4ee9887514e9"
  },
  "kernelspec": {
   "display_name": "Python 3.8.3 64-bit ('tensorflow': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
