{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/user/anaconda3/envs/torch/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<module 'models.sparse_token' from '/home/user/library/discrete_edge_learning/models/sparse_token.py'>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch, random, math, time, sys, os, tqdm\n",
    "import numpy as np\n",
    "import numba\n",
    "import importlib\n",
    "import models.sparse_token as sparse\n",
    "importlib.reload(sparse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 120000/120000 [00:00<00:00, 398319.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:1012, avg_len:236.477525, count:120000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7600/7600 [00:00<00:00, 381784.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:892, avg_len:235.2992105263158, count:7600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at google/bert_uncased_L-12_H-768_A-12 were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer.__init__: Model initialized. model = bert-base\n",
      "Trainer.load: Loading... saves/cls_bert-base.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 120000/120000 [00:00<00:00, 388140.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:1012, avg_len:236.477525, count:120000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7600/7600 [00:00<00:00, 390249.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification Dataset Stat.: name:AG_NEWS, nclass:5, max_len:892, avg_len:235.2992105263158, count:7600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at google/bert_uncased_L-12_H-768_A-12 were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer.__init__: Model initialized. model = bert-base\n",
      "Trainer.load: Loading... saves/cls_bert-base.pth\n",
      "Trainer.load: saves/att_approx_16_bert-base.pth\n",
      "approx trained 150000\n"
     ]
    }
   ],
   "source": [
    "from trainer.classification import Trainer\n",
    "from trainer.attention_approx import Trainer as ApproxTrainer\n",
    "batch_size = 4\n",
    "device = 0\n",
    "factor = 16\n",
    "\n",
    "trainer = Trainer(batch_size=batch_size, model='bert-base', device=device)\n",
    "trainer.load()\n",
    "trainer.model.eval()\n",
    "bert = trainer.model.bert\n",
    "fc = trainer.model.classifier\n",
    "batch = trainer.get_batch()\n",
    "test_batch = trainer.get_batch(test=False)\n",
    "\n",
    "approx_trainer = ApproxTrainer(batch_size=batch_size, factor=factor, model=trainer.model_type, device=trainer.device)\n",
    "approx_trainer.load()\n",
    "approx_bert = approx_trainer.bert\n",
    "approx_bert = approx_bert.eval()\n",
    "print('approx trained', approx_trainer.steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(trainer.device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([1, 2, 4, 1], device='cuda:0'),\n",
       "  tensor([1, 2, 4, 1], device='cuda:0')),\n",
       " (tensor([1, 2, 4, 1], device='cuda:0'),\n",
       "  tensor([1, 2, 4, 1], device='cuda:0')))"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importlib.reload(sparse)\n",
    "sparse.benchmark_reset()\n",
    "sparse.timer_reset()\n",
    "\n",
    "def eval_fc(lm_output, fc=fc, batch=batch):\n",
    "    last_hidden = lm_output.last_hidden_state[:,0,:]\n",
    "    x = fc(last_hidden)\n",
    "    return torch.argmax(x, dim=-1), batch.labels, lm_output\n",
    "\n",
    "def eval(bert, fc=fc, batch=batch):\n",
    "    lm_output = bert(\n",
    "        input_ids = batch.input_ids, \n",
    "        attention_mask = batch.attention_masks,\n",
    "        output_hidden_states = True,\n",
    "        output_attentions = True,\n",
    "    )\n",
    "    return eval_fc(lm_output, fc=fc, batch=batch)\n",
    "\n",
    "def approx_eval(sparse_bert, approx_bert, fc=fc, batch=batch, k=0.5):\n",
    "    lm_output = sparse.run_bert_with_approx(\n",
    "        sparse_bert, \n",
    "        approx_bert, \n",
    "        {\n",
    "            'input_ids': batch.input_ids,\n",
    "            'attention_mask': batch.attention_masks,\n",
    "            'output_hidden_states': True,\n",
    "            'output_attentions': True,\n",
    "        },\n",
    "        ks = [k]*(len(sparse_bert.encoder.layer)),\n",
    "    )\n",
    "    return eval_fc(lm_output, fc=fc, batch=batch)\n",
    "\n",
    "def forward_eval(sparse_bert, fc=fc, batch=batch, k=0.5):\n",
    "    lm_output = sparse.run_bert_forward_sparsity(\n",
    "        sparse_bert, \n",
    "        {\n",
    "            'input_ids': batch.input_ids,\n",
    "            'attention_mask': batch.attention_masks,\n",
    "            'output_hidden_states': True,\n",
    "            'output_attentions': True,\n",
    "        },\n",
    "        ks = k,\n",
    "    )\n",
    "    return eval_fc(lm_output, fc=fc, batch=batch)\n",
    "\n",
    "eval(bert)[:2], forward_eval(sparse_bert)[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2871, 0.1998, 0.0007,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.2825, 0.2313, 0.0010,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.0885, 0.0806, 0.0056,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        ...,\n",
       "        [0.5495, 0.2041, 0.0011,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.5037, 0.2072, 0.0011,  ..., 0.0000, 0.0000, 0.0000],\n",
       "        [0.4037, 0.2040, 0.0038,  ..., 0.0000, 0.0000, 0.0000]],\n",
       "       device='cuda:0', grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_output = sparse.run_bert_forward_sparsity(\n",
    "    sparse_bert, \n",
    "    {\n",
    "        'input_ids': batch.input_ids,\n",
    "        'attention_mask': batch.attention_masks,\n",
    "        'output_hidden_states': True,\n",
    "        'output_attentions': True,\n",
    "    },\n",
    "    ks = 0.5,\n",
    ")\n",
    "lm_output.attentions[0][0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 83%|████████▎ | 124/150 [00:02<00:00, 43.67it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/user/library/discrete_edge_learning/forward_only.ipynb Cell 6'\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=34'>35</a>\u001b[0m sparse\u001b[39m.\u001b[39mtimer_reset()\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=35'>36</a>\u001b[0m k \u001b[39m=\u001b[39m \u001b[39m0.25\u001b[39m\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=36'>37</a>\u001b[0m acc, lm \u001b[39m=\u001b[39m accuracy(\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=37'>38</a>\u001b[0m     \u001b[39mlambda\u001b[39;49;00m batch: approx_eval(sparse_bert, bert, batch\u001b[39m=\u001b[39;49mbatch, k\u001b[39m=\u001b[39;49mk),\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=38'>39</a>\u001b[0m     return_lm \u001b[39m=\u001b[39;49m \u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=39'>40</a>\u001b[0m )\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=40'>41</a>\u001b[0m acc_, lm \u001b[39m=\u001b[39m accuracy(\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=41'>42</a>\u001b[0m     \u001b[39mlambda\u001b[39;00m batch: forward_eval(sparse_bert, batch\u001b[39m=\u001b[39mbatch, k\u001b[39m=\u001b[39mk),\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=42'>43</a>\u001b[0m     return_lm \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=43'>44</a>\u001b[0m )\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=44'>45</a>\u001b[0m \u001b[39m#acc = accuracy(lambda batch: eval(bert, batch=batch))\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=45'>46</a>\u001b[0m \u001b[39m# sparse.timer_report()\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=46'>47</a>\u001b[0m \u001b[39m# sparse.benchmark_report()\u001b[39;00m\n",
      "\u001b[1;32m/home/user/library/discrete_edge_learning/forward_only.ipynb Cell 6'\u001b[0m in \u001b[0;36maccuracy\u001b[0;34m(batch_eval, N, return_lm)\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=9'>10</a>\u001b[0m acc_sum \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=10'>11</a>\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m tqdm\u001b[39m.\u001b[39mtqdm(\u001b[39mrange\u001b[39m(N)):\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=11'>12</a>\u001b[0m     batch \u001b[39m=\u001b[39m trainer\u001b[39m.\u001b[39;49mget_batch(test\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=12'>13</a>\u001b[0m     \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Blab-desk/home/user/library/discrete_edge_learning/forward_only.ipynb#ch0000004vscode-remote?line=13'>14</a>\u001b[0m         output, label, _ \u001b[39m=\u001b[39m batch_eval(batch)\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/trainer/classification.py:56\u001b[0m, in \u001b[0;36mTrainer.get_batch\u001b[0;34m(self, test)\u001b[0m\n\u001b[1;32m     <a href='file:///home/user/library/discrete_edge_learning/trainer/classification.py?line=54'>55</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_batch\u001b[39m(\u001b[39mself\u001b[39m, test\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[0;32m---> <a href='file:///home/user/library/discrete_edge_learning/trainer/classification.py?line=55'>56</a>\u001b[0m     batch \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset\u001b[39m.\u001b[39;49mbatch(test\u001b[39m=\u001b[39;49mtest)\n\u001b[1;32m     <a href='file:///home/user/library/discrete_edge_learning/trainer/classification.py?line=56'>57</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m batch\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice)\n",
      "File \u001b[0;32m~/library/discrete_edge_learning/dataset/classification_dataset.py:108\u001b[0m, in \u001b[0;36mClassificationDataset.batch\u001b[0;34m(self, test)\u001b[0m\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/dataset/classification_dataset.py?line=104'>105</a>\u001b[0m     labels\u001b[39m.\u001b[39mappend(idx)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/dataset/classification_dataset.py?line=105'>106</a>\u001b[0m     texts\u001b[39m.\u001b[39mappend(text)\n\u001b[0;32m--> <a href='file:///home/user/library/discrete_edge_learning/dataset/classification_dataset.py?line=107'>108</a>\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtokenizer(texts, padding\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, truncation\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, return_tensors\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mpt\u001b[39;49m\u001b[39m'\u001b[39;49m, max_length\u001b[39m=\u001b[39;49m\u001b[39m512\u001b[39;49m)\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/dataset/classification_dataset.py?line=109'>110</a>\u001b[0m entry \u001b[39m=\u001b[39m ClassificationBatchEntry()\n\u001b[1;32m    <a href='file:///home/user/library/discrete_edge_learning/dataset/classification_dataset.py?line=110'>111</a>\u001b[0m entry\u001b[39m.\u001b[39mlabels \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor(labels, dtype\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mint64)\n",
      "File \u001b[0;32m~/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:2434\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.__call__\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2430'>2431</a>\u001b[0m     is_batched \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(text, (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m))\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2432'>2433</a>\u001b[0m \u001b[39mif\u001b[39;00m is_batched:\n\u001b[0;32m-> <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2433'>2434</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39;49m(text_pair, \u001b[39mstr\u001b[39;49m):\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2434'>2435</a>\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2435'>2436</a>\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39mwhen tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2436'>2437</a>\u001b[0m         )\n\u001b[1;32m   <a href='file:///home/user/anaconda3/envs/torch/lib/python3.8/site-packages/transformers/tokenization_utils_base.py?line=2437'>2438</a>\u001b[0m     \u001b[39mif\u001b[39;00m text_pair \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(text) \u001b[39m!=\u001b[39m \u001b[39mlen\u001b[39m(text_pair):\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "importlib.reload(sparse)\n",
    "sparse.benchmark_reset()\n",
    "sparse.timer_reset()\n",
    "\n",
    "#7600 is original\n",
    "def accuracy(batch_eval, N=600//4, return_lm=False):\n",
    "    #N = 10\n",
    "    trainer.seed()\n",
    "    trainer.dataset.batch_size = 4\n",
    "    acc_sum = 0\n",
    "    for i in tqdm.tqdm(range(N)):\n",
    "        batch = trainer.get_batch(test=True)\n",
    "        with torch.no_grad():\n",
    "            output, label, _ = batch_eval(batch)\n",
    "        acc_sum += torch.mean((output == label) * 1.0)\n",
    "    if return_lm: return acc_sum.item() / N, _\n",
    "    return acc_sum.item() / N\n",
    "\n",
    "# setup for evaluation\n",
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(trainer.device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse_bert = sparse_bert.to(trainer.device)\n",
    "approx_bert = approx_bert.to(trainer.device)\n",
    "bert = bert.to(trainer.device)\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse.timer_reset()\n",
    "k = 0.25\n",
    "acc, lm = accuracy(\n",
    "    lambda batch: approx_eval(sparse_bert, bert, batch=batch, k=k),\n",
    "    return_lm = True,\n",
    ")\n",
    "acc_, lm = accuracy(\n",
    "    lambda batch: forward_eval(sparse_bert, batch=batch, k=k),\n",
    "    return_lm = True,\n",
    ")\n",
    "#acc = accuracy(lambda batch: eval(bert, batch=batch))\n",
    "# sparse.timer_report()\n",
    "# sparse.benchmark_report()\n",
    "acc_, acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#0.15 79         45.5\n",
    "#0.25 86.69      80.6184\n",
    "#0.5  91.4736    91.802     0.5:0.91842 0.2:0.91855 0.1:0.9186 0:0.9180"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 150/150 [00:01<00:00, 129.70it/s]\n",
      "100%|██████████| 150/150 [00:03<00:00, 44.28it/s]\n",
      "100%|██████████| 150/150 [00:21<00:00,  6.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1 0.5203836930455636 0.3616666666666667 0.695\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 150/150 [00:03<00:00, 44.49it/s]\n",
      "100%|██████████| 150/150 [00:21<00:00,  6.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.25 0.9214145383104124 0.7816666666666666 0.8483333333333334\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 150/150 [00:03<00:00, 44.33it/s]\n",
      "100%|██████████| 150/150 [00:21<00:00,  6.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.375 0.9816513761467891 0.8916666666666667 0.9083333333333333\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 150/150 [00:03<00:00, 44.41it/s]\n",
      "100%|██████████| 150/150 [00:21<00:00,  6.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5 1.0018281535648994 0.9133333333333333 0.9116666666666666\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 150/150 [00:03<00:00, 44.41it/s]\n",
      "100%|██████████| 150/150 [00:22<00:00,  6.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.75 1.0072332730560578 0.9283333333333333 0.9216666666666666\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtTklEQVR4nO3dd3hUZfrG8e+TSe9AQk1C0aBUKZEiBJGVFVcFUXcFC4usi+WHa921rLLYy7K2tWLZxYqKDVlcywImFBHQgBAEQg81nUBImcz7++MMGEJiBpjMmUyez3XlYubMm5k7h+TOmzOniDEGpZRSTV+Q3QGUUkp5hxa6UkoFCC10pZQKEFroSikVILTQlVIqQATb9cIJCQmmU6dOdr28Uko1SStXrsw3xiTW9Zhthd6pUydWrFhh18srpVSTJCLb6ntMN7kopVSA0EJXSqkAoYWulFIBwrZt6HWpqqoiNzeX8vJyu6M0eeHh4SQlJRESEmJ3FKWUj/hVoefm5hITE0OnTp0QEbvjNFnGGAoKCsjNzaVz5852x1FK+YhfbXIpLy+nVatWWuYnSURo1aqV/qWjVDPjV4UOaJl7ia5HpZofv9rkopRSAaG6Cg4Vw6EiKHf/e/j+oSLoeh506Of1l9VCr+XZZ5/lxRdfpF+/frz99tu25Zg2bRrR0dHccccdtmVQqlkzBipK6y/lY5YV/7ys8sAvP3d0ay10X3jhhRf4+uuvSUpKanCs0+kkOPjkV6ExBmMMQUF+twVMqabPWVmjfBsq5VrLTHX9z+sIg4gWEBFv/RufDOG9jl4W0QITHkdFSBxljhjKJIYDQZG0joumZSN8qVroNVx//fVs3ryZ888/n4kTJ5KZmcnmzZuJjIxkxowZ9O7dm2nTprFp0yY2b95MSkoK+/fv59FHH6V379707duXsWPHMnXqVKZOnUpycjLjx49nzJgxFBUVUVVVxUMPPcSYMWPYunUr5513HgMHDmTlypXMmzePt956i5kzZ9K6dWuSk5Pp37+/3atEKf/gckFl6bGlXFdRl5ccvazq4C88sUB4HETEYyJa4AqPxxmTRFVIHBUhsZQ7YikLjuVgUDQHJIYSoigx0RSZKPZXB1NWUc3BSieHKqs5eKiasmInByurOVRp/VtW4aSsqhJj8oC8I6/68NieXDmwo9dXk98W+v2frSV7136vPmf39rH87aIe9T7+0ksv8d///pcFCxZw//3307dvXz755BPmz5/PhAkTyMrKAiA7O5tFixYRERHBY489RmZmJh07diQ4OJjFixcDkJmZyUsvvUR4eDgff/wxsbGx5OfnM2jQIEaPHg3Axo0bmTlzJoMGDWLlypXMmjWLrKwsnE4n/fr100JXgcdZ4WEp11pWXgzGVe/TuhzhOEPjqAyNozw4lvLgRMqiTuFgTDSlQTHsN9EUE0WRK4pCVyT51ZHsc0ZSUBlKaZWhrKiasr3VVLsauiRnhfujgNDgIKJCHUSGBhMZ6iAyLJioUAfxkSFEhgYTFVbjsVr3e3WI8876rMVvC91uixYt4sMPPwRgxIgRFBQUsH+/9Qtm9OjRREREAJCens6zzz5L586dueCCC/jqq68oKytjy5YtnHbaaVRVVXHPPfeQkZFBUFAQO3fuZO/evQB07NiRQYMGAdYvgLFjxxIZGXnkNZTyW8ZYZVu6Bw7sqWNzRjGUF2PKCnG5l0l5EUHO+neldRFEuSOaMkcsB4KiKSWa/ZJCselGYUgUhdVR5FdHsM8ZQWF1FMVEU2yi2E8UFYRCPRPx4CAhKuxwsTqO3I6KCiYx1EFUaDARoY4jhXukpMOsxyJr3w9zEBniINjhf5tI/bbQf2kmbbeoqKgjt88880xWrFhBly5dGDlyJPn5+bzyyitHZtdvv/02eXl5rFy5kpCQEDp16nRk//Caz6MUgMtlcLoM1S5DtTFUVxucLhfVNZZb/7pwugzOaoPL1His+vCYYz/HVetzjx7vXl5dTWhlMRHleURU5BNZmUdUZT5RlXlEVxYQU5VPjLOAWGcBwaaqzq/hEGGUuMvW+oimxLSm2L25ogT3MtyPE81+E8VBiSAyNJRIh4Oo4MPFWbtYHXQJC6ZXqIOIw+XrnhlHuMu59sw4NNj/irex+G2h2y09PZ23336b++67j4ULF5KQkEBsbOwx40JDQ0lOTuaDDz5g6tSp5OXlcccddxzZO6WkpITWrVsTEhLCggUL2Lat7jNfDhs2jIkTJ3L33XfjdDr57LPPuO666xr1a1RHO1jhZNmWAr7fVkx5VbVVqIfLrrqOMjymLGuVaB3l6nLVLuajHzcN/cV/ggQXLSmljRTRWoppLUW0ppjWUkyHGssSKSZUjn0jsIRoCohnr7RkXdDpFAa3pCioJUWOVpQEtaAyzNr+THgcIWERNWayPxds2zAHnUNqzJDDgomocT8sOEiPnzhJHhW6iIwCngEcwKvGmMdqPd4ReB1IBAqBq4wxuV7O6lPTpk1j0qRJ9O7dm8jISGbOnFnv2PT0dP73v/8RERFBeno6ubm5pKenA3DllVdy0UUX0atXL9LS0jj99NPrfI5+/fpx+eWXc8YZZ9C6dWvOPPPMRvm61M9cLkP27v1kbMwjc0M+K7YVUlVtcAQJYcFBOIKE4CDBERTk/ldqLHPfdhz9eIgjiPCQWmOCgo76POtzjl4eVPPxWq/583hruSOInx8XF5GVRURUWDPq8PI8wsv3EVq+j9CyfYQc2kdI2T6CD+UhLuex6yCiJSa6DcQkQ/SZBMW0hdh2EN0GYtpBTBuIbkNcSARxQBff/zep4yCmgSmBiDiADcBIIBdYDow3xmTXGPMBMNcYM1NERgDXGGOu/qXnTUtLM7UvcLFu3Tq6det2Ql+IOpauz2Pt219O5sZ8MjbmsWhjPgUHKwE4vW0MZ3dNJD01kbROLQgPcdgbtNoJB/Os7dOlNT4O7IHSvVC6Gw7shQP76t61LrKVVcg1i/nI/bbWR3QbCA7z/demToqIrDTGpNX1mCcz9AFAjjFms/vJZgFjgOwaY7oDt7lvLwA+OeG0SnlReVU1y7cWWiW+IY+f9pQC0CoqlPTUBIZ1TWToqQm0jg33TaDqKquEaxdz6e6j7x/Mq2OvDoGoBHcZt4W2PWuVtruoo1pDcKhvvh7lVzwp9A7Ajhr3c4GBtcasAi7B2iwzFogRkVbGmIKag0RkMjAZICUl5UQzK1UvYwwb9x0gY0MeGRvzWba5gAqni1BHEGmdWnDnqNNJT02ge7tYgoK8uL22uspdzIdn07uPvn94pn0wH6j9V7FYRw4eLub2feqYTbe1xjj0dMiqft56U/QO4DkRmQhkADuBY/4ONMbMAGaAtcnFS6+tmrnCg5Usysknc0MemRvz2bPf2ovolMQoxg9I4eyuiQzs0pLI0BP4dndWuIv5F2bTpXugLP/Yz5Ugq5Sj20BsB+jQ3yrmmLZHF3VUIjh0/wR18jz5LtoJJNe4n+RedoQxZhfWDB0RiQYuNcYUeymjUkepdLr4YXvRkW3hP+4swRiIiwhh6KkJpKcmMDQ1gaQWkcf3xC4X/DQXfngTSnKtoj5UeOw4cfw8e45PgeQB7qJuc/QmkKgECLJ5W7xqVjwp9OVAqoh0xiryccAVNQeISAJQaIxxAXdj7fGilFcYY9hWUEbGxjwyNuSzdFM+ByurcQQJfZPjueVXXRnWNYHeSfE4TmQzijHw039g4WOw90erpNv2hpTBdb+hGJkAet4d5YcaLHRjjFNEpgBfYO22+LoxZq2IPACsMMbMAYYDj4qIwdrk8n+NmFk1A/vLq1iSU2DtUrgxjx2FhwBIbhnBxX07kJ6ayFmntiI2/CS2KRsD6z+HhY/CntXQsguMfRl6XqabQFST5NF3rTFmHjCv1rKpNW7PBmZ7N5o9tm7dyoUXXsiaNWtO+DkWLlzI9OnTmTt3rheT1W/ixIlceOGFXHbZZT55vcZQ7TKszi0mY0M+mRvz+GFHMdUuQ1Sog8GnJPDH9C4MS02kY6vIkz/4xBjY8IVV5LuzoEVnuPhF6PU7LXLVpOl3bxPjrVP2+oOdxYfI3JBHxsY8FucUUHKoChHo1SGOG84+hfTUBPp1bEGIt86ZYQzkfA0LHoFd30N8RxjzPPS+XPceUQEhMJrBy5xOJ1deeSXff/89PXr04I033mD69Ol89tlnHDp0iLPOOouXX34ZESEnJ4frr7+evLw8HA4HH3zwwVHPtXz5ciZPnszs2bO5+OKLyczMJC4ujoSEBJ566ikmTJjAhAkTuPrqq0lNTeXqq6/m4EHrLEPPPfccZ511FgsXLuS+++6jRYsW/PTTT6xfv56bbrqJr776iuTkZEJDm8Y+x2WVTr7dXHBkFr4pz/o628SG8evubUh37xPeMsrLX48xsOl/sOBR2LkC4lLgomehzxVa5Cqg+G+hf34X7PnRu8/Zthec/1iDw9avX89rr73GkCFDmDRpEi+88AJTpkxh6lRrK9PVV1/N3Llzueiii7jyyiu56667GDt2LOXl5bhcLnbssHbbX7JkCTfddBOffvopKSkpDBkyhMWLF9OxY0e6dOlCZmYmEyZMYOnSpbz44ouICF999RXh4eFs3LiR8ePHc/ho2u+//541a9bQuXNnPvroI9avX092djZ79+6le/fuTJo0ybvrygsOH1p/+KCew4fWhwUHMbBLK8YPSGFY10RSW0c3zjk8jIHNC6wiz/0O4pLhwqehz5V64I0KSP5b6DZKTk5myJAhAFx11VVHTo/7xBNPUFZWRmFhIT169GD48OHs3LmTsWPHAhAe/vPRhuvWrWPy5Ml8+eWXtG/fHrDO+ZKRkUHHjh254YYbmDFjBjt37qRFixZERUVRUlLClClTyMrKwuFwsGHDhiPPN2DAADp37gxARkYG48ePx+Fw0L59e0aMGOGrVdOgfaXlZLpn4Ity8sk/8POh9dcM6cwwXxxabwxs+cYq8h3fWvuAX/Ak9L1ai1wFNP8tdA9m0o2l9mxRRLjxxhtZsWIFycnJTJs27cgpcOvTrl07ysvL+eGHH44U+rBhw3j++efZvn07Dz/8MB9//DGzZ88+ciKvp556ijZt2rBq1SpcLtdRvyD89VS75VXVrNhaRObGPL6pdWj90NQEhqUmkp7qw0Prt2Rab3ZuWwwx7eE306HfBD1niWoW/LfQbbR9+3aWLl3K4MGDeeeddxg6dChLliwhISGBAwcOMHv2bC677DJiYmJISkrik08+4eKLL6aiooLqausA2fj4eF577TVGjhxJVFQUw4cPJzk5mfz8fCorK+nSpQtDhw5l+vTpPPfcc4B1qt2kpCSCgoKYOXPmkeeqbdiwYbz88sv8/ve/Z9++fSxYsIArrriizrHeZowhZ98BvnEflblsSwHlVS5CHEJax5b8ZdRpDEtN9P6h9Q3Zutgq8q2Z1j7j5//dKvIQH/0iUcoPaKHX4bTTTuP5559n0qRJdO/enRtuuIGioiJ69uxJ27Ztjzq17Ztvvsl1113H1KlTCQkJOepN0TZt2jB37lzOP/98Xn/9dQYOHMjAgQOPFHV6ejp33303Q4cOBeDGG2/k0ksv5Y033mDUqFH1zsrHjh3L/Pnz6d69OykpKQwePLgR1wYUuQ+tz6h1aH2XxCjGnZnCsK4JDOzciqgwG76dti2FhY/AlgzrwJ9Rj0P/iVrkqllq8PS5jUVPn9v4TnR9VlW7+GF7sbvA81jtPrQ+NjyYoakJpLs3oxz3ofXetH2ZVeSbF1pnFxx6C6RNgpAI+zIp5QMne/pc1QxszT/o3g6ez7ebCzhQ4cQRJPRxH1qf3jWB3h3i7L+O4o7lVpFvmm8dgv/rhyDtDxBq4y8XpfyEFnozdfjQ+syN1maU7YVlACS1iGB0n/YMS01g8CkJxEX4yX7auSutIs/52rp4w8gH4MxrIdQ/3yxWyg5+V+jGGL2uoBfU3pR2+ND6w/uEH31ofSuuTe9MemoinbxxaL037fzeOmnWxi8goiWcOw3O/COERdudTCm/41eFHh4eTkFBAa1atfKvUmlijDEUFBQQHh7Out37eW5BDos25h91aP31Z3chPTWRfikt/POq6LuyrCLf8DlEtIBfTYUBkyEsxu5kSvktvyr0pKQkcnNzycvLsztKkxcWFk7Gzmoe+e9iIsMcjOzehmFdExlySitaRfvxPtm7V1tFvv4/EB4H59wLA6+D8Fi7kynl9/yq0ENCQo4cDalO3O6SQ9z+/iqWbCpgZPc2PHZJL/8ucYA9a6z9yH+aC2FxMPweGHS9VepKKY/4VaGrkzd39S7u+ehHnC7D45f24ndpyf69+WpvtlXk6+ZAWCycfRcMugEi4u1OplSTo4UeIPaXVzHt07V89MNO+iTH8/TlfeiU4Md7gOz7Cb55DNZ+AqHRMOwvMPhGa3u5UuqEaKEHgO+2FHLre1ns2V/OLeemMuWcU+3fX7w+eevhm8dhzUfWLofpt8HgKRDZ0u5kSjV5WuhNWKXTxVNfb+ClbzaR0jKSD64fTL8UP53h5m+0ivzH2RASaR3ZOfgmiGpldzKlAoYWehOVs6+Um2dlsXbXfsadmcx9F3a351wqDSnY5C7yDyA4HIb8Cc76E0Ql2J1MqYDjhw2gfokxhje/3cbD/1lHZKiDl6/uz3k92tod61gFmyDj77D6PXCEweD/g7NuhuhEu5MpFbC00JuQfaXl/GX2ahauz+Psron8/bLevjvPuKcKt0DGdFj1rnV5t0E3wpCbIbq13cmUCnha6E3El2v3cNdHP3KwwskDY3pw9aCO/rU7YtE2a0a+6l0Qh3Uw0JBbIKaN3cmUaja00P3cwQonD87NZtbyHfRoH8sz4/pwams/Ovy9eLs1I8962yrytD/A0Fshtp3dyZRqdrTQ/dj324u47b0sthWWccPwU7j13K7+c96V4h2Q+Q/44S0Qgf7XWLsgxra3O5lSzZYWuh9yVrt4bkEO/5yfQ9vYcGb9cRADu/jJ7n0lO60i//4N636/CVaRxyXZm0sppYXub7bmH+SW97LI2lHM2L4duH9MD2LD/eCc5Pt3QeaT8P1MMAb6XgXpt0N8st3JlFJuHhW6iIwCngEcwKvGmMdqPZ4CzATi3WPuMsbM827UwGaM4f0VO7j/s2yCg4Rnx/dl9Bl+sPmidA8segpW/AtMNfS5EobdAfEpdidTStXSYKGLiAN4HhgJ5ALLRWSOMSa7xrB7gfeNMS+KSHdgHtCpEfIGpMKDldz14Wq+zN7LWae0Yvpvz6B9vM3XxizdC4ufhhWvQ3UV9BkPw/4MLTrZm0spVS9PZugDgBxjzGYAEZkFjAFqFroBDp+wOg7Y5c2QgWzh+n38efZqSsqquPeCbkwa0pmgIBt3RzywDxY/A8tfg+pKOGOcNSNv2cW+TEopj3hS6B2AHTXu5wIDa42ZBnwpIjcBUcC5dT2RiEwGJgOkpDTvP9kPVVbz2OfrmLl0G6e1ieGNSQPo1s7GizgczLdm5MtfA2c59L7cmpG3OsW+TEqp4+KtN0XHA/82xvxDRAYDb4pIT2OMq+YgY8wMYAZAWlqaqeN5moU1O0u45b0scvYd4A9DO/Pn804jPMRhT5iDBbDkGfjuFavIe/3WOpVtwqn25FFKnTBPCn0nUHNXhiT3spr+AIwCMMYsFZFwIAHY542QgaLaZZiRsZknv1pPy6hQ3vrDQIam2nSSqrJCWPJP+G4GVB6EXpdZRZ7Y1Z48SqmT5kmhLwdSRaQzVpGPA66oNWY78Cvg3yLSDQgH9MKgNeQWlXHb+6v4bkshv+nVlkfG9iI+MtT3QcoKYenzsOxlqDwAPcbC2XdC69N9n0Up5VUNFroxxikiU4AvsHZJfN0Ys1ZEHgBWGGPmALcDr4jIrVhvkE40xjTbTSo1GWP4NGsX932yBgP847dncEm/Dr4/D8uhIlj6Aix7CSr2Q/eLYfhd0Lqbb3MopRqNR9vQ3fuUz6u1bGqN29nAEO9Ga/pKyqq499M1fLZqF2kdW/DU5X1Ibhnp2xDOCms/8qUvQEUJdBttFXmbHr7NoZRqdHqkaCNZkpPP7R+sIq+0gj+fdxrXn30KDjt2R8z4u/Vx+oVWkbft5fsMSimf0EL3sgpnNdO/WM8rmVvokhDFRzeeRe+keHvCHMizZuY9xsJv/21PBqWUz2ihe9H6PaXcPOsHftpTylWDUrjnN92IDLVxFS96CpyHYPg99mVQSvmMFroXuFyGfy3ZyuP//YnY8GBen5jGiNNtvrBDyU5Y/iqcMV53RVSqmdBCP0l7Ssq544NVLMrJ59xurXns0t4kRIfZHQsyp4NxWbskKqWaBS30kzDvx93c/dGPVDpdPDK2F+MHJPvHZeEKt1jnK+8/EVp0tDuNUspHtNBPQGl5Ffd/ls3slbmckRTHU5f3oUtitN2xfvbN4xAUDOl32J1EKeVDWujHafnWQm59L4tdxYf404hTuelXqYQ4/OSycAB562H1ezD4//S6nko1M1roHqqqdvHM1xt5YWEOSS0i+eD6wfTv2NLuWMda8DCERMGQW+1OopTyMS10D2zKO8Ct72WxOreE36UlMfWiHkSH+eGq270Ksj+13giN8pNrkCqlfMYPW8l/GGN4e9l2HvpPNuEhDl66qh+jevrxZoz5D0F4vLW5RSnV7Gih1yOvtII7P1zN/J/2kZ6awPTfnkGb2HC7Y9Vv+zLY+CWcOw3C4+xOo5SygRZ6Hb7O3sudH66mtMLJtIu6M2FwJ3svC9cQY2D+gxDVGgZMtjuNUsomWug1lFU6eeg/63hn2Xa6tYvl3XF96Nomxu5YDdu8ELZmwvlPQGiU3WmUUjbRQndbtaOYW97LYmvBQa47uwu3jexKWLBNl4U7Hodn53HJ1oFESqlmq9kXurPaxYsLN/H0/zbSJiaMd64dxOBTmtAeIus/h50rYfQ/IdgPTjmglLJNsy707QVl3Pp+Fiu3FTGmT3seGNOTuIgQu2N5zuWy9jtv2cU6CZdSqllrloVujGH2ylymzVlLUJDwzLg+jOnTwe5Yx2/tR7B3DVzyKjia0C8ipVSjaHaFXnSwkns+/pHP1+xhYOeWPHl5HzrER9gd6/hVO2Hho9C6O/S81O40Sik/0KwKPWNDHnd8sIqiskruPv90rk3vYs9l4bxh1btQkAOXvw1BfnQuGaWUbZpFoZdXVfPY5z/x7yVbSW0dzb+uOZMe7ZvwwTfOCuuMiu37wekX2J1GKeUnAr7Q1+4q4ZZZWWzcd4BrhnTizlGnEx7SBHZH/CUrZ0LJDrjoGfCH868rpfxCwBa6y2V4JXMz079cT4vIUN6YNIBhXRPtjnXyKsusqxF1HAKnjLA7jVLKjwRkoe8qPsRt72fx7eZCRvVoy6OX9KJFVKjdsbzjuxlwYC/8dqbOzpVSRwm4Qp+zahd//fhHXC7DE5f15rf9k/zjsnDeUF4Ci5+GU8+FjoPtTqOU8jMBU+glh6r426dr+CRrF/07tuCp3/UhpVWk3bG8a+kLcKgIRtxrdxKllB/yqNBFZBTwDOAAXjXGPFbr8aeAc9x3I4HWxph4L+b8RUs3FXD7+1nsLa3g9pFduWH4KQT702XhvKGsEJY+D90ugvZ97U6jlPJDDRa6iDiA54GRQC6wXETmGGOyD48xxtxaY/xNgE8ap8JZzZNfbWBGxmY6tYriwxvOok9yvC9e2vcWPw2VB+Ccv9qdRCnlpzyZoQ8AcowxmwFEZBYwBsiuZ/x44G/eiVe/jXtLuXlWFtm793PFwBTuvaAbkaEBswXpaKV7YNkM6P07aN3N7jRKKT/lSQN2AHbUuJ8LDKxroIh0BDoD8+t5fDIwGSAlJeW4gh5mjGHmkq08+vlPRIcF8+qENM7t3uaEnqvJyPwHuKpg+F12J1FK+TFvT2nHAbONMdV1PWiMmQHMAEhLSzMn8gJPf72RZ/63kRGnt+bxS3uTGBPgp4wt3g4r/gV9r7LOqqiUUvXwpNB3Ask17ie5l9VlHNCoVyi+YmAKbWLDGT8gOXB2R/wl3zwOEgTD/mJ3EqWUn/NkV5DlQKqIdBaRUKzSnlN7kIicDrQAlno34tHaxIZzxcCU5lHm+TmQ9S6c+QeIa4Kn91VK+VSDhW6McQJTgC+AdcD7xpi1IvKAiIyuMXQcMMsYc0KbUlQdFj4CweEw9Da7kyilmgCPtqEbY+YB82otm1rr/jTvxVLsWQNrPoT02yE6AM5Bo5RqdAF29E0AWfAwhMXBWTfZnUQp1URoofuj3BWwfh4MuQkiWtidRinVRGih+6P5D0JkKxh4g91JlFJNiBa6v9mSCZsXWm+EhkXbnUYp1YRoofsTY6zZeUw7a1dFpZQ6Dlro/mTjV7BjGQz7M4RE2J1GKdXEaKH7C5fLmp3Hd4S+V9udRinVBAXo6QmboHVzYM9quPglCA6Qy+UppXxKZ+j+wFUNCx6BhNOsU+QqpdQJ0Bm6P1j9PuSvty78HOSwO41SqonSGbrdnJWw8FFo2xu6jW54vFJK1UNn6Hb74U0o3gZXfABB+vtVKXXitEHsVHUIMv4OyQMhdaTdaZRSTZzO0O20/DUo3Q2XvALN4fzuSqlGpTN0u1SUwqInoctw6JxudxqlVADQQrfLspegrABGTG14rFJKeUAL3Q6HimDxP+G030BSf7vTKKUChBa6HZb8Eyr2wzl/tTuJUiqAaKH72oE8+PYl6HkJtO1pdxqlVADRQve1RU+CsxyG32N3EqVUgNFC96WSndauin3GQ8KpdqdRSgUYLXRfyngCjAvOvtPuJEqpAKSF7iuFm+GHtyDtGohPsTuNUioAaaH7ysLHICgE0m+3O4lSKkBpofvCvnXWKXIH/BFi2tqdRikVoLTQfWHBwxAaDUNvtTuJUiqAeVToIjJKRNaLSI6I3FXPmN+JSLaIrBWRd7wbswnb9QOs+wwG/x9EtrQ7jVIqgDV4tkURcQDPAyOBXGC5iMwxxmTXGJMK3A0MMcYUiUjrxgrc5Mx/CCJaWIWulFKNyJMZ+gAgxxiz2RhTCcwCxtQa80fgeWNMEYAxZp93YzZR25ZCztcw5BYIj7U7jVIqwHlS6B2AHTXu57qX1dQV6Coii0XkWxEZVdcTichkEVkhIivy8vJOLHFTYQzMfxCi28CAyXanUUo1A956UzQYSAWGA+OBV0QkvvYgY8wMY0yaMSYtMTHRSy/tpzbNh22LIf0OCI20O41SqhnwpNB3Ask17ie5l9WUC8wxxlQZY7YAG7AKvnk6PDuPS4b+v7c7jVKqmfCk0JcDqSLSWURCgXHAnFpjPsGanSMiCVibYDZ7L2YT89N/rL1bzr4TgsPsTqOUaiYaLHRjjBOYAnwBrAPeN8asFZEHRGS0e9gXQIGIZAMLgD8bYwoaK7Rfc1Vb+523OhXOGG93GqVUM+LRRaKNMfOAebWWTa1x2wC3uT+atzUfwb5suPQ1cOg1uJVSvqNHinpTdRUsfATa9IQel9idRinVzOgU0puy3rHOqjjuXQjS35VKKd/S1vEWZwV88wR06A+nnW93GqVUM6QzdG9Z8S/YnwtjngMRu9MopZohnaF7Q+VByPwHdEqHLsPtTqOUaqZ0hu4N382Ag/vg8rd0dq6Uso3O0E9WeQksehpSfw0pA+1Oo5RqxrTQT9bS56G8GEbca3cSpVQzp4V+Mg4WWIXefQy0O8PuNEqpZk4L/WQsfgqqyuCcv9qdRCmltNBP2P7d8N0r0PtySDzN7jRKKaWFfsIyp4PLaZ1RUSml/IAW+oko2gYrZ0Lfq6FlZ7vTKKUUoIV+Yr55HCQIhv3Z7iRKKXWEFvrxytsAq96FM6+FuNqXVlVKKftooR+vhY9AcAQMvdXuJEopdRQt9OOxezWs/RgG3QDRAX6Ra6VUk6OFfjwWPAzhcXDWTXYnUUqpY2ihe2rHctjwXzjrTxARb3capZQ6hha6p+Y/AFGJMPB6u5MopVSdtNA9sfkb2JIBQ2+DsGi70yilVJ200BtiDMx/EGI7QNoku9MopVS9tNAbsuELyF1uHUQUEm53GqWUqpcW+i9xuWD+Q9CiM/S9yu40Sin1i/QSdL8k+xPY+yOMnQGOELvTKKXUL9IZen2qnbDgEUg8HXpdZncapZRqkEeFLiKjRGS9iOSIyF11PD5RRPJEJMv9ca33o/rY6vegYKN18Yogh91plFKqQQ1uchERB/A8MBLIBZaLyBxjTHatoe8ZY6Y0Qkbfc1bCN49Buz7Q7SK70yillEc8maEPAHKMMZuNMZXALGBM48ay2Q9vQPF2GHEfiNidRimlPOJJoXcAdtS4n+teVtulIrJaRGaLSHJdTyQik0VkhYisyMvLO4G4PlB1CDKmQ8pgOPVXdqdRSimPeetN0c+ATsaY3sBXwMy6BhljZhhj0owxaYmJfnq2wuWvQulunZ0rpZocTwp9J1Bzxp3kXnaEMabAGFPhvvsq0N878XysohQWPQWnjIBOQ+xOo5RSx8WTQl8OpIpIZxEJBcYBc2oOEJF2Ne6OBtZ5L6IPffsilBXAOffanUQppY5bg3u5GGOcIjIF+AJwAK8bY9aKyAPACmPMHOBPIjIacAKFwMRGzNw4ygphyT/htAsgqWn+gaGUat48OlLUGDMPmFdr2dQat+8G7vZuNB9b8qy1yWXEX+1OopRSJ0SPFAUo3QvLXoael0KbHnanUUqpE6KFDrDoSXBWwDn32J1EKaVOmBZ68Q5Y8Tr0uQJanWJ3GqWUOmFa6BlPWP+efae9OZRS6iQ170Iv2AQ/vA39r4H4Og9uVUqpJqN5F/rCR8ERCum3251EKaVOWvMt9L3Z8ONsGHgdxLSxO41SSp205lvoCx6GsBgYcrPdSZRSyiuaZ6HvXAk/zYXBUyCypd1plFLKK5pnoc9/CCJawqAb7E6ilFJe0/wKfeti2DQfht4K4bF2p1FKKa9pXoVuDMx/EKLbwoA/2p1GKaW8qnkVes7/YPtSGHYHhETYnUYppbyq+RT64dl5fAr0+73daZRSyuuaT6Gv+wx2Z8HZd0FwqN1plFLK65pHobuqYcEj0CoVel9udxqllGoUHl3goslb8yHkrYPL/gWO5vElK6Wan8CfoVdXWbPzNr2g+8V2p1FKqUYT+NPVrLehaAuMfw+CAv/3l1Kq+Qrshqsqh2+egA5p0PU8u9MopVSjCuwZ+sp/wf6dcPELIGJ3GqWUalSBO0OvPAiZ/4BO6dBluN1plFKq0QXuDH3ZS3AwD8a9Y3cSpZTyicCcoR8qhsXPQOp5kDzA7jRKKeUTgVnoS5+D8hIYca/dSZRSymc8KnQRGSUi60UkR0Tu+oVxl4qIEZE070U8Tgfz4dsXrX3O2/W2LYZSSvlag4UuIg7geeB8oDswXkS61zEuBrgZWObtkMdl0VNQVQbn/NXWGEop5WuezNAHADnGmM3GmEpgFjCmjnEPAo8D5V7Md3z274LvXoHe4yCxq20xlFLKDp4UegdgR437ue5lR4hIPyDZGPMfL2Y7fhl/B+OC4XfaGkMppexw0m+KikgQ8CRwuwdjJ4vIChFZkZeXd7IvfbTCLfD9G9BvArTo5N3nVkqpJsCTQt8JJNe4n+RedlgM0BNYKCJbgUHAnLreGDXGzDDGpBlj0hITE088dV2+eRyCgmHYn737vEop1UR4UujLgVQR6SwiocA4YM7hB40xJcaYBGNMJ2NMJ+BbYLQxZkWjJK5L3npY/R6ceS3EtvPZyyqllD9psNCNMU5gCvAFsA543xizVkQeEJHRjR3QIwsehpBIGHqb3UmUUso2Hh36b4yZB8yrtWxqPWOHn3ys47B7FWR/CsP+AlGtfPrSSinlT5r+kaLzH4LweDhrit1JlFLKVk270Lcvg41fwpCbITzO7jRKKWWrplvoxsD8ByGqNQy8zu40Sillu6Zb6JsXwtZMSL8dQqPsTqOUUrZrmoVujLXtPDYJ0q6xO41SSvmFplnoG/4LO1fA2X+B4DC70yillF9oeoXuclmz85ZdoM8VdqdRSim/0fQuQZf9MexdA5e8Ao4Qu9MopZTfaHoz9NAYOP1C6Hmp3UmUUsqvNL0ZetdfWx9KKaWO0vRm6Eoppeqkha6UUgFCC10ppQKEFrpSSgUILXSllAoQWuhKKRUgtNCVUipAaKErpVSAEGOMPS8skgdsO8FPTwDyvRjHWzTX8dFcx89fs2mu43MyuToaYxLresC2Qj8ZIrLCGJNmd47aNNfx0VzHz1+zaa7j01i5dJOLUkoFCC10pZQKEE210GfYHaAemuv4aK7j56/ZNNfxaZRcTXIbulJKqWM11Rm6UkqpWrTQlVIqQPh1oYvIKBFZLyI5InJXHY+Hich77seXiUgnP8k1UUTyRCTL/XGtj3K9LiL7RGRNPY+LiDzrzr1aRPr5Sa7hIlJSY31N9UGmZBFZICLZIrJWRG6uY4zP15eHuexYX+Ei8p2IrHLnur+OMT7/efQwly0/j+7XdojIDyIyt47HvL++jDF++QE4gE1AFyAUWAV0rzXmRuAl9+1xwHt+kmsi8JwN62wY0A9YU8/jvwE+BwQYBCzzk1zDgbk+XlftgH7u2zHAhjr+H32+vjzMZcf6EiDafTsEWAYMqjXGjp9HT3LZ8vPofu3bgHfq+v9qjPXlzzP0AUCOMWazMaYSmAWMqTVmDDDTfXs28CsRET/IZQtjTAZQ+AtDxgBvGMu3QLyItPODXD5njNltjPnefbsUWAd0qDXM5+vLw1w+514HB9x3Q9wftfeo8PnPo4e5bCEiScAFwKv1DPH6+vLnQu8A7KhxP5djv7GPjDHGOIESoJUf5AK41P1n+mwRSW7kTJ7yNLsdBrv/bP5cRHr48oXdf+r2xZrd1WTr+vqFXGDD+nJvPsgC9gFfGWPqXV8+/Hn0JBfY8/P4NPAXwFXP415fX/5c6E3ZZ0AnY0xv4Ct+/i2s6vY91vkpzgD+CXziqxcWkWjgQ+AWY8x+X71uQxrIZcv6MsZUG2P6AEnAABHp6YvXbYgHuXz+8ygiFwL7jDErG/u1avLnQt8J1PxNmuReVucYEQkG4oACu3MZYwqMMRXuu68C/Rs5k6c8Wac+Z4zZf/jPZmPMPCBERBIa+3VFJASrNN82xnxUxxBb1ldDuexaXzVevxhYAIyq9ZAdP48N5rLp53EIMFpEtmJtlh0hIm/VGuP19eXPhb4cSBWRziISivWmwZxaY+YAv3ffvgyYb9zvMNiZq9Z21tFY20H9wRxggnvvjUFAiTFmt92hRKTt4W2HIjIA6/uyUYvA/XqvAeuMMU/WM8zn68uTXDatr0QRiXffjgBGAj/VGubzn0dPctnx82iMudsYk2SM6YTVEfONMVfVGub19RV8Mp/cmIwxThGZAnyBtWfJ68aYtSLyALDCGDMH6xv/TRHJwXrTbZyf5PqTiIwGnO5cExs7F4CIvIu1B0SCiOQCf8N6kwhjzEvAPKw9N3KAMuAaP8l1GXCDiDiBQ8A4H/xiHgJcDfzo3v4KcA+QUiOXHevLk1x2rK92wEwRcWD9AnnfGDPX7p9HD3PZ8vNYl8ZeX3rov1JKBQh/3uSilFLqOGihK6VUgNBCV0qpAKGFrpRSAUILXSmlAoQWulJKBQgtdKWUChD/D/ds23CH4U5JAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "importlib.reload(sparse)\n",
    "sparse.benchmark_reset()\n",
    "sparse.timer_reset()\n",
    "\n",
    "# setup for evaluation\n",
    "sparse_bert = sparse.SparseBertModel(bert.config)\n",
    "sparse_bert.to(trainer.device)\n",
    "sparse_bert.eval()\n",
    "sparse_bert.load_state_dict(bert.state_dict())\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "sparse_bert = sparse_bert.to(trainer.device)\n",
    "approx_bert = approx_bert.to(trainer.device)\n",
    "bert = bert.to(trainer.device)\n",
    "sparse.set_print(sparse_bert, False)\n",
    "sparse.set_backup_last_inputs(sparse_bert, False)\n",
    "sparse.set_output_masking(sparse_bert, False)\n",
    "\n",
    "acc_bert = accuracy(lambda batch: eval(bert, batch=batch))\n",
    "accs_backward = []\n",
    "accs_forward = []\n",
    "ks = [0.1, 0.25, 0.375, 0.5, 0.75]\n",
    "for k in ks:\n",
    "    acc_backward=acc_forward = 1.0\n",
    "    acc_backward, lm = accuracy(\n",
    "        lambda batch: approx_eval(sparse_bert, approx_bert, batch=batch, k=k),\n",
    "        return_lm = True,\n",
    "    )\n",
    "    acc_forward, lm = accuracy(\n",
    "        lambda batch: forward_eval(sparse_bert, batch=batch, k=k),\n",
    "        return_lm = True,\n",
    "    )\n",
    "    accs_backward.append(acc_backward)\n",
    "    accs_forward.append(acc_forward)\n",
    "    print(k, acc_backward / acc_forward, acc_backward, acc_forward)\n",
    "\n",
    "import pandas as pd\n",
    "df = pd.DataFrame()\n",
    "df['forward'] = accs_forward\n",
    "df['backward'] = accs_backward\n",
    "df.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df #forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>forward</th>\n",
       "      <th>backward</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.360000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.831447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.900921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.917895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.926053</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   forward  backward\n",
       "0      1.0  0.360000\n",
       "1      1.0  0.831447\n",
       "2      1.0  0.900921\n",
       "3      1.0  0.917895\n",
       "4      1.0  0.926053"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df #not adjust"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>forward</th>\n",
       "      <th>backward</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.301842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.665132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.837105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.902500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.925789</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   forward  backward\n",
       "0      1.0  0.301842\n",
       "1      1.0  0.665132\n",
       "2      1.0  0.837105\n",
       "3      1.0  0.902500\n",
       "4      1.0  0.925789"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df #adjust softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lm_output = sparse.run_bert_with_approx(\n",
    "    sparse_bert, \n",
    "    approx_bert, \n",
    "    {\n",
    "        'input_ids': batch.input_ids,\n",
    "        'attention_mask': batch.attention_masks,\n",
    "        'output_hidden_states': True,\n",
    "        'output_attentions': True,\n",
    "    },\n",
    "    ks = [0.999]+[k]*(len(sparse_bert.encoder.layer)-1),\n",
    ")\n",
    "\n",
    "lm_output = sparse.run_bert_forward_sparsity(\n",
    "    sparse_bert, \n",
    "    {\n",
    "        'input_ids': batch.input_ids,\n",
    "        'attention_mask': batch.attention_masks,\n",
    "        'output_hidden_states': True,\n",
    "        'output_attentions': True,\n",
    "    },\n",
    "    ks = k,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "58c896f8fe28377dc6f47dbc9814b9367447c8ff4b1090ace6962dd6db7d2533"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
