{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAELklEQVR4nO3YMW4cZQCG4X/Wu7GDghSQ0kRCAomWAlo4QkrEAThVboFEhVCOQINoOQMKAju21+vlAo5pdndehucp528+a/TuP/K03+8H0LOaewDwMHFClDghSpwQJU6IWj92+PL114v8V+53X70a3375au4ZB7ff78fd/d3cM45iNa3G2eps7hlH8c3mi+mh525OiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEqPVjh5+/+PREM07r6ebp+OPy7dwzDm6axrjYnI9pTHNPObjr3c3Y7u7mnnEczx9+/C9xfnaEJfN7urkYb9/9OfeMg5umabxYfTxWq+V9EF3dvhuXt1dzzziO5w8/Xt5bhIUQJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHrxw4v1uen2nFSm7P1WE3L+11aTdOYpmnuGUex3W3H3zdXc884qWm/37/38M3VL+8//M9b3p82jWls1psxjeUF+sOvP40ff/t57hlH8fv3bx58YY/fnJtl3py7+92439/PPeMoprHM23O7uxtXt+/mnnFSy/u2g4UQJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHrxw5vtren2nFS13c343p7PfeMg1tNq/HhxbMxjWnuKQf37PyD8clHL+eecVKPxrndbU+146Quby7HXzeXc884uNU0jSdnT8bZankfRP/HOJf3FmEhxAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFq2u/3c28AHuDmhChxQpQ4IUqcECVOiBInRP0Di9dW5VavYc4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAABiCAYAAADz0wB7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAABtElEQVR4nO3aoU0FQRhG0V2yLIKnKQBDAbSGpgEsTaFeQgV0QDDYoQJQ5E6y7xw75lM3v5h1jLEA0LiaPQDgkoguQEh0AUKiCxASXYCQ6AKEtr8en99f/Sf7J6eb2+Xh7n72jMPYt+tl3/bZMw7j7eO8fH5/zZ5xGC+PT+tvby5dgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyAkugAh0QUIiS5ASHQBQqILEBJdgJDoAoREFyC0jjFmbwC4GC5dgJDoAoREFyAkugAh0QUIiS5A6Af0MhC/+I/o9QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAABiCAYAAADz0wB7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAABq0lEQVR4nO3asU0DQRRF0V3kjAzRAjVsSEnEVEBMSc7sGmjBkiPHQwNAhO5I3nPSSV509YNZxxgLAI2H2QMA9kR0AUKiCxASXYCQ6AKERBcgdPjr8e384T/ZP3l+fFpeX7bZM+BHx6/TcrldZ8+4G5/b+/rbm0sXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQiJLkBIdAFCogsQEl2AkOgChEQXICS6ACHRBQitY4zZGwB2w6ULEBJdgJDoAoREFyAkugAh0QUIfQP1fhC/52XRKwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAADgUlEQVR4nO3YPWpUUQCG4TuTIZoMalQQAhZqoY2dNlZuIZ0g7sC9uRsbm4ApgkjARPyJxlw3MExn7jvyPOU5zde8nMudjeM4AD3zqQcAq4kTosQJUeKEKHFC1GLd5ct3b/zKXePp/pPh9fODqWekHZ4cDac/v049I+3tg1ezVedeTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFqse5ya751VTs20uV4OZz+OJt6RtrWfD4st3emnrGR1sZ5fXHtqnZspN9/LoajL8dTz0i7u7w93Nq5OfWMjeSzFqLECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlToharLt88fDZVe3YSOcXv4b3xx+mnpH2+N6j4c7u3tQz2m6sPl4b5/29/X8x5b/x6ezzcHjyceoZad/Ovw/L7d2pZ2wkn7UQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFC1Gwcx6k3ACt4OSFKnBAlTogSJ0SJE6LECVF/AYVFLUZ+VLvxAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAB+CAYAAACHx8KbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAACCklEQVR4nO3aIU5DQRhG0fcIe6htgukKsGyiCYZ0A+yNHSHqcQRRNV0BiCfuhLxz7JhP3fxi1jHGAkDjYfYAgD0RXYCQ6AKERBcgJLoAIdEFCD3+9fjy8eY/2T91Ojwtl+fz7Bls9Pl1Xb5vP7NnsNH78XX97c2lCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQEh0AUKiCxASXYCQ6AKERBcgJLoAIdEFCIkuQGgdY8zeALAbLl2AkOgChEQXICS6ACHRBQiJLkDoDlQDEPdVbhk3AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for dummy attention\n",
    "import torch\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import cv2\n",
    "\n",
    "def imshow_attention(mat, output_path=None):\n",
    "    if isinstance(mat, torch.Tensor):\n",
    "        mat = mat.detach().cpu().numpy()\n",
    "    mat = cv2.resize(mat, None, fx=128, fy=128, interpolation=cv2.INTER_NEAREST)\n",
    "    H, W = mat.shape\n",
    "    mat = mat.reshape(H, W, 1)\n",
    "    high_color = np.array([10, 143, 45]).reshape(1,1,3)\n",
    "    low_color = np.array([199, 252, 213]).reshape(1,1,3)\n",
    "    mat = mat*high_color + (1-mat)*low_color\n",
    "    mat = mat.astype(np.uint8)\n",
    "    plt.imshow(mat)\n",
    "    # plt.yticks([])\n",
    "    # plt.xticks([])\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "    if output_path is not None:\n",
    "        cv2.imwrite(output_path, mat)\n",
    "\n",
    "score = torch.tensor([\n",
    "    [5, 1, 2, 1],\n",
    "    [3, 2, 1, 2],\n",
    "    [1, 2, 1, 5],\n",
    "    [1, 2, 1, 3],\n",
    "], dtype=torch.float32)\n",
    "prob = torch.softmax(score, dim=-1)\n",
    "reduced_prob = torch.mean(prob, dim=0, keepdim=True)\n",
    "\n",
    "#this is for figure 1\n",
    "imshow_attention(prob, './saves_plot/fig_struct_atten.png')\n",
    "imshow_attention(reduced_prob, './saves_plot/fig_struct_ff.png')\n",
    "imshow_attention(torch.mean(prob[[0,2],:], dim=0, keepdim=True), './saves_plot/fig_struct_abt.png')\n",
    "\n",
    "#this is for algorithm 1\n",
    "score = torch.tensor([\n",
    "    [5, 1, 4],\n",
    "    [3, 2, 1],\n",
    "    [1, 2, 1],\n",
    "], dtype=torch.float32)\n",
    "prob = torch.softmax(score, dim=-1)\n",
    "reduced_prob = prob[0,:].view(1, 3)\n",
    "\n",
    "imshow_attention(prob, './saves_plot/fig_alg_atten.png')\n",
    "imshow_attention(reduced_prob, './saves_plot/fig_alg_ff.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bert-base\t22.4 GFLOPs\t100%\n",
      "approx-bert@f4\t1.5 GFLOPs\t6.86%\n",
      "approx-bert@f8\t436.9 MFLOPs\t1.95%\n",
      "approx-bert@f16\t140.6 MFLOPs\t0.63%\n",
      "fwd-oh-0.2\t559.5 KFLOPs\t0.00%\n",
      "fwd-oh-0.4\t1.1 MFLOPs\t0.00%\n",
      "fwd-oh-1.0\t2.8 MFLOPs\t0.01%\n",
      "abt-oh@0.2\t597.4 KFLOPs\t0.00%\n",
      "abt-oh@0.4\t1.2 MFLOPs\t0.01%\n",
      "abt-oh@1.0\t2.8 MFLOPs\t0.01%\n",
      "conc-oh@0.2\t9.2 MFLOPs\t0.04%\n",
      "conc-oh@0.4\t18.5 MFLOPs\t0.08%\n",
      "conc-oh@1.0\t46.1 MFLOPs\t0.21%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_615774/3688780476.py:125: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  df.to_latex('./saves_plot/table_overhead_flops.tex')\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>FLOPs</th>\n",
       "      <th>Relative %</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BERT_{base}</th>\n",
       "      <th></th>\n",
       "      <td>22.4 GFLOPs</td>\n",
       "      <td>100%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Attention Approx. Net.</th>\n",
       "      <th>4</th>\n",
       "      <td>1.5 GFLOPs</td>\n",
       "      <td>6.86%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>436.9 MFLOPs</td>\n",
       "      <td>1.95%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>140.6 MFLOPs</td>\n",
       "      <td>0.63%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Manual Top-k</th>\n",
       "      <th>0.2</th>\n",
       "      <td>559.5 KFLOPs</td>\n",
       "      <td>0.00%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>1.1 MFLOPs</td>\n",
       "      <td>0.00%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>2.8 MFLOPs</td>\n",
       "      <td>0.01%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Attention Back-tracking</th>\n",
       "      <th>0.2</th>\n",
       "      <td>597.4 KFLOPs</td>\n",
       "      <td>0.00%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>1.2 MFLOPs</td>\n",
       "      <td>0.01%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>2.8 MFLOPs</td>\n",
       "      <td>0.01%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">Concrete Masking</th>\n",
       "      <th>0.2</th>\n",
       "      <td>9.2 MFLOPs</td>\n",
       "      <td>0.04%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>18.5 MFLOPs</td>\n",
       "      <td>0.08%</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>46.1 MFLOPs</td>\n",
       "      <td>0.21%</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    FLOPs Relative %\n",
       "BERT_{base}                   22.4 GFLOPs       100%\n",
       "Attention Approx. Net.  4      1.5 GFLOPs      6.86%\n",
       "                        8    436.9 MFLOPs      1.95%\n",
       "                        16   140.6 MFLOPs      0.63%\n",
       "Manual Top-k            0.2  559.5 KFLOPs      0.00%\n",
       "                        0.4    1.1 MFLOPs      0.00%\n",
       "                        1.0    2.8 MFLOPs      0.01%\n",
       "Attention Back-tracking 0.2  597.4 KFLOPs      0.00%\n",
       "                        0.4    1.2 MFLOPs      0.01%\n",
       "                        1.0    2.8 MFLOPs      0.01%\n",
       "Concrete Masking        0.2    9.2 MFLOPs      0.04%\n",
       "                        0.4   18.5 MFLOPs      0.08%\n",
       "                        1.0   46.1 MFLOPs      0.21%"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# flops table\n",
    "import copy, random\n",
    "import pandas as pd\n",
    "from utils import sparse_flops_calculation as calc\n",
    "\n",
    "SEQ = 128\n",
    "FACTORS = [4,8,16]\n",
    "OCCUPIES = [0.2, 0.4, 1.0]\n",
    "SAMPLES = 1000\n",
    "\n",
    "cols = []\n",
    "\n",
    "base_config = calc.ModelConfig(\n",
    "    num_layer=12,\n",
    "    num_heads=12,\n",
    "    hidden_size=768,\n",
    "    intermediate_size=768*4,\n",
    "    seq_len=SEQ,\n",
    "    arch='bert',\n",
    "    token_occupies=None\n",
    ")\n",
    "\n",
    "flops_bert_base = calc.flops_sparse_bert_model(base_config)\n",
    "print('bert-base', calc.human_readable(flops_bert_base), '100%', sep='\\t')\n",
    "cols.append((('BERT_{base}', ''), [calc.human_readable(flops_bert_base), '100%']))\n",
    "\n",
    "for factor in FACTORS:\n",
    "    config = copy.deepcopy(base_config)\n",
    "    config.hidden_size /= factor\n",
    "    config.intermediate_size /= factor\n",
    "    flops_approx_factor = calc.flops_sparse_bert_model(config)\n",
    "    print(\n",
    "        f'approx-bert@f{factor}', \n",
    "        calc.human_readable(flops_approx_factor), \n",
    "        f'{flops_approx_factor/flops_bert_base*100:.2f}%', \n",
    "        sep='\\t'\n",
    "    )\n",
    "    cols.append((\n",
    "        ('Attention Approx. Net.', f'{factor}'), \n",
    "        [\n",
    "            calc.human_readable(flops_approx_factor), \n",
    "            f'{flops_approx_factor/flops_bert_base*100:.2f}%',\n",
    "        ]\n",
    "    ))\n",
    "\n",
    "for occupy in OCCUPIES:\n",
    "    flops = 0\n",
    "    for _ in range(5000):\n",
    "        config = copy.deepcopy(base_config)\n",
    "        config.sparse_mode = 'forward'\n",
    "        config.token_occupies = [occupy for _ in range(config.num_layer+1)]\n",
    "        flops += calc.flops_sparse_update(config)\n",
    "    flops /= 5000\n",
    "    print(f'fwd-oh-{occupy}', \n",
    "        calc.human_readable(flops),\n",
    "        f'{flops/flops_bert_base*100:.2f}%', \n",
    "        sep='\\t'\n",
    "    )\n",
    "    cols.append((\n",
    "        ('Manual Top-k', f'{occupy}'), \n",
    "        [\n",
    "            calc.human_readable(flops),\n",
    "            f'{flops/flops_bert_base*100:.2f}%', \n",
    "        ]\n",
    "    ))\n",
    "\n",
    "for occupy in OCCUPIES:\n",
    "    flops = 0\n",
    "    for _ in range(5000):\n",
    "        config = copy.deepcopy(base_config)\n",
    "        config.approx_hidden_size = config.hidden_size / 4\n",
    "        config.approx_intermediate_size = config.intermediate_size / 4\n",
    "        config.sparse_mode = 'approx'\n",
    "        if 0.05 + 2*(occupy-0.05) < 1:\n",
    "            config.token_occupies = [random.random() * (2*(occupy-0.05)) + 0.05 for _ in range(config.num_layer+1)]\n",
    "        else:\n",
    "            config.token_occupies = [0.95-(1.0-occupy)*random.random()*2 + 0.05 for _ in range(config.num_layer+1)]\n",
    "        flops += calc.flops_sparse_update(config)\n",
    "    flops /= 5000\n",
    "    print(f'abt-oh@{occupy}', \n",
    "        calc.human_readable(flops),\n",
    "        f'{flops/flops_bert_base*100:.2f}%', \n",
    "        sep='\\t'\n",
    "    )\n",
    "    cols.append((\n",
    "        ('Attention Back-tracking', f'{occupy}'), \n",
    "        [\n",
    "            calc.human_readable(flops),\n",
    "            f'{flops/flops_bert_base*100:.2f}%', \n",
    "        ]\n",
    "    ))\n",
    "\n",
    "for occupy in OCCUPIES:\n",
    "    flops = 0\n",
    "    for _ in range(5000):\n",
    "        config = copy.deepcopy(base_config)\n",
    "        config.approx_hidden_size = config.hidden_size / 4\n",
    "        config.approx_intermediate_size = config.intermediate_size / 4\n",
    "        config.sparse_mode = 'concrete'\n",
    "        if 0.05 + 2*(occupy-0.05) < 1:\n",
    "            config.token_occupies = [random.random() * (2*(occupy-0.05)) + 0.05 for _ in range(config.num_layer+1)]\n",
    "        else:\n",
    "            config.token_occupies = [0.95-(1.0-occupy)*random.random()*2 + 0.05 for _ in range(config.num_layer+1)]\n",
    "        flops += calc.flops_sparse_update(config)\n",
    "    flops /= 5000\n",
    "    print(f'conc-oh@{occupy}', \n",
    "        calc.human_readable(flops),\n",
    "        f'{flops/flops_bert_base*100:.2f}%', \n",
    "        sep='\\t'\n",
    "    )\n",
    "    cols.append((\n",
    "        ('Concrete Masking', f'{occupy}'), \n",
    "        [\n",
    "            calc.human_readable(flops),\n",
    "            f'{flops/flops_bert_base*100:.2f}%', \n",
    "        ]\n",
    "    ))\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for header, content in cols:\n",
    "    df[header] = content\n",
    "df.columns = pd.MultiIndex.from_tuples([(c[0], c[1]) for c in df.columns])\n",
    "df.index = ['FLOPs', 'Relative %']\n",
    "df = df.transpose()\n",
    "df.to_latex('./saves_plot/table_overhead_flops.tex')\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.5 ('torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b60c80f82950b752af00aa00b10c7536815afcc403735623c9df587c041a5446"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
