{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b378f049",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=6\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cc96d42c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "from torchvision import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "446d37da",
   "metadata": {},
   "outputs": [],
   "source": [
    "MEMORY_BUDGET = 2**10 * 40\n",
    "\n",
    "class Identity(t.nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x\n",
    "\n",
    "            \n",
    "def get_mem():\n",
    "    return t.cuda.memory_allocated() / 2**20\n",
    "\n",
    "\n",
    "def pass_through_model(model, forward_fn):\n",
    "    out = forward_fn(model)\n",
    "    return get_mem()\n",
    "\n",
    "\n",
    "def calc_max_batch(param_memory, act_memory, nonlinear_memory, n_bits, total_bits):\n",
    "    model_memory = param_memory * 3\n",
    "    element_memory = act_memory + nonlinear_memory * n_bits / total_bits\n",
    "    return int((MEMORY_BUDGET - model_memory) / element_memory)\n",
    "\n",
    "\n",
    "def run_test(model_fn, forward_fn, delete_nonlinearity, total_bits):\n",
    "    print(get_mem())\n",
    "\n",
    "    model = model_fn()\n",
    "    model_size = get_mem()\n",
    "    print(model_size)\n",
    "    \n",
    "    act_sz = pass_through_model(model, forward_fn) - model_size\n",
    "    print(f'{act_sz=}, {get_mem()=}')\n",
    "    \n",
    "    delete_nonlinearity(model)\n",
    "    act_no_nonlin = pass_through_model(model, forward_fn) - model_size\n",
    "    nonlin = act_sz - act_no_nonlin\n",
    "    \n",
    "    print(f'{act_no_nonlin=}, {nonlin=}')\n",
    "    \n",
    "    act_saves = {}\n",
    "    for n_bits in [total_bits, 8, 4, 3, 2, 1]:\n",
    "        act_save = (1 - n_bits / total_bits) * nonlin / act_sz * 100\n",
    "        print(f'{act_save:.1f}', end=' ')\n",
    "        act_saves[n_bits] = act_save\n",
    "    print()\n",
    "\n",
    "    max_batchs = {}\n",
    "    for n_bits in [total_bits, 8, 4, 3, 2, 1]:\n",
    "        max_batch = calc_max_batch(model_size, act_sz, nonlin, n_bits, total_bits)\n",
    "        max_batchs[n_bits] = max_batch\n",
    "        print(f'{max_batch:.1f}', end=' ')\n",
    "    print()\n",
    "    print()\n",
    "    \n",
    "    return {\n",
    "        'model_size': model_size, \n",
    "        'act_size': act_sz, \n",
    "        'nonlin_size': nonlin, \n",
    "        'savings': act_saves, \n",
    "        'max_batches': max_batchs,\n",
    "    }\n",
    "              \n",
    "    \n",
    "    \n",
    "def run(func):\n",
    "    return run_test(*func())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "e779396e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "505157eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "171.4013671875\n",
      "act_sz=234.4794921875, get_mem()=171.4013671875\n",
      "act_no_nonlin=161.1044921875, nonlin=73.375\n",
      "0.0 23.5 27.4 28.4 29.3 30.3 \n",
      "131.0 159.0 165.0 167.0 169.0 170.0 \n",
      "\n",
      "0.0\n",
      "30.85888671875\n",
      "act_sz=243.7548828125, get_mem()=30.85888671875\n",
      "act_no_nonlin=164.6923828125, nonlin=79.0625\n",
      "0.0 24.3 28.4 29.4 30.4 31.4 \n",
      "126.0 155.0 161.0 162.0 164.0 165.0 \n",
      "\n",
      "0.0\n",
      "256.31005859375\n",
      "act_sz=673.4794921875, get_mem()=256.31005859375\n",
      "act_no_nonlin=494.591796875, nonlin=178.8876953125\n",
      "0.0 19.9 23.2 24.1 24.9 25.7 \n",
      "47.0 55.0 57.0 58.0 58.0 59.0 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "def conv_forward_fn(model):\n",
    "    return model(t.randn(1, 3, 256, 256).cuda())\n",
    "\n",
    "\n",
    "def apply(model, frm, to):\n",
    "    for name, child in model.named_children():\n",
    "        if isinstance(child, frm):\n",
    "            setattr(model, name, to())\n",
    "        else:\n",
    "            apply(child, frm, to)\n",
    "\n",
    "\n",
    "def conv_model(name, frm):\n",
    "    def create_model():\n",
    "        model = models.__dict__[name]().cuda()\n",
    "        if frm is nn.ReLU:\n",
    "            apply(model, nn.ReLU, nn.GELU)\n",
    "        return model\n",
    "    \n",
    "    def remove_nonlin(m):\n",
    "        apply(m, nn.GELU if frm is nn.ReLU else frm, Identity)\n",
    "        \n",
    "    return create_model, conv_forward_fn, remove_nonlin, 32\n",
    "\n",
    "\n",
    "CONV_MODELS = {\n",
    "    'ResNet-101': 'resnet101',\n",
    "    'DenseNet-121': 'densenet121',\n",
    "    'Efficient Net B7': 'efficientnet_b7',\n",
    "}\n",
    "\n",
    "for model_name, model in CONV_MODELS.items():\n",
    "    data[model_name] = run(lambda: conv_model(model, nn.ReLU if model != 'efficientnet_b7' else nn.SiLU))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "803a7d52",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_684876/1802646866.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mTRANS_MODELS\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m     \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtransformer_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "\n",
    "def transformer_model(name):\n",
    "    def create_model():\n",
    "        return AutoModelForSequenceClassification.from_pretrained(\n",
    "            name,\n",
    "            num_labels=2,\n",
    "        ).cuda()\n",
    "    \n",
    "    def forward_model(model):\n",
    "        inp = t.zeros(1, 256, dtype=t.long, device='cuda')\n",
    "        return model(inp)\n",
    "    \n",
    "    if 'roberta' in name:\n",
    "        def delete_nonlinearity(model):\n",
    "            for layer in model.roberta.encoder.layer:\n",
    "                layer.intermediate.intermediate_act_fn = Identity()\n",
    "    elif 'gpt' in name:\n",
    "        def delete_nonlinearity(model):\n",
    "            for h in model.transformer.h:\n",
    "                h.mlp.act = Identity()\n",
    "\n",
    "    \n",
    "    return create_model, forward_model, delete_nonlinearity, 32\n",
    "\n",
    "\n",
    "TRANS_MODELS = {\n",
    "    'RoBERTa-base': 'roberta-base',\n",
    "    'RoBERTa-large': 'roberta-large',\n",
    "    'GPT2': 'gpt2'\n",
    "}\n",
    "\n",
    "\n",
    "for model_name, model in TRANS_MODELS.items():\n",
    "    data[model_name] = run(lambda: transformer_model(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "0527ccad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    \\begin{table}[h] \\label{tab:memory}\n",
      "    \\begin{tabular}{|l|l|c|c|}\n",
      "        \\hline\n",
      "        Model     & Quant & Saving & \\makecell{Max \\\\ Batch \\\\ Size} \\\\ \\hline\n",
      "\n",
      "         \\textbf{ResNet-101} & & & 131 \\\\\n",
      "         256x256 size  & 1-bit & 30\\% & 170 (+29.8\\%) \\\\\n",
      "            & 2-bit & 29\\% & 169 (+29.0\\%) \\\\\n",
      "            & 3-bit & 28\\% & 167 (+27.5\\%) \\\\\n",
      "            & 4-bit & 27\\% & 165 (+26.0\\%) \\\\\n",
      "         \\textbf{DenseNet-121} & & & 126 \\\\\n",
      "         256x256 size  & 1-bit & 31\\% & 165 (+31.0\\%) \\\\\n",
      "            & 2-bit & 30\\% & 164 (+30.2\\%) \\\\\n",
      "            & 3-bit & 29\\% & 162 (+28.6\\%) \\\\\n",
      "            & 4-bit & 28\\% & 161 (+27.8\\%) \\\\\n",
      "         \\textbf{Efficient Net B7} & & & 47 \\\\\n",
      "         256x256 size  & 1-bit & 26\\% & 59 (+25.5\\%) \\\\\n",
      "            & 2-bit & 25\\% & 58 (+23.4\\%) \\\\\n",
      "            & 3-bit & 24\\% & 58 (+23.4\\%) \\\\\n",
      "            & 4-bit & 23\\% & 57 (+21.3\\%) \\\\\n",
      "         \\textbf{RoBERTa-base} & & & 154 \\\\\n",
      "         256 seq. len  & 1-bit & 16\\% & 179 (+16.2\\%) \\\\\n",
      "            & 2-bit & 15\\% & 178 (+15.6\\%) \\\\\n",
      "            & 3-bit & 15\\% & 177 (+14.9\\%) \\\\\n",
      "            & 4-bit & 14\\% & 176 (+14.3\\%) \\\\\n",
      "         \\textbf{RoBERTa-large} & & & 54 \\\\\n",
      "         256 seq. len  & 1-bit & 16\\% & 63 (+16.7\\%) \\\\\n",
      "            & 2-bit & 16\\% & 63 (+16.7\\%) \\\\\n",
      "            & 3-bit & 15\\% & 62 (+14.8\\%) \\\\\n",
      "            & 4-bit & 15\\% & 62 (+14.8\\%) \\\\\n",
      "         \\textbf{GPT2} & & & 83 \\\\\n",
      "         256 seq. len  & 1-bit & 42\\% & 117 (+41.0\\%) \\\\\n",
      "            & 2-bit & 41\\% & 116 (+39.8\\%) \\\\\n",
      "            & 3-bit & 39\\% & 114 (+37.3\\%) \\\\\n",
      "            & 4-bit & 38\\% & 113 (+36.1\\%) \\\\\n",
      "        \\hline\n",
      "    \\end{tabular}\n",
      "    \\caption{Memory savings and maximum batch size for popular models for different quantization budget. Memory is calculated with the expectation, that three copies of the model are stored on device (one copy is the actual weights, second copy is gradients and third copy is optimizer statistics like momentum for SGD with momentum) for NVIDIA A100 GPU with 40Gb memory. Memory saving is calculated with respect to only activation memory, because total storage gain depends on chosen batch size.}\n",
      "    \\end{table}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(r'''\n",
    "    \\begin{table}[h] \\label{table:memory}\n",
    "    \\centering\n",
    "    \\resizebox{.41\\textwidth{!}{\n",
    "    \\begin{tabular}{|l|l|c|c|}\n",
    "        \\hline\n",
    "        Model     & Quant & Saving & \\makecell{Max \\\\ Batch \\\\ Size} \\\\ \\hline\n",
    "''')\n",
    "\n",
    "\n",
    "\n",
    "for model_name, (s, v) in data.items():\n",
    "    def get(bits):\n",
    "        return f'' \n",
    "    \n",
    "    for i, bits in enumerate([32, 1, 2, 3, 4]):\n",
    "        if i == 0:\n",
    "            prefix = f'\\\\textbf{{{model_name}}}'\n",
    "        elif i == 1:\n",
    "            prefix = f'256x256 size' if model_name in CONV_MODELS else '256 seq. len'\n",
    "        else:\n",
    "            prefix = f' '\n",
    "        \n",
    "        if bits == 32:\n",
    "            suffix = f'& & & {v[bits]} \\\\\\\\'\n",
    "        else:\n",
    "            suffix = f' & {bits}-bit & {s[bits]:.0f}\\% & {v[bits]} (+{100 * v[bits] / v[32] - 100:.1f}\\\\%) \\\\\\\\'\n",
    "\n",
    "        print(' ' * 8, prefix, suffix)\n",
    "            \n",
    "print('''        \\hline\n",
    "    \\end{tabular}\n",
    "    }\n",
    "    \\caption{Memory savings and maximum batch size for popular models for different quantization budget. Memory is calculated with the expectation, that three copies of the model are stored on device (one copy is the actual weights, second copy is gradients and third copy is optimizer statistics like momentum for SGD with momentum) for NVIDIA A100 GPU with 40Gb memory. Memory saving is calculated with respect to only activation memory, because total storage gain depends on chosen batch size.}\n",
    "    \\end{table}\n",
    "''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fe050d",
   "metadata": {},
   "outputs": [],
   "source": [
    "models.v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "a5fa218d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "44.6435546875\n",
      "act_sz=40.044921875, get_mem()=44.6435546875\n",
      "act_no_nonlin=28.544921875, nonlin=11.5\n",
      "0.0 21.5 25.1 26.0 26.9 27.8 \n",
      "792.0 951.0 984.0 992.0 1001.0 1010.0 \n",
      "\n",
      "0.0\n",
      "99.22802734375\n",
      "act_sz=156.7802734375, get_mem()=99.22802734375\n",
      "act_no_nonlin=108.9052734375, nonlin=47.875\n",
      "0.0 22.9 26.7 27.7 28.6 29.6 \n",
      "198.0 240.0 249.0 252.0 254.0 256.0 \n",
      "\n",
      "0.0\n",
      "171.4013671875\n",
      "act_sz=234.4794921875, get_mem()=171.4013671875\n",
      "act_no_nonlin=161.1044921875, nonlin=73.375\n",
      "0.0 23.5 27.4 28.4 29.3 30.3 \n",
      "131.0 159.0 165.0 167.0 169.0 170.0 \n",
      "\n",
      "0.0\n",
      "232.27783203125\n",
      "act_sz=328.1552734375, get_mem()=232.27783203125\n",
      "act_no_nonlin=223.2802734375, nonlin=104.875\n",
      "0.0 24.0 28.0 29.0 30.0 31.0 \n",
      "92.0 113.0 117.0 119.0 120.0 121.0 \n",
      "\n",
      "0.0\n",
      "30.85888671875\n",
      "act_sz=243.7548828125, get_mem()=30.85888671875\n",
      "act_no_nonlin=164.6923828125, nonlin=79.0625\n",
      "0.0 24.3 28.4 29.4 30.4 31.4 \n",
      "126.0 155.0 161.0 162.0 164.0 165.0 \n",
      "\n",
      "0.0\n",
      "112.3837890625\n",
      "act_sz=457.205078125, get_mem()=112.3837890625\n",
      "act_no_nonlin=311.861328125, nonlin=145.34375\n",
      "0.0 23.8 27.8 28.8 29.8 30.8 \n",
      "67.0 82.0 85.0 86.0 87.0 87.0 \n",
      "\n",
      "0.0\n",
      "54.72412109375\n",
      "act_sz=296.26611328125, get_mem()=54.72412109375\n",
      "act_no_nonlin=200.92236328125, nonlin=95.34375\n",
      "0.0 24.1 28.2 29.2 30.2 31.2 \n",
      "104.0 127.0 132.0 133.0 134.0 136.0 \n",
      "\n",
      "0.0\n",
      "77.39208984375\n",
      "act_sz=382.22216796875, get_mem()=77.39208984375\n",
      "act_no_nonlin=258.28466796875, nonlin=123.9375\n",
      "0.0 24.3 28.4 29.4 30.4 31.4 \n",
      "80.0 98.0 102.0 103.0 104.0 105.0 \n",
      "\n",
      "0.0\n",
      "20.41064453125\n",
      "act_sz=112.43603515625, get_mem()=20.41064453125\n",
      "act_no_nonlin=79.99853515625, nonlin=32.4375\n",
      "0.0 21.6 25.2 26.1 27.0 27.9 \n",
      "282.0 339.0 351.0 354.0 357.0 360.0 \n",
      "\n",
      "0.0\n",
      "47.5458984375\n",
      "act_sz=218.56298828125, get_mem()=47.5458984375\n",
      "act_no_nonlin=159.06982421875, nonlin=59.4931640625\n",
      "0.0 20.4 23.8 24.7 25.5 26.4 \n",
      "146.0 174.0 180.0 182.0 183.0 185.0 \n",
      "\n",
      "0.0\n",
      "256.31005859375\n",
      "act_sz=673.4794921875, get_mem()=256.31005859375\n",
      "act_no_nonlin=494.591796875, nonlin=178.8876953125\n",
      "0.0 19.9 23.2 24.1 24.9 25.7 \n",
      "47.0 55.0 57.0 58.0 58.0 59.0 \n",
      "\n",
      "0.0\n",
      "507.208984375\n",
      "act_sz=100.919921875, get_mem()=507.208984375\n",
      "act_no_nonlin=63.888671875, nonlin=37.03125\n",
      "0.0 27.5 32.1 33.3 34.4 35.5 \n",
      "285.0 357.0 373.0 377.0 382.0 386.0 \n",
      "\n",
      "0.0\n",
      "528.16796875\n",
      "act_sz=163.794921875, get_mem()=528.16796875\n",
      "act_no_nonlin=95.263671875, nonlin=68.53125\n",
      "0.0 31.4 36.6 37.9 39.2 40.5 \n",
      "169.0 217.0 228.0 231.0 234.0 237.0 \n",
      "\n",
      "0.0\n",
      "548.4228515625\n",
      "act_sz=178.794921875, get_mem()=548.4228515625\n",
      "act_no_nonlin=103.763671875, nonlin=75.03125\n",
      "0.0 31.5 36.7 38.0 39.3 40.7 \n",
      "154.0 199.0 208.0 211.0 214.0 217.0 \n",
      "\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.bias', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "480.73388671875\n",
      "act_sz=219.55615234375, get_mem()=480.73388671875\n",
      "act_no_nonlin=183.55615234375, nonlin=36.0\n",
      "0.0 12.3 14.3 14.9 15.4 15.9 \n",
      "154.0 172.0 176.0 177.0 178.0 179.0 \n",
      "\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.bias', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1355.60693359375\n",
      "act_sz=578.10400390625, get_mem()=1355.60693359375\n",
      "act_no_nonlin=482.10400390625, nonlin=96.0\n",
      "0.0 12.5 14.5 15.0 15.6 16.1 \n",
      "54.0 61.0 62.0 62.0 63.0 63.0 \n",
      "\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "490.974609375\n",
      "act_sz=331.0537109375, get_mem()=490.974609375\n",
      "act_no_nonlin=187.0537109375, nonlin=144.0\n",
      "0.0 32.6 38.1 39.4 40.8 42.1 \n",
      "83.0 107.0 113.0 114.0 116.0 117.0 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "CONV_MODELS = {\n",
    "    'ResNet-18': 'resnet18',\n",
    "    'ResNet-50': 'resnet50',\n",
    "    'ResNet-101': 'resnet101',\n",
    "    'ResNet-152': 'resnet152',\n",
    "    'DenseNet-121': 'densenet121',\n",
    "    'DenseNet-161': 'densenet161',\n",
    "    'DenseNet-169': 'densenet169',\n",
    "    'DenseNet-201': 'densenet201',\n",
    "    'Efficient Net B0': 'efficientnet_b0',\n",
    "#     'Efficient Net B1': 'efficientnet_b1',\n",
    "#     'Efficient Net B2': 'efficientnet_b2',\n",
    "    'Efficient Net B3': 'efficientnet_b3',\n",
    "#     'Efficient Net B4': 'efficientnet_b4',\n",
    "#     'Efficient Net B5': 'efficientnet_b5',\n",
    "#     'Efficient Net B6': 'efficientnet_b6',\n",
    "    'Efficient Net B7': 'efficientnet_b7',\n",
    "    'VGG 11': 'vgg11',\n",
    "    'VGG 16': 'vgg16',\n",
    "    'VGG 19': 'vgg19',\n",
    "}\n",
    "\n",
    "\n",
    "TRANS_MODELS = {\n",
    "    'RoBERTa-base': 'roberta-base',\n",
    "    'RoBERTa-large': 'roberta-large',\n",
    "    'GPT2': 'gpt2'\n",
    "}\n",
    "\n",
    "\n",
    "data_full = {}\n",
    "\n",
    "for model_name, model in CONV_MODELS.items():\n",
    "    data_full[model_name] = run(lambda: conv_model(model, nn.ReLU if 'efficientnet' not in model else nn.SiLU))\n",
    "\n",
    "for model_name, model in TRANS_MODELS.items():\n",
    "    data_full[model_name] = run(lambda: transformer_model(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "4bdcfb14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    \\begin{table}[H] \\label{tab:memory}\n",
      "    \\begin{tabular}{|l|c|c|c|c|c|c|c|}\n",
      "        \\hline\n",
      "             & \\makecell{Model \\\\ Size \\\\ (Mb)} & \\makecell{All\\\\\\small Activations\\\\Size\\\\(Mb)} & \\makecell{{\\small Nonlinearity\\\\\\small Activations\\\\Size\\\\(Mb)}} & \\makecell{1-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{2-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{3-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{4-bit\\\\Saving\\\\\\small{Max batch size}} \\\\ \\hline\n",
      "\n",
      "        \\textbf{ResNet-18} & 44.6 & 40.0 & 11.5 & \\makecell{28\\% \\\\ 1010 (+27.5\\%)} & \\makecell{27\\% \\\\ 1001 (+26.4\\%)} & \\makecell{26\\% \\\\ 992 (+25.3\\%)} & \\makecell{25\\% \\\\ 984 (+24.2\\%)} \\\\ \\hline\n",
      "        \\textbf{ResNet-50} & 99.2 & 156.8 & 47.9 & \\makecell{30\\% \\\\ 256 (+29.3\\%)} & \\makecell{29\\% \\\\ 254 (+28.3\\%)} & \\makecell{28\\% \\\\ 252 (+27.3\\%)} & \\makecell{27\\% \\\\ 249 (+25.8\\%)} \\\\ \\hline\n",
      "        \\textbf{ResNet-101} & 171.4 & 234.5 & 73.4 & \\makecell{30\\% \\\\ 170 (+29.8\\%)} & \\makecell{29\\% \\\\ 169 (+29.0\\%)} & \\makecell{28\\% \\\\ 167 (+27.5\\%)} & \\makecell{27\\% \\\\ 165 (+26.0\\%)} \\\\ \\hline\n",
      "        \\textbf{ResNet-152} & 232.3 & 328.2 & 104.9 & \\makecell{31\\% \\\\ 121 (+31.5\\%)} & \\makecell{30\\% \\\\ 120 (+30.4\\%)} & \\makecell{29\\% \\\\ 119 (+29.3\\%)} & \\makecell{28\\% \\\\ 117 (+27.2\\%)} \\\\ \\hline\n",
      "        \\textbf{DenseNet-121} & 30.9 & 243.8 & 79.1 & \\makecell{31\\% \\\\ 165 (+31.0\\%)} & \\makecell{30\\% \\\\ 164 (+30.2\\%)} & \\makecell{29\\% \\\\ 162 (+28.6\\%)} & \\makecell{28\\% \\\\ 161 (+27.8\\%)} \\\\ \\hline\n",
      "        \\textbf{DenseNet-161} & 112.4 & 457.2 & 145.3 & \\makecell{31\\% \\\\ 87 (+29.9\\%)} & \\makecell{30\\% \\\\ 87 (+29.9\\%)} & \\makecell{29\\% \\\\ 86 (+28.4\\%)} & \\makecell{28\\% \\\\ 85 (+26.9\\%)} \\\\ \\hline\n",
      "        \\textbf{DenseNet-169} & 54.7 & 296.3 & 95.3 & \\makecell{31\\% \\\\ 136 (+30.8\\%)} & \\makecell{30\\% \\\\ 134 (+28.8\\%)} & \\makecell{29\\% \\\\ 133 (+27.9\\%)} & \\makecell{28\\% \\\\ 132 (+26.9\\%)} \\\\ \\hline\n",
      "        \\textbf{DenseNet-201} & 77.4 & 382.2 & 123.9 & \\makecell{31\\% \\\\ 105 (+31.2\\%)} & \\makecell{30\\% \\\\ 104 (+30.0\\%)} & \\makecell{29\\% \\\\ 103 (+28.8\\%)} & \\makecell{28\\% \\\\ 102 (+27.5\\%)} \\\\ \\hline\n",
      "        \\textbf{Efficient Net B0} & 20.4 & 112.4 & 32.4 & \\makecell{28\\% \\\\ 360 (+27.7\\%)} & \\makecell{27\\% \\\\ 357 (+26.6\\%)} & \\makecell{26\\% \\\\ 354 (+25.5\\%)} & \\makecell{25\\% \\\\ 351 (+24.5\\%)} \\\\ \\hline\n",
      "        \\textbf{Efficient Net B3} & 47.5 & 218.6 & 59.5 & \\makecell{26\\% \\\\ 185 (+26.7\\%)} & \\makecell{26\\% \\\\ 183 (+25.3\\%)} & \\makecell{25\\% \\\\ 182 (+24.7\\%)} & \\makecell{24\\% \\\\ 180 (+23.3\\%)} \\\\ \\hline\n",
      "        \\textbf{Efficient Net B7} & 256.3 & 673.5 & 178.9 & \\makecell{26\\% \\\\ 59 (+25.5\\%)} & \\makecell{25\\% \\\\ 58 (+23.4\\%)} & \\makecell{24\\% \\\\ 58 (+23.4\\%)} & \\makecell{23\\% \\\\ 57 (+21.3\\%)} \\\\ \\hline\n",
      "        \\textbf{VGG 11} & 507.2 & 100.9 & 37.0 & \\makecell{36\\% \\\\ 386 (+35.4\\%)} & \\makecell{34\\% \\\\ 382 (+34.0\\%)} & \\makecell{33\\% \\\\ 377 (+32.3\\%)} & \\makecell{32\\% \\\\ 373 (+30.9\\%)} \\\\ \\hline\n",
      "        \\textbf{VGG 16} & 528.2 & 163.8 & 68.5 & \\makecell{41\\% \\\\ 237 (+40.2\\%)} & \\makecell{39\\% \\\\ 234 (+38.5\\%)} & \\makecell{38\\% \\\\ 231 (+36.7\\%)} & \\makecell{37\\% \\\\ 228 (+34.9\\%)} \\\\ \\hline\n",
      "        \\textbf{VGG 19} & 548.4 & 178.8 & 75.0 & \\makecell{41\\% \\\\ 217 (+40.9\\%)} & \\makecell{39\\% \\\\ 214 (+39.0\\%)} & \\makecell{38\\% \\\\ 211 (+37.0\\%)} & \\makecell{37\\% \\\\ 208 (+35.1\\%)} \\\\ \\hline\n",
      "        \\textbf{RoBERTa-base} & 480.7 & 219.6 & 36.0 & \\makecell{16\\% \\\\ 179 (+16.2\\%)} & \\makecell{15\\% \\\\ 178 (+15.6\\%)} & \\makecell{15\\% \\\\ 177 (+14.9\\%)} & \\makecell{14\\% \\\\ 176 (+14.3\\%)} \\\\ \\hline\n",
      "        \\textbf{RoBERTa-large} & 1355.6 & 578.1 & 96.0 & \\makecell{16\\% \\\\ 63 (+16.7\\%)} & \\makecell{16\\% \\\\ 63 (+16.7\\%)} & \\makecell{15\\% \\\\ 62 (+14.8\\%)} & \\makecell{15\\% \\\\ 62 (+14.8\\%)} \\\\ \\hline\n",
      "        \\textbf{GPT2} & 491.0 & 331.1 & 144.0 & \\makecell{42\\% \\\\ 117 (+41.0\\%)} & \\makecell{41\\% \\\\ 116 (+39.8\\%)} & \\makecell{39\\% \\\\ 114 (+37.3\\%)} & \\makecell{38\\% \\\\ 113 (+36.1\\%)} \\\\ \\hline\n",
      "    \\end{tabular}\n",
      "    \\end{table}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(r'''\n",
    "    \\begin{table}[H] \\label{table:appendix-memory}\n",
    "    \\begin{tabular}{|l|c|c|c|c|c|c|c|}\n",
    "        \\hline\n",
    "             & \\makecell{Model \\\\ Size \\\\ (Mb)} & \\makecell{All\\\\\\small Activations\\\\Size\\\\(Mb)} & \\makecell{{\\small Nonlinearity\\\\\\small Activations\\\\Size\\\\(Mb)}} & \\makecell{1-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{2-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{3-bit\\\\Saving\\\\\\small{Max batch size}} & \\makecell{4-bit\\\\Saving\\\\\\small{Max batch size}} \\\\ \\hline\n",
    "''')\n",
    "\n",
    "for model_name, values in data_full.items():\n",
    "    model_size = values['model_size']\n",
    "    act_size = values['act_size']\n",
    "    nonlin_size = values['nonlin_size']\n",
    "    savings = values['savings']\n",
    "    mbs = values['max_batches']\n",
    "    \n",
    "    row = f'        \\\\textbf{{{model_name}}} & {model_size:.1f} & {act_size:.1f} & {nonlin_size:.1f} '\n",
    "    for n_bits in range(1, 5):\n",
    "        row = row + f'& \\\\makecell{{{savings[n_bits]:.0f}\\\\% \\\\\\\\ {mbs[n_bits]} (+{100 * mbs[n_bits] / mbs[32] - 100:.1f}\\\\%)}} '\n",
    "    \n",
    "    print(row + r'\\\\ \\hline')\n",
    "    \n",
    "print('''    \\end{tabular}\n",
    "    \\end{table}\n",
    "''')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
