{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fa79369b-5d7e-4046-99c8-9a59d2c29634",
   "metadata": {},
   "source": [
    "# Rotation Invariant Quantization (RIQ) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f04b3fab-7496-4731-be30-2372aeaad5a3",
   "metadata": {},
   "source": [
    "This notebook provides an easy step-by-step walkthrough for downloading off-the-shelf model, quantizing the model with RIQ  such that a certain cosine distance requirement\n",
    "is satisfied at the model's output. Then, compressing the quantized model with ANS layer-by-layer. \n",
    "\n",
    "You will:\n",
    "\n",
    " * Set up the environment\n",
    " * Select a model\n",
    " * Download the model\n",
    " * Quantize the model with RIQ and compress the it using ANS encoder\n",
    "     * The compressed model is saved into a dictonary file \n",
    " * Load the compressed model from a file into onnx format  \n",
    " * Run inference with the quantized weights and compare its output to the original model     \n",
    "    \n",
    " Reading through this notebook will quickly provide an overview of the RIQ approach, and its performance. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74cf9bc5-d791-46b9-8efc-278c180705fc",
   "metadata": {},
   "source": [
    "#### Setup the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55a424e5-e5a5-46f7-be82-3910cc8b83bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from torchvision.models import vgg16, VGG16_Weights, resnet50, ResNet50_Weights, alexnet, AlexNet_Weights\n",
    "\n",
    "from utils.quantize import get_quantized_model, bytes_to_bitstring, measure_cos_sim\n",
    "from utils.dataset import prepare_dataset_images\n",
    "from utils.onnx_bridge import OnnxBridge"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05e69d1a-de3e-46d6-8bb4-909ef64990b8",
   "metadata": {},
   "source": [
    "#### Select a model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f92b2e4c-7a19-44d8-a740-4980ab5ab2f9",
   "metadata": {},
   "source": [
    "* pick a model (resnet, vgg or alexnet)\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8921286-4623-438b-92de-3e0b1d339c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "picked_model = \"resnet\"\n",
    "model_path = \"models/\" + picked_model + \".onnx\"\n",
    "\n",
    "model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) if picked_model ==\"resnet\" else \\\n",
    "            vgg16(weights=VGG16_Weights.IMAGENET1K_V1) if picked_model ==\"vgg\" else \\\n",
    "                alexnet(weights=AlexNet_Weights.IMAGENET1K_V1) if picked_model ==\"alexnet\" else None\n",
    "\n",
    "if model != None:\n",
    "    model.eval()\n",
    "    inp = torch.randn(1, 3, 224, 224)\n",
    "    in_names = [\"actual_input\"]\n",
    "    out_name = [\"output\"]\n",
    "    torch.onnx.export(model, inp, model_path, verbose=False, input_names=in_names, output_names=out_name, export_params=True)\n",
    "else:\n",
    "    print(\"Please pick one of the following: \\\"resnet\\\", \\\"vgg\\\", \\\"alexnet\\\". \")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "920b5577-2765-43d1-97e8-ac6186b070d0",
   "metadata": {},
   "source": [
    "* set path to a small set of calibtration images\n",
    "    * This is optional, however, **better compression results are attain with a small calibration set**\n",
    "    * leave the calibration_path=\"\" for no cailbration data   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4e3241-26a2-4657-992b-2c647fa54a6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "calibration_path = \"<path to small set of pictures (optional, you can leave an empty string)>\" \n",
    "calibration_data = prepare_dataset_images(calibration_path+\"*.jpg\", model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6e1b041-a345-4dcb-ae00-da2454504b39",
   "metadata": {},
   "source": [
    "#### Quantize with RIQ and compress with ANS "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b37b9573-0d16-42ba-b63a-f80431180eb3",
   "metadata": {},
   "source": [
    "For the given distortion requirement at the output (0.005 in this demonstration):\n",
    "* RIQ finds a solution with minimal entropy \n",
    "* compress it with ANS layer by layer\n",
    "* save the compressed weights to a file .riq\n",
    "    * This enables to verify the reported compression ratio \n",
    "    * Enables to use these compressed weights in the sequel for inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "997cc153-995e-4eb4-bea5-ba98931f29e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = picked_model + '.riq'\n",
    "get_quantized_model(model_path, calibration_data, distortion=0.005, save_to=filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "208606d8-162d-44de-890b-900c0c36e2ce",
   "metadata": {},
   "source": [
    "#### Load and Decode compressed model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba12aa51-c32c-4aa9-8c32-1c9ced8ab4a3",
   "metadata": {},
   "source": [
    "The weights are loaded into a dict which we use for comparison with the original model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fda2117-dd32-478d-b704-dd62cdf4b892",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(filename, 'rb') as fn:\n",
    "    load_compressed_weights_dict = pickle.load(fn)\n",
    "    \n",
    "qws = {}\n",
    "for k, tup in tqdm(load_compressed_weights_dict.items()):    \n",
    "    tans, bytestream, delta, shape = tup    \n",
    "    bitstream  = bytes_to_bitstring(bytestream)    \n",
    "    if delta < 1.0 :        \n",
    "        qw = tans.decode_data(bitstream)\n",
    "        qw = np.array(qw).reshape(shape)\n",
    "        qw = qw * delta \n",
    "        qws[k] = qw\n",
    "    else: \n",
    "        qws[k] = bytestream \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65c1423c-00f0-4f48-b673-5931db94d0b8",
   "metadata": {},
   "source": [
    "#### Compare outputs of the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b223380-26ab-4efd-b3d2-06302b8477c0",
   "metadata": {
    "tags": []
   },
   "source": [
    "This part requires small validation set for comparing the output.\n",
    "\n",
    "* For the calibratoin set can be used for this comparison as well (thought doesn't reflect )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2282650-6f69-4337-a318-50db7021d5e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ob = OnnxBridge(model_path) #load original model\n",
    "validation_dataset = calibration_data \n",
    "original_outs = np.array([ob(i) for i in validation_dataset])\n",
    "ob.set_weights(qws)\n",
    "outs = np.array([ob(i) for i in validation_dataset])\n",
    "cos_distance = 1-measure_cos_sim(original_outs, outs, _)\n",
    "print(\"cos distance is \", cos_distance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fcf77b-6e99-4943-a707-c026f280977a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4e9465-6e6b-4342-9109-63572dd48cd3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab600e5f-595f-4bf9-accd-5164cb186f60",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
