{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import trainer.glue_base as glue_base\n",
    "import models.sparse_token as sparse\n",
    "import pickle, importlib\n",
    "importlib.reload(glue_base)\n",
    "importlib.reload(sparse)\n",
    "Glue = glue_base.GlueAttentionApproxTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: cola\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-6b1dfe928518b6a9.arrow\n",
      "Loading cached processed dataset at ./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f7e44aa0ee690f24.arrow\n",
      "Reusing dataset glue (./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-99bc25cbc8b7096f.arrow\n",
      "Loading cached processed dataset at ./cache/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-49888407280455f9.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: Save checkpoint path saves/glue-cola-16.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:05<00:00, 24.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5338774230813111}\n",
      "avg occupy 0.34522852946633237\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:10<00:00, 12.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.05521043543404519}\n",
      "avg occupy 0.34522852946633237\n",
      "1/27 | cola 0.1 = 0.05521\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:11<00:00, 11.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.408526513461678}\n",
      "avg occupy 0.34522852946633237\n",
      "2/27 | cola 0.25 = 0.40853\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:10<00:00, 12.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.4586369195095741}\n",
      "avg occupy 0.34522852946633237\n",
      "3/27 | cola 0.375 = 0.45864\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:08<00:00, 15.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5025517897100551}\n",
      "avg occupy 0.34522852946633237\n",
      "4/27 | cola 0.5 = 0.50255\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:07<00:00, 17.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5259790835948196}\n",
      "avg occupy 0.34522852946633237\n",
      "5/27 | cola 0.625 = 0.52598\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:07<00:00, 16.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5311686988596115}\n",
      "avg occupy 0.34522852946633237\n",
      "6/27 | cola 0.75 = 0.53117\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:07<00:00, 17.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5311983410233877}\n",
      "avg occupy 0.34522852946633237\n",
      "7/27 | cola 0.875 = 0.53120\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:08<00:00, 16.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5338774230813111}\n",
      "avg occupy 0.34522852946633237\n",
      "8/27 | cola 0.999 = 0.53388\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 131/131 [00:07<00:00, 16.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'matthews_correlation': 0.5338774230813111}\n",
      "avg occupy 0.34522852946633237\n",
      "9/27 | cola dynamic = 0.53388\n",
      "Trainer: mnli\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c8de9ec562ad8c1a.arrow\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e8a73fc1e1933d2b.arrow\n",
      "Reusing dataset glue (./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-baccd56b3c558104.arrow\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e968fe9456be4543bee7467a5caac4d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: Save checkpoint path saves/glue-mnli-16.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:00<00:00, 20.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8420784513499745}\n",
      "avg occupy 0.2057366697815377\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:10<00:00, 17.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.5840040753948039}\n",
      "avg occupy 0.2057366697815377\n",
      "10/27 | mnli 0.1 = 0.58400\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:12<00:00, 16.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8178298522669384}\n",
      "avg occupy 0.2057366697815377\n",
      "11/27 | mnli 0.25 = 0.81783\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:18<00:00, 15.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8383087111563933}\n",
      "avg occupy 0.2057366697815377\n",
      "12/27 | mnli 0.375 = 0.83831\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:23<00:00, 14.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8416709118695873}\n",
      "avg occupy 0.2057366697815377\n",
      "13/27 | mnli 0.5 = 0.84167\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:25<00:00, 14.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8419765664798777}\n",
      "avg occupy 0.2057366697815377\n",
      "14/27 | mnli 0.625 = 0.84198\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:29<00:00, 13.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8419765664798777}\n",
      "avg occupy 0.2057366697815377\n",
      "15/27 | mnli 0.75 = 0.84198\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:30<00:00, 13.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8420784513499745}\n",
      "avg occupy 0.2057366697815377\n",
      "16/27 | mnli 0.875 = 0.84208\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:39<00:00, 12.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8420784513499745}\n",
      "avg occupy 0.2057366697815377\n",
      "17/27 | mnli 0.999 = 0.84208\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:20<00:00, 15.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.841874681609781}\n",
      "avg occupy 0.2057366697815377\n",
      "18/27 | mnli dynamic = 0.84187\n",
      "Trainer: mrpc\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (./cache/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9afadd04a309a170.arrow\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0bddda4d866e6dbd.arrow\n",
      "Reusing dataset glue (./cache/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "Loading cached processed dataset at ./cache/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-efab46bb85d6ed58.arrow\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6d3fa32d8d64371999a31159788f50b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: Save checkpoint path saves/glue-mrpc-16.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:07<00:00, 30.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8440579710144928, 'f1': 0.8865457612821593}\n",
      "avg occupy 0.5399536083129833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 17.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.33507246376811595, 'f1': 0.0}\n",
      "avg occupy 0.5399536083129833\n",
      "19/27 | mrpc 0.1 = 0.33507\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:10<00:00, 19.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.34550724637681157, 'f1': 0.035866780529461996}\n",
      "avg occupy 0.5399536083129833\n",
      "20/27 | mrpc 0.25 = 0.34551\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 17.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.4405797101449275, 'f1': 0.3170559094125973}\n",
      "avg occupy 0.5399536083129833\n",
      "21/27 | mrpc 0.375 = 0.44058\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 17.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.6539130434782608, 'f1': 0.6842940243257536}\n",
      "avg occupy 0.5399536083129833\n",
      "22/27 | mrpc 0.5 = 0.65391\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 17.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.807536231884058, 'f1': 0.8525754884547068}\n",
      "avg occupy 0.5399536083129833\n",
      "23/27 | mrpc 0.625 = 0.80754\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 17.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.84, 'f1': 0.8828522920203735}\n",
      "avg occupy 0.5399536083129833\n",
      "24/27 | mrpc 0.75 = 0.84000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:13<00:00, 16.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8440579710144928, 'f1': 0.8865457612821593}\n",
      "avg occupy 0.5399536083129833\n",
      "25/27 | mrpc 0.875 = 0.84406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:11<00:00, 18.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8440579710144928, 'f1': 0.8865457612821593}\n",
      "avg occupy 0.5399536083129833\n",
      "26/27 | mrpc 0.999 = 0.84406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 216/216 [00:12<00:00, 16.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8440579710144928, 'f1': 0.8865457612821593}\n",
      "avg occupy 0.5399536083129833\n",
      "27/27 | mrpc dynamic = 0.84406\n",
      "est_k: (2592, 1894.1703296703308), avg: 0.7307755901505906\n"
     ]
    }
   ],
   "source": [
    "subsets = [\"mnli\",\"mrpc\",\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\"]\n",
    "subsets = [\"cola\",\"mnli\",\"mrpc\"]#,\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\",\"wnli\",]\n",
    "kss = [\n",
    "    0.1, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 0.999, 'dynamic', \n",
    "    # 'dynamic:avg:avg:true', 'dynamic:avg:avg:false', 'dynamic:avg:max:true', 'dynamic:avg:max:false',\n",
    "    # 'dynamic:max:avg:true', 'dynamic:max:avg:false', 'dynamic:max:max:true', 'dynamic:max:max:false',\n",
    "]\n",
    "sparse.benchmark_reset()\n",
    "# subsets = [\"mrpc\"]\n",
    "# kss = ['dynamic:avg:avg:f',0.1]\n",
    "\n",
    "def get_score(score):\n",
    "    if 'accuracy' in score:\n",
    "        return score['accuracy'], \"acc\"\n",
    "    first_metric = list(score.keys())[0]\n",
    "    return score[first_metric], first_metric\n",
    "\n",
    "results = {}\n",
    "i = 0\n",
    "for subset in subsets:\n",
    "    trainer = Glue(dataset=subset, factor=16, batch_size=-1, device=0)\n",
    "    trainer.load()\n",
    "    scores = {}\n",
    "    metric_name = \"\"\n",
    "    bert_score, metric_name = get_score(trainer.eval_base_model())\n",
    "    scores['bert'] = f'{bert_score:.5f}'\n",
    "    for ks in kss:\n",
    "        sparse.benchmark_reset()\n",
    "        sparse_score, _ = get_score(trainer.eval_sparse_model(ks=ks))\n",
    "        if isinstance(ks, str) and ks.startswith('dynamic'):\n",
    "            est_k = sparse.benchmark_get_average('est_k')\n",
    "            scores[str(ks)] = f'{sparse_score:.5f} (k:{est_k:.2f})'\n",
    "        else:\n",
    "            scores[str(ks)] = f'{sparse_score:.5f}'\n",
    "        i += 1\n",
    "        count = len(subsets) * len(kss)\n",
    "        print(f'{i}/{count} | {subset} {ks} = {sparse_score:.5f}')\n",
    "    results[f\"{subset} ({metric_name})\"] = scores\n",
    "\n",
    "with open('glue_benchmark.pkl', 'wb') as f:\n",
    "    pickle.dump(results, f)\n",
    "\n",
    "sparse.benchmark_report()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cola (matthews_correlation)</th>\n",
       "      <th>mnli (acc)</th>\n",
       "      <th>mrpc (acc)</th>\n",
       "      <th>reproduce</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>bert</th>\n",
       "      <td>0.53388</td>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.1</th>\n",
       "      <td>0.05521</td>\n",
       "      <td>0.58400</td>\n",
       "      <td>0.33507</td>\n",
       "      <td>39.80</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.25</th>\n",
       "      <td>0.40853</td>\n",
       "      <td>0.81783</td>\n",
       "      <td>0.34551</td>\n",
       "      <td>71.53</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.375</th>\n",
       "      <td>0.45864</td>\n",
       "      <td>0.83831</td>\n",
       "      <td>0.44058</td>\n",
       "      <td>79.22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.5</th>\n",
       "      <td>0.50255</td>\n",
       "      <td>0.84167</td>\n",
       "      <td>0.65391</td>\n",
       "      <td>90.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.625</th>\n",
       "      <td>0.52598</td>\n",
       "      <td>0.84198</td>\n",
       "      <td>0.80754</td>\n",
       "      <td>98.06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.75</th>\n",
       "      <td>0.53117</td>\n",
       "      <td>0.84198</td>\n",
       "      <td>0.84000</td>\n",
       "      <td>99.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.875</th>\n",
       "      <td>0.53120</td>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>99.83</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.999</th>\n",
       "      <td>0.53388</td>\n",
       "      <td>0.84208</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dynamic</th>\n",
       "      <td>0.53388 (k:0.48)</td>\n",
       "      <td>0.84187 (k:0.37)</td>\n",
       "      <td>0.84406 (k:0.73)</td>\n",
       "      <td>99.99</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        cola (matthews_correlation)        mnli (acc)        mrpc (acc)  \\\n",
       "bert                        0.53388           0.84208           0.84406   \n",
       "0.1                         0.05521           0.58400           0.33507   \n",
       "0.25                        0.40853           0.81783           0.34551   \n",
       "0.375                       0.45864           0.83831           0.44058   \n",
       "0.5                         0.50255           0.84167           0.65391   \n",
       "0.625                       0.52598           0.84198           0.80754   \n",
       "0.75                        0.53117           0.84198           0.84000   \n",
       "0.875                       0.53120           0.84208           0.84406   \n",
       "0.999                       0.53388           0.84208           0.84406   \n",
       "dynamic            0.53388 (k:0.48)  0.84187 (k:0.37)  0.84406 (k:0.73)   \n",
       "\n",
       "        reproduce  \n",
       "bert       100.00  \n",
       "0.1         39.80  \n",
       "0.25        71.53  \n",
       "0.375       79.22  \n",
       "0.5         90.52  \n",
       "0.625       98.06  \n",
       "0.75        99.67  \n",
       "0.875       99.83  \n",
       "0.999      100.00  \n",
       "dynamic     99.99  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "\n",
    "with open('glue_benchmark.pkl', 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "\n",
    "data = []\n",
    "subsets = list(results.keys())\n",
    "factors = list(results[subsets[0]].keys())\n",
    "for factor in factors:             \n",
    "    row = []\n",
    "    for subset in subsets:\n",
    "        row.append(results[subset][factor])\n",
    "    data.append(row)\n",
    "\n",
    "#calculate reproducibility579\n",
    "data_scalar = []\n",
    "for line in data:\n",
    "    xs = []\n",
    "    for item in line:\n",
    "        xs.append(float(item.split()[0]))\n",
    "    data_scalar.append(xs)\n",
    "reproducibilities = []\n",
    "for i in range(len(data_scalar)):\n",
    "    rsum = 0\n",
    "    for k in range(len(data_scalar[i])):\n",
    "        rsum += data_scalar[i][k]/data_scalar[0][k]\n",
    "    rsum /= len(data_scalar[i])\n",
    "    reproducibilities.append(rsum)\n",
    "for i, r in enumerate(reproducibilities):\n",
    "    data[i].append(f\"{r*100:.2f}\")\n",
    "subsets.append(\"reproduce\")\n",
    "\n",
    "df = pd.DataFrame(data, columns=subsets, index=factors)\n",
    "tex = df.to_latex()\n",
    "with open('saves_plot/glue_benchmark.tex', 'w') as f:\n",
    "    f.write(tex)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\tcola (matthews_correlation)\tmnli (acc)\tmrpc (acc)\tqnli (acc)\tqqp (acc)\trte (acc)\tsst2 (acc)\tstsb (pearson)\twnli (acc)\treproduce\n",
    "bert\t0.53388\t0.84198\t0.84406\t0.91543\t0.90908\t0.72563\t0.92431\t0.88047\t0.56338\t100.00\n",
    "0.1\t0.02214\t0.57932\t0.33507\t0.59217\t0.73391\t0.52347\t0.70872\t0.01637\t0.60563\t57.36\n",
    "0.25\t0.33346\t0.81559\t0.36174\t0.83892\t0.89115\t0.70758\t0.85092\t0.55826\t0.56338\t82.76\n",
    "0.375\t0.45301\t0.83770\t0.54377\t0.89731\t0.90616\t0.70758\t0.89335\t0.75968\t0.56338\t91.88\n",
    "0.5\t0.49403\t0.84167\t0.76928\t0.90976\t0.90885\t0.73285\t0.90596\t0.84956\t0.56338\t97.61\n",
    "0.625\t0.51112\t0.84167\t0.82957\t0.91360\t0.90915\t0.72563\t0.91628\t0.87337\t0.56338\t99.12\n",
    "0.75\t0.52328\t0.84198\t0.84290\t0.91433\t0.90913\t0.72924\t0.92202\t0.87936\t0.56338\t99.77\n",
    "0.875\t0.53120\t0.84208\t0.84348\t0.91525\t0.90913\t0.72563\t0.92317\t0.88039\t0.56338\t99.92\n",
    "0.999\t0.53388\t0.84198\t0.84406\t0.91543\t0.90908\t0.72563\t0.92431\t0.88047\t0.56338\t100.00"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer = Glue(dataset='qnli', factor=16, batch_size=-1, device=0)\n",
    "# trainer.load() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer.model.bert = trainer.model_bert\n",
    "# trainer.eval_base_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer.seed()\n",
    "# import models.sparse_token as sparse\n",
    "# import transformers.models.bert.modeling_bert as berts\n",
    "# import importlib\n",
    "# importlib.reload(sparse)\n",
    "\n",
    "# wrapped_bert = sparse.ApproxSparseBertModel(trainer.model_bert, approx_bert=trainer.approx_bert, ks=0.1)\n",
    "# sparse_cls_bert = berts.BertForSequenceClassification(trainer.model_bert.config)\n",
    "# sparse_cls_bert.load_state_dict(trainer.model.state_dict())\n",
    "# sparse_cls_bert.bert = wrapped_bert\n",
    "# sparse_cls_bert.to(trainer.device).eval()\n",
    "\n",
    "# trainer.eval_base_model(model = sparse_cls_bert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f7b3ac0126d0d6fea024471ce24e510948bf6332f7ae1a66cdcb4ee9887514e9"
  },
  "kernelspec": {
   "display_name": "Python 3.8.3 64-bit ('tensorflow': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
