{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import trainer.glue_base as glue_base\n",
    "import models.sparse_token as sparse\n",
    "import pickle, importlib\n",
    "importlib.reload(glue_base)\n",
    "importlib.reload(sparse)\n",
    "Glue = glue_base.GlueAttentionApproxTrainer\n",
    "PICKLE_PATH = \"glue_benchmark_wiki.pkl\"\n",
    "TEX_PATH = \"saves_plot/glue_benchmark_wiki.tex\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainer: mnli\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:datasets.builder:Reusing dataset glue (./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "WARNING:datasets.arrow_dataset:Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c8de9ec562ad8c1a.arrow\n",
      "WARNING:datasets.arrow_dataset:Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e8a73fc1e1933d2b.arrow\n",
      "WARNING:datasets.builder:Reusing dataset glue (./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n",
      "WARNING:datasets.arrow_dataset:Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-baccd56b3c558104.arrow\n",
      "WARNING:datasets.arrow_dataset:Loading cached processed dataset at ./cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-80b2f7c5c23d1000.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WikitextBatchLoader: worker_main\n",
      "WikitextBatchLoader: worker_main\n",
      "WikitextBatchLoader: worker_main\n",
      "WikitextBatchLoader: worker_main\n",
      "Trainer: Save checkpoint path saves/glue-mnli-16-wiki-b5.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "eval: 100%|██████████| 1227/1227 [01:04<00:00, 18.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "metric score {'accuracy': 0.8420784513499745}\n",
      "avg occupy 0.2057366697815377\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.76 GiB total capacity; 1.05 GiB already allocated; 12.44 MiB free; 1.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-4f69b3655d81>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     28\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mks\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkss\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbenchmark_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m         \u001b[0msparse_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval_sparse_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'dynamic'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m             \u001b[0mest_k\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbenchmark_get_average\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'est_k'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/discrete_edge_learning/trainer/glue_base.py\u001b[0m in \u001b[0;36meval_sparse_model\u001b[0;34m(self, ks, use_forward)\u001b[0m\n\u001b[1;32m    227\u001b[0m         \u001b[0msparse_cls_bert\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    228\u001b[0m         \u001b[0msparse_cls_bert\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbert\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapped_bert\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 229\u001b[0;31m         \u001b[0msparse_cls_bert\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    230\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    231\u001b[0m         \u001b[0msparse_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval_base_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msparse_cls_bert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    897\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    898\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 899\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    900\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    901\u001b[0m     def register_backward_hook(\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    568\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 570\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    571\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    572\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    591\u001b[0m             \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    592\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 593\u001b[0;31m                 \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    594\u001b[0m             \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    595\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/w10_backup/user/anaconda3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m    895\u001b[0m                 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[1;32m    896\u001b[0m                             non_blocking, memory_format=convert_to_format)\n\u001b[0;32m--> 897\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    898\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    899\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.76 GiB total capacity; 1.05 GiB already allocated; 12.44 MiB free; 1.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "subsets = [\"cola\",\"mnli\",\"mrpc\",\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\",\"wnli\",]\n",
    "subsets = [\"mnli\",\"mrpc\",\"qnli\",\"qqp\",\"rte\",\"sst2\",\"stsb\"]\n",
    "kss = [\n",
    "    0.1, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 0.999, 'dynamic', \n",
    "    #'dynamic:avg:avg:true', 'dynamic:avg:avg:false', 'dynamic:avg:max:true', 'dynamic:avg:max:false',\n",
    "    #'dynamic:max:avg:true', 'dynamic:max:avg:false', 'dynamic:max:max:true', 'dynamic:max:max:false',\n",
    "]\n",
    "#kss = ['dynamic']\n",
    "sparse.benchmark_reset()\n",
    "# subsets = [\"mrpc\"]\n",
    "# kss = ['dynamic:avg:avg:f',0.1]\n",
    "\n",
    "def get_score(score):\n",
    "    if 'accuracy' in score:\n",
    "        return score['accuracy'], \"acc\"\n",
    "    first_metric = list(score.keys())[0]\n",
    "    return score[first_metric], first_metric\n",
    "\n",
    "results = {}\n",
    "i = 0\n",
    "for subset in subsets:\n",
    "    trainer = Glue(dataset=subset, factor=16, batch_size=-1, device=0, wiki_train=True)\n",
    "    trainer.load()\n",
    "    scores = {}\n",
    "    metric_name = \"\"\n",
    "    bert_score, metric_name = get_score(trainer.eval_base_model())\n",
    "    scores['bert'] = f'{bert_score:.5f}'\n",
    "    for ks in kss:\n",
    "        sparse.benchmark_reset()\n",
    "        sparse_score, _ = get_score(trainer.eval_sparse_model(ks=ks))\n",
    "        if isinstance(ks, str) and ks.startswith('dynamic'):\n",
    "            est_k = sparse.benchmark_get_average('est_k')\n",
    "            scores[str(ks)] = f'{sparse_score:.5f} (k:{est_k:.2f})'\n",
    "        else:\n",
    "            scores[str(ks)] = f'{sparse_score:.5f}'\n",
    "        i += 1\n",
    "        count = len(subsets) * len(kss)\n",
    "        print(f'{i}/{count}')\n",
    "    results[f\"{subset} ({metric_name})\"] = scores\n",
    "\n",
    "with open(PICKLE_PATH, 'wb') as f:\n",
    "    pickle.dump(results, f)\n",
    "\n",
    "sparse.benchmark_report()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_159507/610143257.py:54: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  tex = df.to_latex()\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cola</th>\n",
       "      <th>mnli</th>\n",
       "      <th>mrpc</th>\n",
       "      <th>qnli</th>\n",
       "      <th>qqp</th>\n",
       "      <th>rte</th>\n",
       "      <th>sst2</th>\n",
       "      <th>stsb</th>\n",
       "      <th>wnli</th>\n",
       "      <th>reproduce</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>bert</th>\n",
       "      <td>0.53388</td>\n",
       "      <td>0.84198</td>\n",
       "      <td>0.84406</td>\n",
       "      <td>0.91543</td>\n",
       "      <td>0.90908</td>\n",
       "      <td>0.72563</td>\n",
       "      <td>0.92431</td>\n",
       "      <td>0.88047</td>\n",
       "      <td>0.56338</td>\n",
       "      <td>100.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dynamic:original</th>\n",
       "      <td>0.53388 (k:0.48)</td>\n",
       "      <td>0.84004 (k:0.33)</td>\n",
       "      <td>0.83942 (k:0.71)</td>\n",
       "      <td>0.91287 (k:0.44)</td>\n",
       "      <td>0.90920 (k:0.45)</td>\n",
       "      <td>0.72563 (k:0.55)</td>\n",
       "      <td>0.92317 (k:0.73)</td>\n",
       "      <td>0.88043 (k:0.51)</td>\n",
       "      <td>0.56338 (k:0.63)</td>\n",
       "      <td>99.87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dynamic:w_augment</th>\n",
       "      <td>0.53388 (k:0.48)</td>\n",
       "      <td>0.83647 (k:0.33)</td>\n",
       "      <td>0.83826 (k:0.72)</td>\n",
       "      <td>0.91506 (k:0.48)</td>\n",
       "      <td>0.90905 (k:0.45)</td>\n",
       "      <td>0.66426 (k:0.40)</td>\n",
       "      <td>0.92202 (k:0.68)</td>\n",
       "      <td>0.88034 (k:0.51)</td>\n",
       "      <td>0.56338 (k:0.61)</td>\n",
       "      <td>98.88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dynamic:wo_augment</th>\n",
       "      <td>0.53388 (k:0.48)</td>\n",
       "      <td>0.81559 (k:0.30)</td>\n",
       "      <td>0.74667 (k:0.66)</td>\n",
       "      <td>0.90445 (k:0.41)</td>\n",
       "      <td>0.90893 (k:0.43)</td>\n",
       "      <td>0.70758 (k:0.38)</td>\n",
       "      <td>0.92202 (k:0.69)</td>\n",
       "      <td>0.86572 (k:0.49)</td>\n",
       "      <td>0.56338 (k:0.54)</td>\n",
       "      <td>97.74</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                cola              mnli              mrpc  \\\n",
       "bert                         0.53388           0.84198           0.84406   \n",
       "dynamic:original    0.53388 (k:0.48)  0.84004 (k:0.33)  0.83942 (k:0.71)   \n",
       "dynamic:w_augment   0.53388 (k:0.48)  0.83647 (k:0.33)  0.83826 (k:0.72)   \n",
       "dynamic:wo_augment  0.53388 (k:0.48)  0.81559 (k:0.30)  0.74667 (k:0.66)   \n",
       "\n",
       "                                qnli               qqp               rte  \\\n",
       "bert                         0.91543           0.90908           0.72563   \n",
       "dynamic:original    0.91287 (k:0.44)  0.90920 (k:0.45)  0.72563 (k:0.55)   \n",
       "dynamic:w_augment   0.91506 (k:0.48)  0.90905 (k:0.45)  0.66426 (k:0.40)   \n",
       "dynamic:wo_augment  0.90445 (k:0.41)  0.90893 (k:0.43)  0.70758 (k:0.38)   \n",
       "\n",
       "                                sst2              stsb              wnli  \\\n",
       "bert                         0.92431           0.88047           0.56338   \n",
       "dynamic:original    0.92317 (k:0.73)  0.88043 (k:0.51)  0.56338 (k:0.63)   \n",
       "dynamic:w_augment   0.92202 (k:0.68)  0.88034 (k:0.51)  0.56338 (k:0.61)   \n",
       "dynamic:wo_augment  0.92202 (k:0.69)  0.86572 (k:0.49)  0.56338 (k:0.54)   \n",
       "\n",
       "                   reproduce  \n",
       "bert                  100.00  \n",
       "dynamic:original       99.87  \n",
       "dynamic:w_augment      98.88  \n",
       "dynamic:wo_augment     97.74  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "\n",
    "with open(PICKLE_PATH, 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "with open('glue_benchmark.pkl', 'rb') as f:\n",
    "    results_original = pickle.load(f)\n",
    "\n",
    "def convert_data_to_results(results):\n",
    "    data = []\n",
    "    subsets = list(results.keys())\n",
    "    factors = list(results[subsets[0]].keys())\n",
    "    for factor in factors:             \n",
    "        row = []\n",
    "        for subset in subsets:\n",
    "            row.append(results[subset][factor])\n",
    "        data.append(row)\n",
    "    return data, factors\n",
    "\n",
    "data, factors = convert_data_to_results(results)\n",
    "data_origin, _ = convert_data_to_results(results_original)\n",
    "data_wo_data_augment = (\"0.53388 (k:0.48)\t0.81559 (k:0.30)\t0.74667 (k:0.66)\t0.90445 (k:0.41)\t\"+\\\n",
    "    \"0.90893 (k:0.43)\t0.70758 (k:0.38)\t0.92202 (k:0.69)\t0.86572 (k:0.49)\t0.56338 (k:0.54)\").split(\"\\t\")\n",
    "\n",
    "factors[-1] = \"dynamic:w_augment\"\n",
    "indicies = factors+[\"dynamic:original\", \"dynamic:wo_augment\"]\n",
    "columns = subsets[:]\n",
    "df_data = data + [data_origin[9], data_wo_data_augment]\n",
    "\n",
    "#calculate reproducibility\n",
    "data_scalar = []\n",
    "for line in df_data:\n",
    "    xs = []\n",
    "    for item in line:\n",
    "        xs.append(float(item.split()[0]))\n",
    "    data_scalar.append(xs)\n",
    "reproducibilities = []\n",
    "for i in range(len(data_scalar)):\n",
    "    rsum = 0\n",
    "    for k in range(len(data_scalar[i])):\n",
    "        rsum += data_scalar[i][k]/data_scalar[0][k]\n",
    "    rsum /= len(data_scalar[i])\n",
    "    reproducibilities.append(rsum)\n",
    "for i, r in enumerate(reproducibilities):\n",
    "    df_data[i].append(f\"{r*100:.2f}\")\n",
    "columns.append(\"reproduce\")\n",
    "\n",
    "df = pd.DataFrame(\n",
    "    df_data,\n",
    "    columns=columns, \n",
    "    index=indicies\n",
    ")\n",
    "#df = df.reindex([\"bert\", \"dynamic:original\", \"dynamic:w_augment\", \"dynamic:wo_augment\"])\n",
    "tex = df.to_latex()\n",
    "with open(TEX_PATH, 'w') as f:\n",
    "    f.write(tex)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\tcola\tmnli\tmrpc\tqnli\tqqp\trte\tsst2\tstsb\twnli\treproduce\n",
    "bert\t0.53388\t0.84198\t0.84406\t0.91543\t0.90908\t0.72563\t0.92431\t0.88047\t0.56338\t100.00\n",
    "dynamic:original\t0.53388 (k:0.48)\t0.84004 (k:0.33)\t0.83942 (k:0.71)\t0.91287 (k:0.44)\t0.90920 (k:0.45)\t0.72563 (k:0.55)\t0.92317 (k:0.73)\t0.88043 (k:0.51)\t0.56338 (k:0.63)\t99.87\n",
    "dynamic:w_augment\t0.53388 (k:0.48)\t0.83647 (k:0.33)\t0.83826 (k:0.72)\t0.91506 (k:0.48)\t0.90905 (k:0.45)\t0.66426 (k:0.40)\t0.92202 (k:0.68)\t0.88034 (k:0.51)\t0.56338 (k:0.61)\t98.88\n",
    "dynamic:wo_augment\t0.53388 (k:0.48)\t0.81559 (k:0.30)\t0.74667 (k:0.66)\t0.90445 (k:0.41)\t0.90893 (k:0.43)\t0.70758 (k:0.38)\t0.92202 (k:0.69)\t0.86572 (k:0.49)\t0.56338 (k:0.54)\t97.74"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# pre data augment\n",
    "\n",
    "cola (matthews_correlation)\tmnli (acc)\tmrpc (acc)\tqnli (acc)\tqqp (acc)\trte (acc)\tsst2 (acc)\tstsb (pearson)\twnli (acc)\n",
    "\n",
    "bert\t0.53388\t0.84198\t0.84406\t0.91543\t0.90908\t0.72563\t0.92431\t0.88047\t0.56338\n",
    "\n",
    "dynamic\t0.53388 (k:0.48)\t0.81559 (k:0.30)\t0.74667 (k:0.66)\t0.90445 (k:0.41)\t0.90893 (k:0.43)\t0.70758 (k:0.38)\t0.92202 (k:0.69)\t0.86572 (k:0.49)\t0.56338 (k:0.54)\n",
    "\n",
    "dynamic:original\t0.53388 (k:0.48)\t0.84004 (k:0.33)\t0.83942 (k:0.71)\t0.91287 (k:0.44)\t0.90920 (k:0.45)\t0.72563 (k:0.55)\t0.92317 (k:0.73)\t0.88043 (k:0.51)\t0.56338 (k:0.63)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f7b3ac0126d0d6fea024471ce24e510948bf6332f7ae1a66cdcb4ee9887514e9"
  },
  "kernelspec": {
   "display_name": "Python 3.8.3 64-bit ('tensorflow': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
