{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> \n",
    "\n",
    "# PyTorch Detection Model Pruning using SparseML\n",
    "\n",
    "This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). You will:\n",
    "- Set up the model and dataset\n",
    "- Define a generic PyTorch training flow\n",
    "- Integrate the PyTorch flow with SparseML\n",
    "- Prune the model using the PyTorch+SparseML flow\n",
    "- Export to [ONNX](https://onnx.ai/)\n",
    "\n",
    "Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:\n",
    "- 30 minutes on a GPU\n",
    "- 90 minutes on a laptop CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Requirements\n",
    "To run this notebook, you will need the following packages already installed:\n",
    "* SparseML and SparseZoo\n",
    "* PyTorch and torchvision\n",
    "\n",
    "You can install any package that is not already present via `pip`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sparseml\n",
    "import sparsezoo\n",
    "import torch\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Setting Up the Model and Dataset\n",
    "\n",
    "By default, you will prune a [SSD](https://arxiv.org/abs/1512.02325) model trained on the [VOC detection dataset](http://host.robots.ox.ac.uk/pascal/VOC/). The model's pretrained weights are downloaded from the SparseZoo model repo. The VOC detection dataset is downloaded from its repository via a helper class from SparseML.\n",
    "\n",
    "Note, for this notebook, you will use a ResNet18 backbone for the object detector.  This is to save training time and demonstrate the general pruning flow. For better accuracy, you can use a different backbone or model.\n",
    "\n",
    "If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:\n",
    "- model = ModelRegistry.create(...)\n",
    "- train_dataset = VOCDetectionDataset(...)\n",
    "- val_dataset = VOCDetectionDataset(...)\n",
    "\n",
    "Take care to keep the variable names the same, as the rest of the notebook is set up according to those and update any parts of the training flow as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.models import ModelRegistry\n",
    "from sparseml.pytorch.datasets import VOCDetectionDataset\n",
    "from sparseml.pytorch.utils import get_default_boxes_300\n",
    "\n",
    "#######################################################\n",
    "# Define your model below\n",
    "#######################################################\n",
    "print(\"loading model...\")\n",
    "model = ModelRegistry.create(\n",
    "    key=\"ssd300_resnet18\",\n",
    "    pretrained=True,\n",
    "    pretrained_dataset=\"voc\",\n",
    "    pretrained_backbone=False,  # no need to download initial weights\n",
    "    num_classes=21,\n",
    ")\n",
    "model_name = model.__class__.__name__\n",
    "input_shape = ModelRegistry.input_shape(\"ssd300_resnet18\")\n",
    "input_size = input_shape[-1]\n",
    "print(model)\n",
    "#######################################################\n",
    "# Define your train and validation datasets below\n",
    "#######################################################\n",
    "\n",
    "print(\"\\nloading train dataset...\")\n",
    "default_boxes = get_default_boxes_300(\"voc\")\n",
    "train_dataset = VOCDetectionDataset(\n",
    "    train=True, rand_trans=True, preprocessing_type=\"ssd\", default_boxes=default_boxes\n",
    ")\n",
    "print(train_dataset)\n",
    "\n",
    "print(\"\\nloading val dataset...\")\n",
    "val_dataset = VOCDetectionDataset(\n",
    "    train=False, preprocessing_type=\"ssd\", default_boxes=default_boxes\n",
    ")\n",
    "print(val_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Set Up a PyTorch Training Loop\n",
    "SparseML can plug directly into your existing PyTorch training flow by overriding the Optimizer object. To demonstrate this, in the cell below, we define a simple PyTorch training loop adapted from [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) to work with our detection model and loss function.  To prune your existing models using SparseML, you can use your own training flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import math\n",
    "import torch\n",
    "\n",
    "from sparseml.pytorch.utils import DEFAULT_LOSS_KEY\n",
    "\n",
    "\n",
    "def run_model_one_epoch(\n",
    "    model, data_loader, criterion, device, train=False, optimizer=None\n",
    "):\n",
    "    if train:\n",
    "        model.train()\n",
    "    else:\n",
    "        model.eval()\n",
    "\n",
    "    running_loss = 0.0\n",
    "\n",
    "    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):\n",
    "        inputs = inputs.to(device)\n",
    "        labels = [\n",
    "            label.to(device) if isinstance(label, torch.Tensor) else label\n",
    "            for label in labels\n",
    "        ]\n",
    "\n",
    "        if train:\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion((inputs, labels), outputs)[DEFAULT_LOSS_KEY]\n",
    "\n",
    "        if train:\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "    loss = running_loss / (step + 1.0)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Set Up Detection Training Objects\n",
    "In this step, you will select a device to train your model with, set up DataLoader objects, a loss function, and optimizer.  All of these variables and objects can be replaced to fit your training flow.  The loss function and collate function are standard for SSD training and are defined in the sparseml API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import SGD\n",
    "\n",
    "from sparseml.pytorch.datasets import ssd_collate_fn\n",
    "from sparseml.pytorch.utils import SSDLossWrapper\n",
    "\n",
    "# setup device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)\n",
    "print(\"Using device: {}\".format(device))\n",
    "\n",
    "# setup data loaders\n",
    "batch_size = 64\n",
    "train_loader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size,\n",
    "    shuffle=True,\n",
    "    pin_memory=True,\n",
    "    num_workers=12,\n",
    "    collate_fn=ssd_collate_fn,\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    val_dataset,\n",
    "    batch_size,\n",
    "    shuffle=False,\n",
    "    pin_memory=True,\n",
    "    num_workers=12,\n",
    "    collate_fn=ssd_collate_fn,\n",
    ")\n",
    "\n",
    "# setup loss function and optimizer\n",
    "criterion = SSDLossWrapper()\n",
    "optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Apply a SparseML Recipe and Prune Model\n",
    "\n",
    "To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to wrap the optimizer object to gradually prune the model using unstructured weight magnitude pruning after each optimizer step.\n",
    "\n",
    "You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.\n",
    "\n",
    "Using the wrapped optimizer object, you will call the training function to prune your model. Finalize the model after training by making a call to manager's `finalize(...)` method.\n",
    "\n",
    "If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `batch_size` in the cell above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Downloading a Recipe from SparseZoo\n",
    "The [SparseZoo](https://github.com/neuralmagic/sparsezoo) API provides precofigured recipes for its optimized model.  In the cell below, you will download a recipe for pruning SSD-ResNet18 on the VOC dataset and record it's saved path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparsezoo import Model, search_models\n",
    "\n",
    "zoo_model = search_models(\n",
    "    domain=\"cv\",\n",
    "    sub_domain=\"detection\",\n",
    "    architecture=\"ssd\",\n",
    "    sub_architecture=\"resnet18_300\",\n",
    "    framework=\"pytorch\",\n",
    "    repo=\"sparseml\",\n",
    "    dataset=\"voc\",\n",
    "    sparse_name=\"pruned\",\n",
    ")[0]  # unwrap search result\n",
    "recipe_path = zoo_model.recipes.default.path\n",
    "print(f\"Recipe downloaded to: {recipe_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.optim import (\n",
    "    ScheduledModifierManager,\n",
    ")\n",
    "\n",
    "# create ScheduledModifierManager and Optimizer wrapper\n",
    "manager = ScheduledModifierManager.from_yaml(recipe_path)\n",
    "optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))\n",
    "\n",
    "epoch = manager.min_epochs\n",
    "while epoch < manager.max_epochs:\n",
    "    # run training loop\n",
    "    epoch_name = \"{}/{}\".format(epoch + 1, manager.max_epochs)\n",
    "    print(\"Running Training Epoch {}\".format(epoch_name))\n",
    "    train_loss = run_model_one_epoch(\n",
    "        model, train_loader, criterion, device, train=True, optimizer=optimizer\n",
    "    )\n",
    "    print(\"Training Epoch: {}\\nTraining Loss: {}\\n\".format(epoch_name, train_loss))\n",
    "\n",
    "    # run validation loop\n",
    "    print(\"Running Validation Epoch {}\".format(epoch_name))\n",
    "    val_loss = run_model_one_epoch(model, val_loader, criterion, device)\n",
    "    print(\"Validation Epoch: {}\\nVal Loss: {}\\n\".format(epoch_name, val_loss))\n",
    "\n",
    "    epoch += 1\n",
    "\n",
    "manager.finalize(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - View Model Sparsity\n",
    "To see the effects of the model pruning, in this step, you will print out the sparsities of each Conv and FC layer in your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity\n",
    "\n",
    "# print sparsities of each layer\n",
    "for (name, layer) in get_prunable_layers(model):\n",
    "    print(\"{}.weight: {:.4f}\".format(name, tensor_sparsity(layer.weight).item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7 - Exporting to ONNX\n",
    "\n",
    "Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). For PyTorch, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.\n",
    "\n",
    "Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.utils import ModuleExporter\n",
    "\n",
    "save_dir = \"pytorch_detection\"\n",
    "\n",
    "exporter = ModuleExporter(model, output_dir=save_dir)\n",
    "exporter.export_pytorch(name=\"ssd_resnet18_pruned.pth\")\n",
    "exporter.export_onnx(torch.randn(1, 3, 300, 300), name=\"ssd_resnet18_pruned.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:\n",
    "* Pruning different models using SparseML\n",
    "* Trying different pruning and optimization recipes\n",
    "* Running your model on the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
