{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install pulp, if you already have it, you can skip this step\n",
    "# ! pip install pulp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pulp import *\n",
    "import pulp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class args():\n",
    "    def __init__(self):\n",
    "        pass\n",
    "args = args\n",
    "args.model_size_limit = 0.5 # the model size limit, 0.25 means, 4bit_model_size * 0.75 + 8bit_model_size * 0.25\n",
    "args.bops_limit = 0.5 # same definition as above\n",
    "args.latency_limit = 0.5 # same definition as above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# those are different matrics\n",
    "\n",
    "# Hutchinson_trace means the trace of Hessian for each weight matrix.\n",
    "# Particular, it has 19 elements since ResNet18 has 18 layers, removing the first and last layer \n",
    "# we still have 16 layers (the other three come from residual connection layer)\n",
    "# These values are already normlized, i.e., Trace / # arameters\n",
    "Hutchinson_trace = np.array([0.06857826, 0.03162379, 0.03298575, 0.01205663, 0.02222431, 0.00596336, 0.06931772, 0.00807129, 0.00372905, 0.00530698, 0.00209011, 0.00737569, 0.00210454, 0.00151197, 0.00158041,0.00078146, 0.00451841, 0.00098745, 0.00072944])\n",
    "# Delta Weight 8 bit Square means \\| W_fp32 - W_int8  \\|_2^2\n",
    "delta_weights_8bit_square = np.array([0.0235, 0.0125, 0.0102, 0.0082, 0.0145, 0.0344, 0.0023, 0.0287, 0.0148, 0.0333, 0.0682, 0.0027, 0.0448, 0.0336, 0.0576, 0.1130, 0.0102, 0.0947, 0.0532]) #  = (w_fp32 - w_int8)^2\n",
    "# Delta Weight 4 bit Square means \\| W_fp32 - W_int4  \\|_2^2\n",
    "delta_weights_4bit_square = np.array([6.7430, 3.9691, 3.3281, 2.6796, 4.7277, 10.5966, 0.6827, 9.0942, 4.8857, 10.7599, 21.7546, 0.8603, 14.5324, 10.9651, 18.7706, 36.4044, 3.1572, 29.6994, 17.4016]) #  = (w_fp32 - w_int4)^2\n",
    "# number of paramers of each layer\n",
    "parameters = np.array([ 36864, 36864, 36864, 36864, 73728, 147456, 8192, 147456, 147456, 294912, 589824, 32768, 589824, 589824, 1179648, 2359296, 131072, 2359296, 2359296]) / 1024 / 1024 # make it millition based (1024 is due to Byte to MB see next cell for model size computation)\n",
    "# Bit Operators of each layer\n",
    "bops = np.array([115605504, 115605504, 115605504, 115605504, 57802752, 115605504, 6422528, 115605504, 115605504, 57802752, 115605504, 6422528, 115605504, 115605504, 57802752, 115605504, 6422528, 115605504, 115605504]) / 1000000 # num of bops for each layer/block\n",
    "# latency for INT4 and INT8, measured on T4 GPU\n",
    "latency_int4 = np.array([0.21094404, 0.21092674, 0.21104113, 0.21086851, 0.13642465, 0.19167506, 0.02532183, 0.19148203, 0.19142914, 0.11395316, 0.20556011, 0.01917474, 0.20566918, 0.20566509, 0.13185102, 0.22287786, 0.01790088, 0.22304611, 0.22286099])\n",
    "latency_int8 = np.array([0.36189111, 0.36211718, 0.31141909, 0.30454471, 0.19184896, 0.38948934, 0.0334169, 0.38904905, 0.3892859, 0.19134735, 0.34307431, 0.02802354, 0.34313329, 0.34310756, 0.21117103, 0.37376585, 0.02896843, 0.37398187, 0.37405185])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# do some calculatations first\n",
    "# model size\n",
    "model_size_32bit = np.sum(parameters) * 4. # MB\n",
    "model_size_8bit = model_size_32bit / 4. # 8bit model is 1/4 of 32bit model \n",
    "model_size_4bit = model_size_32bit / 8. # 4bit model is 1/8 of 32bit model \n",
    "# as mentioned previous, that's how we set the model size limit\n",
    "model_size_limit = model_size_4bit + (model_size_8bit - model_size_4bit) * args.model_size_limit\n",
    "\n",
    "# bops\n",
    "bops_8bit = bops / 4. / 4. # For Wx, we have two matrices, so that we need the (bops / 4 / 4)\n",
    "bops_4bit = bops / 8. / 8. # Similar to above line\n",
    "bops_limit = np.sum(bops_4bit) + (np.sum(bops_8bit) - np.sum(bops_4bit)) * args.bops_limit # similar to model size\n",
    "\n",
    "# latency \n",
    "latency_limit = np.sum(latency_int4) + (np.sum(latency_int8) - np.sum(latency_int4)) * args.latency_limit # similar to model size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Size Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"Model_Size\", LpMinimize)\n",
    "prob += sum([0.5 * variable[f\"x{i}\"] * parameters[i] for i in range(num_variable) ]) <= model_size_limit # 1 million 8 bit numbers means 1 Mb, here 0.5 * 2 = 1, similar for 4 bit\n",
    "\n",
    "# add downsampling layer constraint, here we make the residual connetction\n",
    "# layer have the same bit as the main stream\n",
    "prob += variable[f\"x4\"] ==variable[f\"x6\"]\n",
    "prob += variable[f\"x9\"] ==variable[f\"x11\"]\n",
    "prob += variable[f\"x14\"] ==variable[f\"x16\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed bops, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bops Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"BOPS\", LpMinimize)\n",
    "\n",
    "bops_difference_between_4_8 = bops_8bit - bops_4bit\n",
    "\n",
    "# add bops constraint\n",
    "prob += sum([ bops_4bit[i] + (variable[f\"x{i}\"] - 1) * bops_difference_between_4_8[i] for i in range(num_variable) ]) <= bops_limit # if 4 bit, x - 1 = 0, we use bops_4, if 8 bit, x-1=2, we use bops_diff + bos_4 = bops_8\n",
    "# add downsampling layer constraint\n",
    "prob += variable[f\"x4\"] ==variable[f\"x6\"]\n",
    "prob += variable[f\"x9\"] ==variable[f\"x11\"]\n",
    "prob += variable[f\"x14\"] ==variable[f\"x16\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed bops, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# latency Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"BOPS\", LpMinimize)\n",
    "\n",
    "latency_difference_between_4_8 = latency_int8 - latency_int4\n",
    "\n",
    "# add bops constraint\n",
    "prob += sum([ latency_int4[i] + (variable[f\"x{i}\"] - 1) * latency_difference_between_4_8[i] for i in range(num_variable) ]) <= latency_limit # if 4 bit, x - 1 = 0, we use bops_4, if 8 bit, x-1=2, we use bops_diff + bos_4 = bops_8\n",
    "# add downsampling layer constraint\n",
    "prob += variable[f\"x4\"] ==variable[f\"x6\"]\n",
    "prob += variable[f\"x9\"] ==variable[f\"x11\"]\n",
    "prob += variable[f\"x14\"] ==variable[f\"x16\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed latency, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResNet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class args():\n",
    "    def __init__(self):\n",
    "        pass\n",
    "args = args\n",
    "args.model_size_limit = 0.5\n",
    "args.bops_limit = 0.5\n",
    "args.latency_limit = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Hutchinson_trace = np.array([0.19270050525665214, 0.026702478528024147, 0.09448622912168206, 0.07639402896166123, 0.009688886813818761, 0.006711237132549379, 0.024378638714551495, 0.011769247241317709, 0.00250671175308565, 0.008039989508690044, 0.008477769792080075, 0.0017263438785415808, 0.00313899479806392, 0.004365175962448905, 0.0015507987700405636, 0.0015619333134963605, 0.002701229648662333, 0.002383587881922602, 0.0010320576839146385, 0.0025724433362483905, 0.0046385480090998495, 0.0008540530106982067, 0.003809824585913301, 0.0019535077735770134, 0.0005609235377052988, 0.0018696154002090816, 0.0016933275619534234, 0.000607571622821307, 0.00044498199713291893, 0.0008382285595869651, 0.0006031448720029394, 0.0003951200051233584, 0.000718781026080718, 0.0008380718063555912, 0.00037973490543638077, 0.000828751828521126, 0.0013068615226087446, 0.0004476536996660732, 0.0012669296702365405, 0.0013098433846609064, 0.00032695921254355777, 0.0012137993471687096, 0.0010635341750463423, 0.0002480300900066659, 0.0005943655269220451, 0.0007215500227179646, 0.000612712872680123, 0.00025412911781990733, 0.0005049526807847556, 0.0007025265367691074, 0.0002485941222403296, 0.000525604293216296])\n",
    "delta_weights_8bit_square = np.array([0.0031, 0.0025, 0.0123, 0.0025, 0.0024, 0.0062, 0.0019, 0.0023, 0.0030, 0.0010, 0.0140, 0.0052, 0.0204, 0.0054, 0.0064, 0.0141, 0.0041, 0.0067, 0.0094, 0.0029, 0.0103, 0.0095, 0.0030, 0.0243, 0.0140, 0.0371, 0.0131, 0.0144, 0.0270, 0.0078, 0.0147, 0.0192, 0.0066, 0.0162, 0.0210, 0.0063, 0.0161, 0.0169, 0.0054, 0.0194, 0.0241, 0.0061, 0.0436, 0.0267, 0.0551, 0.0217, 0.0533, 0.0589, 0.0212, 0.0562, 0.0392, 0.0246])\n",
    "delta_weights_4bit_square = np.array([0.9257, 0.6624, 3.4605, 0.7805, 0.8040, 2.0657, 0.5947, 0.7596, 0.9826, 0.3381, 4.0922, 1.5506, 6.0288, 1.7186, 2.0701, 4.6080, 1.3128, 2.1613, 3.0710, 0.9209, 3.2794, 3.0883, 0.9841, 7.7154, 4.3872, 11.7131,4.1949, 4.6766, 8.7215, 2.5369, 4.8132, 6.1627, 2.1765, 5.2257, 6.7500, 2.0171, 5.2612, 5.5459, 1.7509, 6.2498, 7.4798, 1.9959, 13.8412,8.3407, 18.0200,6.9997, 15.3605,17.5150,6.4017, 16.0258, 10.7682, 8.0800])\n",
    "parameters = np.array([0.004096, 0.036864, 0.016384, 0.016384, 0.016384, 0.036864, 0.016384, 0.016384, 0.036864, 0.016384, 0.032768, 0.147456, 0.065536, 0.131072, 0.065536, 0.147456, 0.065536, 0.065536, 0.147456, 0.065536, 0.065536, 0.147456, 0.065536, 0.131072, 0.589824, 0.262144, 0.524288, 0.262144, 0.589824, 0.262144, 0.262144, 0.589824, 0.262144, 0.262144, 0.589824, 0.262144, 0.262144, 0.589824, 0.262144, 0.262144, 0.589824, 0.262144, 0.524288, 2.359296, 1.048576, 2.097152, 1.048576, 2.359296, 1.048576, 1.048576, 2.359296, 1.048576])\n",
    "bops = np.array([12.845056, 115.605504, 51.380224, 51.380224, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 25.690112, 115.605504, 51.380224, 102.760448, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 25.690112, 115.605504, 51.380224, 102.760448, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224, 25.690112, 115.605504, 51.380224, 102.760448, 51.380224, 115.605504, 51.380224, 51.380224, 115.605504, 51.380224])\n",
    "\n",
    "latency_int4 = np.array([0.07018191, 0.22481827, 0.27510223, 0.27512618, 0.10512171, 0.22473146, 0.21536969, 0.10506169, 0.22481533, 0.18938694, 0.05091347, 0.19677368, 0.15108796, 0.16824758, 0.0848311, 0.19684267, 0.15120609, 0.08481235, 0.19667935, 0.15121574, 0.08476754, 0.19673882, 0.15123765, 0.04866963, 0.20606097, 0.10150094, 0.169953, 0.09355929, 0.20598819, 0.10146334, 0.09362463, 0.20609737, 0.10188489, 0.09360134, 0.20618122, 0.10150974, 0.09352375, 0.20603792, 0.1015803, 0.09365549, 0.20599906, 0.10143032, 0.08047627, 0.207775, 0.08819472, 0.17924989, 0.10268331, 0.20780669, 0.08795617, 0.10265705, 0.20781191, 0.08793056])\n",
    "latency_int8 = np.array([0.07079369, 0.36040637, 0.27649506, 0.276513, 0.1611783, 0.36101955, 0.276361, 0.16111273, 0.36064923, 0.24884412, 0.07218692, 0.34899773, 0.19600394, 0.29505887, 0.16074604, 0.34904888, 0.19603683, 0.16067593, 0.34919751, 0.1960506, 0.16089535, 0.34916001, 0.19616036, 0.08172319, 0.33166738, 0.15534599, 0.30712053, 0.15510333, 0.33172639, 0.15541728, 0.15290036, 0.3317953, 0.15551916, 0.15252759, 0.33165544, 0.15529572, 0.15229388, 0.33187364, 0.15545004, 0.15275891, 0.33161443, 0.15566678, 0.08390548, 0.33057286, 0.15200003, 0.33919933, 0.16635967, 0.33092922, 0.15179218, 0.16672966, 0.33086648, 0.1519619])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# do some calculatations first\n",
    "# model size\n",
    "model_size_32bit = np.sum(parameters) * 4. # MB\n",
    "model_size_8bit = model_size_32bit / 4.\n",
    "model_size_4bit = model_size_32bit / 8.\n",
    "model_size_limit = model_size_4bit + (model_size_8bit - model_size_4bit) * args.model_size_limit\n",
    "\n",
    "# bops\n",
    "bops_8bit = bops / 4. / 4.\n",
    "bops_4bit = bops / 8. / 8.\n",
    "bops_limit = np.sum(bops_4bit) + (np.sum(bops_8bit) - np.sum(bops_4bit)) * args.bops_limit\n",
    "\n",
    "# latency \n",
    "latency_limit = np.sum(latency_int4) + (np.sum(latency_int8) - np.sum(latency_int4)) * args.latency_limit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Size Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"Model_Size\", LpMinimize)\n",
    "prob += sum([0.5 * variable[f\"x{i}\"] * parameters[i] for i in range(num_variable) ]) <= model_size_limit # 1 million 8 bit numbers means 1 Mb, here 0.5 * 2 = 1, similar for 4 bit\n",
    "\n",
    "# add downsampling layer constraint\n",
    "prob += variable[f\"x0\"] ==variable[f\"x3\"]\n",
    "prob += variable[f\"x10\"] ==variable[f\"x13\"]\n",
    "prob += variable[f\"x23\"] ==variable[f\"x26\"]\n",
    "prob += variable[f\"x42\"] ==variable[f\"x45\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed bops, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bops Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"BOPS\", LpMinimize)\n",
    "\n",
    "bops_difference_between_4_8 = bops_8bit - bops_4bit\n",
    "\n",
    "# add bops constraint\n",
    "prob += sum([ bops_4bit[i] + (variable[f\"x{i}\"] - 1) * bops_difference_between_4_8[i] for i in range(num_variable) ]) <= bops_limit # if 4 bit, x - 1 = 0, we use bops_4, if 8 bit, x-1=2, we use bops_diff + bos_4 = bops_8\n",
    "# add downsampling layer constraint\n",
    "prob += variable[f\"x0\"] ==variable[f\"x3\"]\n",
    "prob += variable[f\"x10\"] ==variable[f\"x13\"]\n",
    "prob += variable[f\"x23\"] ==variable[f\"x26\"]\n",
    "prob += variable[f\"x42\"] ==variable[f\"x45\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed bops, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Latency Constrait"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's construct the problem \n",
    "num_variable = Hutchinson_trace.shape[0]\n",
    "\n",
    "# first get the variables, here 1 means 4 bit and 2 means 8 bit\n",
    "variable = {}\n",
    "for i in range(num_variable):\n",
    "    variable[f\"x{i}\"] = LpVariable(f\"x{i}\", 1, 2, cat=LpInteger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = LpProblem(\"BOPS\", LpMinimize)\n",
    "\n",
    "latency_difference_between_4_8 = latency_int8 - latency_int4\n",
    "\n",
    "# add bops constraint\n",
    "prob += sum([ latency_int4[i] + (variable[f\"x{i}\"] - 1) * latency_difference_between_4_8[i] for i in range(num_variable) ]) <= latency_limit # if 4 bit, x - 1 = 0, we use bops_4, if 8 bit, x-1=2, we use bops_diff + bos_4 = bops_8\n",
    "# add downsampling layer constraint\n",
    "prob += variable[f\"x0\"] ==variable[f\"x3\"]\n",
    "prob += variable[f\"x10\"] ==variable[f\"x13\"]\n",
    "prob += variable[f\"x23\"] ==variable[f\"x26\"]\n",
    "prob += variable[f\"x42\"] ==variable[f\"x45\"]\n",
    "\n",
    "sensitivity_difference_between_4_8 = Hutchinson_trace * ( delta_weights_8bit_square  - delta_weights_4bit_square ) # here is the sensitivity different between 4 and 8\n",
    "\n",
    "# for fixed bops, we want large models, as well as smaller sensitivy\n",
    "# here sensitivity_difference_between is negative numbers, so if x = 1, means using 4 bit, gives us 0, and if x = 2, means using 8 bits, gives us negavie numers. It will prefer 8 bit\n",
    "# negative model size is negaive number. It will prefer 8 bit.\n",
    "# both prefer 8 bit but the bops is constrained, so we get a tradeoff. \n",
    "prob += sum( [ (variable[f\"x{i}\"] - 1) * sensitivity_difference_between_4_8[i] for i in range(num_variable) ] )  \n",
    "\n",
    "# solve the problem\n",
    "status = prob.solve(GLPK_CMD(msg=1, options=[\"--tmlim\", \"10000\",\"--simplex\"]))\n",
    "# status = prob.solve(COIN_CMD(msg=1, options=['dualSimplex']))\n",
    "# status = prob.solve(GLPK(msg=1, options=[\"--tmlim\", \"10\",\"--dualSimplex\"]))\n",
    "\n",
    "# get the result\n",
    "LpStatus[status]\n",
    "\n",
    "result = []\n",
    "for i in range(num_variable):\n",
    "    result.append(value(variable[f\"x{i}\"]))\n",
    "result = np.array(result)\n",
    "\n",
    "print(result)\n",
    "# Note that this model size does not count the first and last layer of the network\n",
    "# In the paper, we manually added this two layers in the final result\n",
    "print('Model Size:', np.sum(result * parameters * 4 * 4 / 32))\n",
    "result_4 = (result == 1)\n",
    "result_8 = (result == 2)\n",
    "print(\"Bops: \", np.sum(bops_4bit[result_4]) + np.sum(bops_8bit[result_8]))\n",
    "print(\"Latency: \", np.sum(latency_int4[result_4]) + np.sum(latency_int8[result_8]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
