{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Saving PruneBERT\n",
    "\n",
    "\n",
    "This notebook aims at showcasing how we can leverage standard tools to save (and load) an extremely sparse model fine-pruned with [movement pruning](XXXX) (or any other unstructured pruning mehtod).\n",
    "\n",
    "In this example, we used BERT (base-uncased, but the procedure described here is not specific to BERT and can be applied to a large variety of models.\n",
    "\n",
    "We first obtain an extremely sparse model by fine-pruning with movement pruning on SQuAD v1.1. We then used the following combination of standard tools:\n",
    "- We reduce the precision of the model with Int8 dynamic quantization using [PyTorch implementation](XXXX). We only quantized the Fully Connected Layers.\n",
    "- Sparse quantized matrices are converted into the [Compressed Sparse Row format](XXXX).\n",
    "- We use HDF5 with `gzip` compression to store the weights.\n",
    "\n",
    "We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](XXXX)!\n",
    "\n",
    "<img src=\"XXXX\" width=\"200\">\n",
    "\n",
    "*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](XXXX).*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Includes\n",
    "\n",
    "import h5py\n",
    "import os\n",
    "import json\n",
    "from collections import OrderedDict\n",
    "\n",
    "from scipy import sparse\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from transformers import *\n",
    "\n",
    "os.chdir(\"../../\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dynamic quantization induces little or no loss of performance while significantly reducing the memory footprint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load fine-pruned model and quantize the model\n",
    "\n",
    "model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
    "model.to(\"cpu\")\n",
    "\n",
    "quantized_model = torch.quantization.quantize_dynamic(\n",
    "    model=model,\n",
    "    qconfig_spec={\n",
    "        nn.Linear: torch.quantization.default_dynamic_qconfig,\n",
    "    },\n",
    "    dtype=torch.qint8,\n",
    ")\n",
    "# print(quantized_model)\n",
    "\n",
    "qtz_st = quantized_model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving the original (encoder + classifier) in the standard torch.save format\n",
    "\n",
    "dense_st = {\n",
    "    name: param for name, param in model.state_dict().items() if \"embedding\" not in name and \"pooler\" not in name\n",
    "}\n",
    "torch.save(\n",
    "    dense_st,\n",
    "    \"dbg/dense_squad.pt\",\n",
    ")\n",
    "dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decompose quantization for bert.encoder.layer.0.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.0.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.0.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.0.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.0.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.0.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.1.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.2.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.3.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.4.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.5.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.6.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.7.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.8.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.9.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.10.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.attention.self.query._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.attention.self.key._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.attention.self.value._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.attention.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.intermediate.dense._packed_params.weight\n",
      "Decompose quantization for bert.encoder.layer.11.output.dense._packed_params.weight\n",
      "Decompose quantization for bert.pooler.dense._packed_params.weight\n",
      "Decompose quantization for qa_outputs._packed_params.weight\n"
     ]
    }
   ],
   "source": [
    "# Elementary representation: we decompose the quantized tensors into (scale, zero_point, int_repr).\n",
    "# See XXXX\n",
    "\n",
    "# We further leverage the fact that int_repr is sparse matrix to optimize the storage: we decompose int_repr into\n",
    "# its CSR representation (data, indptr, indices).\n",
    "\n",
    "elementary_qtz_st = {}\n",
    "for name, param in qtz_st.items():\n",
    "    if \"dtype\" not in name and param.is_quantized:\n",
    "        print(\"Decompose quantization for\", name)\n",
    "        # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
    "        scale = param.q_scale()  # torch.tensor(1,) - float32\n",
    "        zero_point = param.q_zero_point()  # torch.tensor(1,) - int32\n",
    "        elementary_qtz_st[f\"{name}.scale\"] = scale\n",
    "        elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n",
    "\n",
    "        # We assume the int_repr is sparse and compute its CSR representation\n",
    "        # Only the FCs in the encoder are actually sparse\n",
    "        int_repr = param.int_repr()  # torch.tensor(nb_rows, nb_columns) - int8\n",
    "        int_repr_cs = sparse.csr_matrix(int_repr)  # scipy.sparse.csr.csr_matrix\n",
    "\n",
    "        elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data  # np.array int8\n",
    "        elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr  # np.array int32\n",
    "        assert max(int_repr_cs.indices) < 65535  # If not, we shall fall back to int32\n",
    "        elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices)  # np.array uint16\n",
    "        elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape  # tuple(int, int)\n",
    "    else:\n",
    "        elementary_qtz_st[name] = param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
    "str_2_dtype = {\"qint8\": torch.qint8}\n",
    "dtype_2_str = {torch.qint8: \"qint8\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder Size (MB) - Sparse & Quantized - `torch.save`: 21.29\n"
     ]
    }
   ],
   "source": [
    "# Saving the pruned (encoder + classifier) in the standard torch.save format\n",
    "\n",
    "dense_optimized_st = {\n",
    "    name: param for name, param in elementary_qtz_st.items() if \"embedding\" not in name and \"pooler\" not in name\n",
    "}\n",
    "torch.save(\n",
    "    dense_optimized_st,\n",
    "    \"dbg/dense_squad_optimized.pt\",\n",
    ")\n",
    "print(\n",
    "    \"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n",
    "    round(os.path.getsize(\"dbg/dense_squad_optimized.pt\") / 1e6, 2),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skip bert.embeddings.word_embeddings.weight\n",
      "Skip bert.embeddings.position_embeddings.weight\n",
      "Skip bert.embeddings.token_type_embeddings.weight\n",
      "Skip bert.embeddings.LayerNorm.weight\n",
      "Skip bert.embeddings.LayerNorm.bias\n",
      "Skip bert.pooler.dense.scale\n",
      "Skip bert.pooler.dense.zero_point\n",
      "Skip bert.pooler.dense._packed_params.weight.scale\n",
      "Skip bert.pooler.dense._packed_params.weight.zero_point\n",
      "Skip bert.pooler.dense._packed_params.weight.int_repr.data\n",
      "Skip bert.pooler.dense._packed_params.weight.int_repr.indptr\n",
      "Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
      "Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
      "Skip bert.pooler.dense._packed_params.bias\n",
      "Skip bert.pooler.dense._packed_params.dtype\n",
      "\n",
      "Encoder Size (MB) - Dense:              340.26\n",
      "Encoder Size (MB) - Sparse & Quantized: 11.28\n"
     ]
    }
   ],
   "source": [
    "# Save the decomposed state_dict with an HDF5 file\n",
    "# Saving only the encoder + QA Head\n",
    "\n",
    "with h5py.File(\"dbg/squad_sparse.h5\", \"w\") as hf:\n",
    "    for name, param in elementary_qtz_st.items():\n",
    "        if \"embedding\" in name:\n",
    "            print(f\"Skip {name}\")\n",
    "            continue\n",
    "\n",
    "        if \"pooler\" in name:\n",
    "            print(f\"Skip {name}\")\n",
    "            continue\n",
    "\n",
    "        if type(param) == torch.Tensor:\n",
    "            if param.numel() == 1:\n",
    "                # module scale\n",
    "                # module zero_point\n",
    "                hf.attrs[name] = param\n",
    "                continue\n",
    "\n",
    "            if param.requires_grad:\n",
    "                # LayerNorm\n",
    "                param = param.detach().numpy()\n",
    "            hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
    "\n",
    "        elif type(param) == float or type(param) == int or type(param) == tuple:\n",
    "            # float - tensor _packed_params.weight.scale\n",
    "            # int   - tensor _packed_params.weight.zero_point\n",
    "            # tuple - tensor _packed_params.weight.shape\n",
    "            hf.attrs[name] = param\n",
    "\n",
    "        elif type(param) == torch.dtype:\n",
    "            # dtype - tensor _packed_params.dtype\n",
    "            hf.attrs[name] = dtype_2_str[param]\n",
    "\n",
    "        else:\n",
    "            hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
    "\n",
    "\n",
    "with open(\"dbg/metadata.json\", \"w\") as f:\n",
    "    f.write(json.dumps(qtz_st._metadata))\n",
    "\n",
    "size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
    "print(\"\")\n",
    "print(\"Encoder Size (MB) - Dense:             \", round(dense_mb_size / 1e6, 2))\n",
    "print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size / 1e6, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Size (MB): 99.41\n"
     ]
    }
   ],
   "source": [
    "# Save the decomposed state_dict to HDF5 storage\n",
    "# Save everything in the architecutre (embedding + encoder + QA Head)\n",
    "\n",
    "with h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"w\") as hf:\n",
    "    for name, param in elementary_qtz_st.items():\n",
    "        #         if \"embedding\" in name:\n",
    "        #             print(f\"Skip {name}\")\n",
    "        #             continue\n",
    "\n",
    "        #         if \"pooler\" in name:\n",
    "        #             print(f\"Skip {name}\")\n",
    "        #             continue\n",
    "\n",
    "        if type(param) == torch.Tensor:\n",
    "            if param.numel() == 1:\n",
    "                # module scale\n",
    "                # module zero_point\n",
    "                hf.attrs[name] = param\n",
    "                continue\n",
    "\n",
    "            if param.requires_grad:\n",
    "                # LayerNorm\n",
    "                param = param.detach().numpy()\n",
    "            hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
    "\n",
    "        elif type(param) == float or type(param) == int or type(param) == tuple:\n",
    "            # float - tensor _packed_params.weight.scale\n",
    "            # int   - tensor _packed_params.weight.zero_point\n",
    "            # tuple - tensor _packed_params.weight.shape\n",
    "            hf.attrs[name] = param\n",
    "\n",
    "        elif type(param) == torch.dtype:\n",
    "            # dtype - tensor _packed_params.dtype\n",
    "            hf.attrs[name] = dtype_2_str[param]\n",
    "\n",
    "        else:\n",
    "            hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
    "\n",
    "\n",
    "with open(\"dbg/metadata.json\", \"w\") as f:\n",
    "    f.write(json.dumps(qtz_st._metadata))\n",
    "\n",
    "size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n",
    "print(\"\\nSize (MB):\", round(size / 1e6, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reconstruct the elementary state dict\n",
    "\n",
    "reconstructed_elementary_qtz_st = {}\n",
    "\n",
    "hf = h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"r\")\n",
    "\n",
    "for attr_name, attr_param in hf.attrs.items():\n",
    "    if \"shape\" in attr_name:\n",
    "        attr_param = tuple(attr_param)\n",
    "    elif \".scale\" in attr_name:\n",
    "        if \"_packed_params\" in attr_name:\n",
    "            attr_param = float(attr_param)\n",
    "        else:\n",
    "            attr_param = torch.tensor(attr_param)\n",
    "    elif \".zero_point\" in attr_name:\n",
    "        if \"_packed_params\" in attr_name:\n",
    "            attr_param = int(attr_param)\n",
    "        else:\n",
    "            attr_param = torch.tensor(attr_param)\n",
    "    elif \".dtype\" in attr_name:\n",
    "        attr_param = str_2_dtype[attr_param]\n",
    "    reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
    "    # print(f\"Unpack {attr_name}\")\n",
    "\n",
    "# Get the tensors/arrays\n",
    "for data_name, data_param in hf.items():\n",
    "    if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n",
    "        reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
    "    elif \"embedding\" in data_name:\n",
    "        reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n",
    "    else:  # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n",
    "        data_param = np.array(data_param)\n",
    "        if \"indices\" in data_name:\n",
    "            data_param = np.array(data_param, dtype=np.int32)\n",
    "        reconstructed_elementary_qtz_st[data_name] = data_param\n",
    "    # print(f\"Unpack {data_name}\")\n",
    "\n",
    "\n",
    "hf.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sanity checks\n",
    "\n",
    "for name, param in reconstructed_elementary_qtz_st.items():\n",
    "    assert name in elementary_qtz_st\n",
    "for name, param in elementary_qtz_st.items():\n",
    "    assert name in reconstructed_elementary_qtz_st, name\n",
    "\n",
    "for name, param in reconstructed_elementary_qtz_st.items():\n",
    "    assert type(param) == type(elementary_qtz_st[name]), name\n",
    "    if type(param) == torch.Tensor:\n",
    "        assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n",
    "    elif type(param) == np.ndarray:\n",
    "        assert (param == elementary_qtz_st[name]).all(), name\n",
    "    else:\n",
    "        assert param == elementary_qtz_st[name], name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-assemble the sparse int_repr from the CSR format\n",
    "\n",
    "reconstructed_qtz_st = {}\n",
    "\n",
    "for name, param in reconstructed_elementary_qtz_st.items():\n",
    "    if \"weight.int_repr.indptr\" in name:\n",
    "        prefix_ = name[:-16]\n",
    "        data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n",
    "        indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n",
    "        indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n",
    "        shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n",
    "\n",
    "        int_repr = sparse.csr_matrix(arg1=(data, indices, indptr), shape=shape)\n",
    "        int_repr = torch.tensor(int_repr.todense())\n",
    "\n",
    "        scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n",
    "        zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n",
    "        weight = torch._make_per_tensor_quantized_tensor(int_repr, scale, zero_point)\n",
    "\n",
    "        reconstructed_qtz_st[f\"{prefix_}\"] = weight\n",
    "    elif (\n",
    "        \"int_repr.data\" in name\n",
    "        or \"int_repr.shape\" in name\n",
    "        or \"int_repr.indices\" in name\n",
    "        or \"weight.scale\" in name\n",
    "        or \"weight.zero_point\" in name\n",
    "    ):\n",
    "        continue\n",
    "    else:\n",
    "        reconstructed_qtz_st[name] = param"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sanity checks\n",
    "\n",
    "for name, param in reconstructed_qtz_st.items():\n",
    "    assert name in qtz_st\n",
    "for name, param in qtz_st.items():\n",
    "    assert name in reconstructed_qtz_st, name\n",
    "\n",
    "for name, param in reconstructed_qtz_st.items():\n",
    "    assert type(param) == type(qtz_st[name]), name\n",
    "    if type(param) == torch.Tensor:\n",
    "        assert torch.all(torch.eq(param, qtz_st[name])), name\n",
    "    elif type(param) == np.ndarray:\n",
    "        assert (param == qtz_st[name]).all(), name\n",
    "    else:\n",
    "        assert param == qtz_st[name], name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the re-constructed state dict into a model\n",
    "\n",
    "dummy_model = BertForQuestionAnswering.from_pretrained(\"bert-base-uncased\")\n",
    "dummy_model.to(\"cpu\")\n",
    "\n",
    "reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n",
    "    model=dummy_model,\n",
    "    qconfig_spec=None,\n",
    "    dtype=torch.qint8,\n",
    ")\n",
    "\n",
    "reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n",
    "with open(\"dbg/metadata.json\", \"r\") as read_file:\n",
    "    metadata = json.loads(read_file.read())\n",
    "reconstructed_qtz_st._metadata = metadata\n",
    "\n",
    "reconstructed_qtz_model.load_state_dict(reconstructed_qtz_st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity check passed\n"
     ]
    }
   ],
   "source": [
    "# Sanity checks on the infernce\n",
    "\n",
    "N = 32\n",
    "\n",
    "for _ in range(25):\n",
    "    inputs = torch.randint(low=0, high=30000, size=(N, 128))\n",
    "    mask = torch.ones(size=(N, 128))\n",
    "\n",
    "    y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n",
    "    y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n",
    "\n",
    "    assert torch.all(torch.eq(y, y_reconstructed))\n",
    "print(\"Sanity check passed\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
