{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import trainer.glue_base as glue_base\n",
    "import models.sparse_token as sparse\n",
    "import pickle, importlib\n",
    "importlib.reload(glue_base)\n",
    "importlib.reload(sparse)\n",
    "Glue = glue_base.GlueAttentionApproxTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: mrpc\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|██████████| 516/516 [00:00<00:00, 889kB/s]\n",
      "Downloading: 100%|██████████| 418M/418M [00:17<00:00, 24.5MB/s] \n",
      "Reusing dataset glue (/home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9433d17ba13b5e5f.arrow\n",
      "Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-a94cfdbaf3fde127.arrow\n",
      "Reusing dataset glue (/home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7814f5729032a0b8.arrow\n",
      "Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-1fdda5bad27ee52a.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: Save checkpoint path saves/glue-mrpc-16.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:03<00:00, 60.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8440579710144928, 'f1': 0.8865457612821593}\n",
      "avg occupy 0.5399536083129833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:57<00:00,  3.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.34840579710144925, 'f1': 0.04421768707482993}\n",
      "avg occupy 0.5399536083129833\n",
      "1/8 | mrpc 0.1 = 0.34841\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [01:00<00:00,  3.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.5797101449275363, 'f1': 0.5777518928363424}\n",
      "avg occupy 0.5399536083129833\n",
      "2/8 | mrpc 0.25 = 0.57971\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [01:01<00:00,  3.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.7878260869565218, 'f1': 0.8319559228650139}\n",
      "avg occupy 0.5399536083129833\n",
      "3/8 | mrpc 0.375 = 0.78783\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval:   5%|▌         | 11/216 [00:03<01:08,  3.00it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb Cell 2'\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb#ch0000001?line=26'>27</a>\u001b[0m \u001b[39mfor\u001b[39;00m ks \u001b[39min\u001b[39;00m kss:\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb#ch0000001?line=27'>28</a>\u001b[0m     sparse\u001b[39m.\u001b[39mbenchmark_reset()\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb#ch0000001?line=28'>29</a>\u001b[0m     sparse_score, _ \u001b[39m=\u001b[39m get_score(trainer\u001b[39m.\u001b[39;49meval_sparse_model(ks\u001b[39m=\u001b[39;49mks, use_forward\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m))\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb#ch0000001?line=29'>30</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(ks, \u001b[39mstr\u001b[39m) \u001b[39mand\u001b[39;00m ks\u001b[39m.\u001b[39mstartswith(\u001b[39m'\u001b[39m\u001b[39mdynamic\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/user/library/discrete_edge_learning/plot_benchmark_glue_forward.ipynb#ch0000001?line=30'>31</a>\u001b[0m         est_k \u001b[39m=\u001b[39m sparse\u001b[39m.\u001b[39mbenchmark_get_average(\u001b[39m'\u001b[39m\u001b[39mest_k\u001b[39m\u001b[39m'\u001b[39m)\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/trainer/glue_base.py:231\u001b[0m, in \u001b[0;36mGlueAttentionApproxTrainer.eval_sparse_model\u001b[0;34m(self, ks, use_forward)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=227'>228</a>\u001b[0m sparse_cls_bert\u001b[39m.\u001b[39mbert \u001b[39m=\u001b[39m wrapped_bert\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=228'>229</a>\u001b[0m sparse_cls_bert\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice)\u001b[39m.\u001b[39meval()\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=230'>231</a>\u001b[0m sparse_result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meval_base_model(model \u001b[39m=\u001b[39;49m sparse_cls_bert)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=231'>232</a>\u001b[0m \u001b[39mreturn\u001b[39;00m sparse_result\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/trainer/glue_base.py:182\u001b[0m, in \u001b[0;36mGlueAttentionApproxTrainer.eval_base_model\u001b[0;34m(self, model, amp)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=178'>179</a>\u001b[0m \u001b[39mdel\u001b[39;00m batch[\u001b[39m'\u001b[39m\u001b[39mlabels\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=180'>181</a>\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad(), torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mamp\u001b[39m.\u001b[39mautocast(enabled\u001b[39m=\u001b[39mamp):\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=181'>182</a>\u001b[0m     outputs \u001b[39m=\u001b[39m model(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mbatch)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=182'>183</a>\u001b[0m predictions \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/trainer/glue_base.py?line=184'>185</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset \u001b[39m!=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mstsb\u001b[39m\u001b[39m'\u001b[39m: \n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py:1545\u001b[0m, in \u001b[0;36mBertForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1536'>1537</a>\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1537'>1538</a>\u001b[0m \u001b[39mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1538'>1539</a>\u001b[0m \u001b[39m    Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1539'>1540</a>\u001b[0m \u001b[39m    config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1540'>1541</a>\u001b[0m \u001b[39m    `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1541'>1542</a>\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1542'>1543</a>\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1544'>1545</a>\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbert(\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1545'>1546</a>\u001b[0m     input_ids,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1546'>1547</a>\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1547'>1548</a>\u001b[0m     token_type_ids\u001b[39m=\u001b[39;49mtoken_type_ids,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1548'>1549</a>\u001b[0m     position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1549'>1550</a>\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1550'>1551</a>\u001b[0m     inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1551'>1552</a>\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1552'>1553</a>\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1553'>1554</a>\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1554'>1555</a>\u001b[0m )\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1556'>1557</a>\u001b[0m pooled_output \u001b[39m=\u001b[39m outputs[\u001b[39m1\u001b[39m]\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py?line=1558'>1559</a>\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout(pooled_output)\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:1106\u001b[0m, in \u001b[0;36mApproxSparseBertModel.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1098'>1099</a>\u001b[0m     output \u001b[39m=\u001b[39m run_bert_with_approx(\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1099'>1100</a>\u001b[0m         sparse_bert \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msparse_bert,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1100'>1101</a>\u001b[0m         approx_bert \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mapprox_bert,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1101'>1102</a>\u001b[0m         input_dict  \u001b[39m=\u001b[39m kwargs,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1102'>1103</a>\u001b[0m         ks          \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mks,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1103'>1104</a>\u001b[0m     )\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1104'>1105</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1105'>1106</a>\u001b[0m     output \u001b[39m=\u001b[39m run_bert_forward_sparsity(\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1106'>1107</a>\u001b[0m         sparse_bert \u001b[39m=\u001b[39;49m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msparse_bert,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1107'>1108</a>\u001b[0m         input_dict  \u001b[39m=\u001b[39;49m kwargs,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1108'>1109</a>\u001b[0m         ks          \u001b[39m=\u001b[39;49m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mks,\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1109'>1110</a>\u001b[0m     )\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1110'>1111</a>\u001b[0m \u001b[39mreturn\u001b[39;00m output\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:1021\u001b[0m, in \u001b[0;36mrun_bert_forward_sparsity\u001b[0;34m(sparse_bert, input_dict, ks)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1017'>1018</a>\u001b[0m layer\u001b[39m.\u001b[39mattention\u001b[39m.\u001b[39mself\u001b[39m.\u001b[39mkey\u001b[39m.\u001b[39mchannel_indices \u001b[39m=\u001b[39m indices_unsqueeze\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1018'>1019</a>\u001b[0m layer\u001b[39m.\u001b[39mattention\u001b[39m.\u001b[39mself\u001b[39m.\u001b[39mvalue\u001b[39m.\u001b[39mchannel_indices \u001b[39m=\u001b[39m indices_unsqueeze\n\u001b[0;32m-> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1020'>1021</a>\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad(): sparse_bert(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49minput_dict)\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1022'>1023</a>\u001b[0m last_att \u001b[39m=\u001b[39m layer\u001b[39m.\u001b[39mattention\u001b[39m.\u001b[39mself\u001b[39m.\u001b[39mlast_attention_probs \u001b[39m#(N, H, T, T)\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=1023'>1024</a>\u001b[0m impact_factor \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mmean(last_att, dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m) \u001b[39m#reduce head\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:842\u001b[0m, in \u001b[0;36mSparseBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=832'>833</a>\u001b[0m head_mask \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_head_mask(head_mask, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=834'>835</a>\u001b[0m embedding_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membeddings(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=835'>836</a>\u001b[0m     input_ids\u001b[39m=\u001b[39minput_ids,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=836'>837</a>\u001b[0m     position_ids\u001b[39m=\u001b[39mposition_ids,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=839'>840</a>\u001b[0m     past_key_values_length\u001b[39m=\u001b[39mpast_key_values_length,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=840'>841</a>\u001b[0m )\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=841'>842</a>\u001b[0m encoder_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=842'>843</a>\u001b[0m     embedding_output,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=843'>844</a>\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mextended_attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=844'>845</a>\u001b[0m     head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=845'>846</a>\u001b[0m     encoder_hidden_states\u001b[39m=\u001b[39;49mencoder_hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=846'>847</a>\u001b[0m     encoder_attention_mask\u001b[39m=\u001b[39;49mencoder_extended_attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=847'>848</a>\u001b[0m     past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=848'>849</a>\u001b[0m     use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=849'>850</a>\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=850'>851</a>\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=851'>852</a>\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=852'>853</a>\u001b[0m )\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=853'>854</a>\u001b[0m sequence_output \u001b[39m=\u001b[39m encoder_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=854'>855</a>\u001b[0m pooled_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler(sequence_output) \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpooler \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:641\u001b[0m, in \u001b[0;36mBertEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=631'>632</a>\u001b[0m     layer_outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=632'>633</a>\u001b[0m         create_custom_forward(layer_module),\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=633'>634</a>\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=637'>638</a>\u001b[0m         encoder_attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=638'>639</a>\u001b[0m     )\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=639'>640</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=640'>641</a>\u001b[0m     layer_outputs \u001b[39m=\u001b[39m layer_module(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=641'>642</a>\u001b[0m         hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=642'>643</a>\u001b[0m         attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=643'>644</a>\u001b[0m         layer_head_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=644'>645</a>\u001b[0m         encoder_hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=645'>646</a>\u001b[0m         encoder_attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=646'>647</a>\u001b[0m         past_key_value,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=647'>648</a>\u001b[0m         output_attentions,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=648'>649</a>\u001b[0m     )\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=650'>651</a>\u001b[0m hidden_states \u001b[39m=\u001b[39m layer_outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=651'>652</a>\u001b[0m \u001b[39mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:552\u001b[0m, in \u001b[0;36mBertLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=539'>540</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=540'>541</a>\u001b[0m     \u001b[39mself\u001b[39m,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=541'>542</a>\u001b[0m     hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=548'>549</a>\u001b[0m ):\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=549'>550</a>\u001b[0m     \u001b[39m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=550'>551</a>\u001b[0m     self_attn_past_key_value \u001b[39m=\u001b[39m past_key_value[:\u001b[39m2\u001b[39m] \u001b[39mif\u001b[39;00m past_key_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=551'>552</a>\u001b[0m     self_attention_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=552'>553</a>\u001b[0m         hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=553'>554</a>\u001b[0m         attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=554'>555</a>\u001b[0m         head_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=555'>556</a>\u001b[0m         output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=556'>557</a>\u001b[0m         past_key_value\u001b[39m=\u001b[39;49mself_attn_past_key_value,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=557'>558</a>\u001b[0m     )\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=558'>559</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mself_attention_outputs \u001b[39m=\u001b[39m self_attention_outputs\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=559'>560</a>\u001b[0m     attention_output \u001b[39m=\u001b[39m self_attention_outputs[\u001b[39m0\u001b[39m]\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:476\u001b[0m, in \u001b[0;36mBertAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=465'>466</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=466'>467</a>\u001b[0m     \u001b[39mself\u001b[39m,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=467'>468</a>\u001b[0m     hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=473'>474</a>\u001b[0m     output_attentions\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=474'>475</a>\u001b[0m ):\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=475'>476</a>\u001b[0m     self_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mself(\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=476'>477</a>\u001b[0m         hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=477'>478</a>\u001b[0m         attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=478'>479</a>\u001b[0m         head_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=479'>480</a>\u001b[0m         encoder_hidden_states,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=480'>481</a>\u001b[0m         encoder_attention_mask,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=481'>482</a>\u001b[0m         past_key_value,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=482'>483</a>\u001b[0m         output_attentions,\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=483'>484</a>\u001b[0m     )\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=484'>485</a>\u001b[0m     timer_start(\u001b[39m'\u001b[39m\u001b[39mbert.attention.output\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=485'>486</a>\u001b[0m     attention_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput(self_outputs[\u001b[39m0\u001b[39m], hidden_states)\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py:1102\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1097'>1098</a>\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1098'>1099</a>\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1099'>1100</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1100'>1101</a>\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1101'>1102</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1102'>1103</a>\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py?line=1103'>1104</a>\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/models/sparse_token.py:364\u001b[0m, in \u001b[0;36mBertSelfAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=361'>362</a>\u001b[0m timer_start(\u001b[39m'\u001b[39m\u001b[39mbert.attention.scores.matmul\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=362'>363</a>\u001b[0m \u001b[39m# Take the dot product between \"query\" and \"key\" to get the raw attention scores.\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=363'>364</a>\u001b[0m attention_scores \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mmatmul(query_layer, key_layer\u001b[39m.\u001b[39;49mtranspose(\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m, \u001b[39m-\u001b[39;49m\u001b[39m2\u001b[39;49m))\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=364'>365</a>\u001b[0m timer_end(\u001b[39m'\u001b[39m\u001b[39mbert.attention.scores.matmul\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/models/sparse_token.py?line=365'>366</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mposition_embedding_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mrelative_key\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mposition_embedding_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mrelative_key_query\u001b[39m\u001b[39m\"\u001b[39m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "subsets = [\"cola\",\"mnli\",\"mrpc\",\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\",\"wnli\",]\n",
    "subsets = [\"mnli\",\"mrpc\",\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\"]\n",
    "subsets = [\"mrpc\"]\n",
    "kss = [\n",
    "    0.1, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 0.999,\n",
    "]\n",
    "ks = [0.1, 0.25, 0.375, 0.5,]\n",
    "sparse.benchmark_reset()\n",
    "# subsets = [\"mrpc\"]\n",
    "# kss = ['dynamic:avg:avg:f',0.1]\n",
    "\n",
    "def get_score(score):\n",
    "    if 'accuracy' in score:\n",
    "        return score['accuracy'], \"acc\"\n",
    "    first_metric = list(score.keys())[0]\n",
    "    return score[first_metric], first_metric\n",
    "\n",
    "results = {}\n",
    "i = 0\n",
    "for subset in subsets:\n",
    "    trainer = Glue(dataset=subset, factor=16, batch_size=-1, device=0)\n",
    "    trainer.load()\n",
    "    scores = {}\n",
    "    metric_name = \"\"\n",
    "    bert_score, metric_name = get_score(trainer.eval_base_model())\n",
    "    scores['bert'] = f'{bert_score:.5f}'\n",
    "    for ks in kss:\n",
    "        sparse.benchmark_reset()\n",
    "        sparse_score, _ = get_score(trainer.eval_sparse_model(ks=ks, use_forward=True))\n",
    "        if isinstance(ks, str) and ks.startswith('dynamic'):\n",
    "            est_k = sparse.benchmark_get_average('est_k')\n",
    "            scores[str(ks)] = f'{sparse_score:.5f} (k:{est_k:.2f})'\n",
    "        else:\n",
    "            scores[str(ks)] = f'{sparse_score:.5f}'\n",
    "        i += 1\n",
    "        count = len(subsets) * len(kss)\n",
    "        print(f'{i}/{count} | {subset} {ks} = {sparse_score:.5f}')\n",
    "    results[f\"{subset} ({metric_name})\"] = scores\n",
    "\n",
    "with open('glue_benchmark_forward.pkl', 'wb') as f:\n",
    "    pickle.dump(results, f)\n",
    "\n",
    "sparse.benchmark_report()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "drop cola\n",
      "drop wnli\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_123192/3086650456.py:48: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  tex = df.to_latex()\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mnli (acc)</th>\n",
       "      <th>mrpc (acc)</th>\n",
       "      <th>qnli (acc)</th>\n",
       "      <th>qqp (acc)</th>\n",
       "      <th>rte (acc)</th>\n",
       "      <th>sst2 (acc)</th>\n",
       "      <th>stsb (pearson)</th>\n",
       "      <th>reproduce</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>bert</th>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>0.91543</td>\n",
       "      <td>0.90908</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.92431</td>\n",
       "      <td>0.88047</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.1</th>\n",
       "      <td>0.49638</td>\n",
       "      <td>0.33681</td>\n",
       "      <td>0.62328</td>\n",
       "      <td>0.66876</td>\n",
       "      <td>0.54874</td>\n",
       "      <td>0.66284</td>\n",
       "      <td>0.20566</td>\n",
       "      <td>58.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.25</th>\n",
       "      <td>0.78686</td>\n",
       "      <td>0.41797</td>\n",
       "      <td>0.85210</td>\n",
       "      <td>0.88078</td>\n",
       "      <td>0.66426</td>\n",
       "      <td>0.83142</td>\n",
       "      <td>0.79653</td>\n",
       "      <td>86.41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.375</th>\n",
       "      <td>0.83209</td>\n",
       "      <td>0.58841</td>\n",
       "      <td>0.90445</td>\n",
       "      <td>0.90566</td>\n",
       "      <td>0.68953</td>\n",
       "      <td>0.87385</td>\n",
       "      <td>0.86073</td>\n",
       "      <td>93.47</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.5</th>\n",
       "      <td>0.83984</td>\n",
       "      <td>0.71420</td>\n",
       "      <td>0.91323</td>\n",
       "      <td>0.90846</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.90711</td>\n",
       "      <td>0.87449</td>\n",
       "      <td>97.36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.625</th>\n",
       "      <td>0.84147</td>\n",
       "      <td>0.80464</td>\n",
       "      <td>0.91506</td>\n",
       "      <td>0.90903</td>\n",
       "      <td>0.72202</td>\n",
       "      <td>0.91858</td>\n",
       "      <td>0.87796</td>\n",
       "      <td>99.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.75</th>\n",
       "      <td>0.84187</td>\n",
       "      <td>0.83942</td>\n",
       "      <td>0.91525</td>\n",
       "      <td>0.90915</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.92202</td>\n",
       "      <td>0.87986</td>\n",
       "      <td>99.87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.875</th>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>0.91525</td>\n",
       "      <td>0.90913</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.92431</td>\n",
       "      <td>0.88047</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.999</th>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>0.91543</td>\n",
       "      <td>0.90908</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.92431</td>\n",
       "      <td>0.88047</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      mnli (acc) mrpc (acc) qnli (acc) qqp (acc) rte (acc) sst2 (acc)  \\\n",
       "bert     0.84208    0.84406    0.91543   0.90908   0.72563    0.92431   \n",
       "0.1      0.49638    0.33681    0.62328   0.66876   0.54874    0.66284   \n",
       "0.25     0.78686    0.41797    0.85210   0.88078   0.66426    0.83142   \n",
       "0.375    0.83209    0.58841    0.90445   0.90566   0.68953    0.87385   \n",
       "0.5      0.83984    0.71420    0.91323   0.90846   0.72563    0.90711   \n",
       "0.625    0.84147    0.80464    0.91506   0.90903   0.72202    0.91858   \n",
       "0.75     0.84187    0.83942    0.91525   0.90915   0.72563    0.92202   \n",
       "0.875    0.84208    0.84406    0.91525   0.90913   0.72563    0.92431   \n",
       "0.999    0.84208    0.84406    0.91543   0.90908   0.72563    0.92431   \n",
       "\n",
       "      stsb (pearson) reproduce  \n",
       "bert         0.88047    100.00  \n",
       "0.1          0.20566     58.74  \n",
       "0.25         0.79653     86.41  \n",
       "0.375        0.86073     93.47  \n",
       "0.5          0.87449     97.36  \n",
       "0.625        0.87796     99.12  \n",
       "0.75         0.87986     99.87  \n",
       "0.875        0.88047    100.00  \n",
       "0.999        0.88047    100.00  "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "\n",
    "with open('glue_benchmark_forward.pkl', 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "\n",
    "data = []\n",
    "subsets = list(results.keys())\n",
    "factors = list(results[subsets[0]].keys())\n",
    "for factor in factors:             \n",
    "    row = []\n",
    "    for subset in subsets:\n",
    "        row.append(results[subset][factor])\n",
    "    data.append(row)\n",
    "\n",
    "#remove cola and wnli\n",
    "if subsets[0] == \"cola (matthews_correlation)\":\n",
    "    for line in data:\n",
    "        line.pop(0)\n",
    "    subsets.pop(0)\n",
    "    print(\"drop cola\")\n",
    "\n",
    "if subsets[-1] == \"wnli (acc)\":\n",
    "    for line in data:\n",
    "        line.pop(-1)\n",
    "    subsets.pop(-1)\n",
    "    print(\"drop wnli\")\n",
    "\n",
    "#calculate reproducibility\n",
    "data_scalar = []\n",
    "for line in data:\n",
    "    xs = []\n",
    "    for item in line:\n",
    "        xs.append(float(item.split()[0]))\n",
    "    data_scalar.append(xs)\n",
    "reproducibilities = []\n",
    "for i in range(len(data_scalar)):\n",
    "    rsum = 0\n",
    "    for k in range(len(data_scalar[i])):\n",
    "        rsum += data_scalar[i][k]/data_scalar[0][k]\n",
    "    rsum /= len(data_scalar[i])\n",
    "    reproducibilities.append(rsum)\n",
    "for i, r in enumerate(reproducibilities):\n",
    "    data[i].append(f\"{r*100:.2f}\")\n",
    "subsets.append(\"reproduce\")\n",
    "\n",
    "df = pd.DataFrame(data, columns=subsets, index=factors)\n",
    "tex = df.to_latex()\n",
    "with open('saves_plot/glue_benchmark_forward.tex', 'w') as f:\n",
    "    f.write(tex)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: qnli\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-95cdbfccba10767f.arrow\n",
      "Loading cached processed dataset at C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-a121d057d8b0504c.arrow\n",
      "Reusing dataset glue (C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-81af1b1d0da19e67.arrow\n",
      "Loading cached processed dataset at C:\\Users\\user\\.cache\\huggingface\\datasets\\glue\\qnli\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-efc6aadc78835551.arrow\n"
     ]
    }
   ],
   "source": [
    "trainer = Glue(dataset='qnli', factor=16, batch_size=-1, device=0)\n",
    "trainer.load() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 683/683 [00:25<00:00, 26.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.9154310818231741}\n",
      "avg occupy 0.30243502260079336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9154310818231741}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.model.bert = trainer.model_bert\n",
    "trainer.eval_base_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 683/683 [00:36<00:00, 18.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.5996705107084019}\n",
      "avg occupy 0.30243502260079336\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.5996705107084019}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# trainer.seed()\n",
    "# import models.sparse_token as sparse\n",
    "# import transformers.models.bert.modeling_bert as berts\n",
    "# import importlib\n",
    "# importlib.reload(sparse)\n",
    "\n",
    "# wrapped_bert = sparse.ApproxSparseBertModel(trainer.model_bert, approx_bert=trainer.approx_bert, ks=0.1)\n",
    "# sparse_cls_bert = berts.BertForSequenceClassification(trainer.model_bert.config)\n",
    "# sparse_cls_bert.load_state_dict(trainer.model.state_dict())\n",
    "# sparse_cls_bert.bert = wrapped_bert\n",
    "# sparse_cls_bert.to(trainer.device).eval()\n",
    "\n",
    "# trainer.eval_base_model(model = sparse_cls_bert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f7b3ac0126d0d6fea024471ce24e510948bf6332f7ae1a66cdcb4ee9887514e9"
  },
  "kernelspec": {
   "display_name": "Python 3.8.3 64-bit ('tensorflow': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
