{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> \n",
    "\n",
    "# Torchvision Classification Model Pruning using SparseML\n",
    "\n",
    "This notebook provides a step-by-step walkthrough for pruning a [torchvision model](https://pytorch.org/docs/stable/torchvision/models.html) using SparseML. You will:\n",
    "- Download a pre-trained torchvision model and generic dataset\n",
    "- Define a generic torchvision finetuning flow\n",
    "- Integrate the torchvision flow with SparseML\n",
    "- Prune the model using the torchvision+SparseML flow\n",
    "- Save the model and export to [ONNX](https://onnx.ai/)\n",
    "\n",
    "Reading through this notebook will be reasonably quick to gain an intuition for how to integrate SparseML with torchvision or more generically a PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:\n",
    "- 15 minutes on a GPU\n",
    "- 45 minutes on a laptop CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Requirements\n",
    "To run this notebook, you will need the following packages already installed:\n",
    "* SparseML and SparseZoo\n",
    "* PyTorch and torchvision\n",
    "\n",
    "You can install any package that is not already present via `pip`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sparseml\n",
    "import sparsezoo\n",
    "import torch\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Setting Up the Model and Dataset\n",
    "\n",
    "By default, you will prune a [ResNet50](https://arxiv.org/abs/1512.03385) model while finetuning it on the [Imagenette dataset](https://github.com/fastai/imagenette). The model's pretrained weights are downloaded from torchvision. The Imagenette dataset is downloaded from its repository via a helper class from SparseML.\n",
    "\n",
    "Additionally, we will override the FC layer in the ResNet50 model to have 10 output classes instead of the ImageNet standard 1000.\n",
    "\n",
    "If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:\n",
    "- model = resnet50(pretrained=True)\n",
    "- train_dataset = ImagenetteDataset(...)\n",
    "- val_dataset = ImagenetteDataset(...)\n",
    "\n",
    "Take care to keep the variable names the same, as the rest of the notebook is set up according to those and update any parts of the training flow as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet50\n",
    "from torch.nn import Linear\n",
    "\n",
    "from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize\n",
    "\n",
    "#######################################################\n",
    "# Define your model below\n",
    "#######################################################\n",
    "print(\"loading model...\")\n",
    "model = resnet50(pretrained=True)\n",
    "print(model)\n",
    "#######################################################\n",
    "# Define your train and validation datasets below\n",
    "#######################################################\n",
    "\n",
    "print(\"\\nloading train dataset...\")\n",
    "train_dataset = ImagenetteDataset(\n",
    "    train=True, dataset_size=ImagenetteSize.s320, image_size=224\n",
    ")\n",
    "print(train_dataset)\n",
    "\n",
    "print(\"\\nloading val dataset...\")\n",
    "val_dataset = ImagenetteDataset(\n",
    "    train=False, dataset_size=ImagenetteSize.s320, image_size=224\n",
    ")\n",
    "print(val_dataset)\n",
    "\n",
    "# Overriding number of classes\n",
    "NUM_CLASSES = 10  # number of imagenette classes\n",
    "model.fc = Linear(in_features=model.fc.in_features, out_features=NUM_CLASSES, bias=True)\n",
    "print(model.fc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Set Up a Torchvision Finetuning Loop\n",
    "SparseML can plug directly into your existing PyTorch training flow by overriding the Optimizer object. To demonstrate this, in the cell below, we define a simple PyTorch training loop taken from the [torchvision finetuning example](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html).  To prune your existing models using SparseML, you can use your own training flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import copy\n",
    "import torch\n",
    "\n",
    "def train_model(\n",
    "    model, dataloaders, criterion, optimizer, device, num_epochs=25, is_inception=False\n",
    "):\n",
    "    since = time.time()\n",
    "\n",
    "    val_acc_history = []\n",
    "\n",
    "    best_acc = 0.0\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        print(\"Epoch {}/{}\".format(epoch, num_epochs - 1))\n",
    "        print(\"-\" * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in [\"train\", \"val\"]:\n",
    "            if phase == \"train\":\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                model.eval()  # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "\n",
    "            # Iterate over data.\n",
    "            for inputs, labels in dataloaders[phase]:\n",
    "                inputs = inputs.to(device)\n",
    "                labels = labels.to(device)\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == \"train\"):\n",
    "                    # Get model outputs and calculate loss\n",
    "                    # Special case for inception because in training it has an auxiliary output. In train\n",
    "                    #   mode we calculate the loss by summing the final output and the auxiliary output\n",
    "                    #   but in testing we only consider the final output.\n",
    "                    if is_inception and phase == \"train\":\n",
    "                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958\n",
    "                        outputs, aux_outputs = model(inputs)\n",
    "                        loss1 = criterion(outputs, labels)\n",
    "                        loss2 = criterion(aux_outputs, labels)\n",
    "                        loss = loss1 + 0.4 * loss2\n",
    "                    else:\n",
    "                        outputs = model(inputs)\n",
    "                        loss = criterion(outputs, labels)\n",
    "\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == \"train\":\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                running_loss += loss.item() * inputs.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "\n",
    "            epoch_loss = running_loss / len(dataloaders[phase].dataset)\n",
    "            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)\n",
    "\n",
    "            print(\"{} Loss: {:.4f} Acc: {:.4f}\".format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            # deep copy the model\n",
    "            if phase == \"val\" and epoch_acc > best_acc:\n",
    "                best_acc = epoch_acc\n",
    "            if phase == \"val\":\n",
    "                val_acc_history.append(epoch_acc)\n",
    "\n",
    "        print()\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print(\n",
    "        \"Training complete in {:.0f}m {:.0f}s\".format(\n",
    "            time_elapsed // 60, time_elapsed % 60\n",
    "        )\n",
    "    )\n",
    "    print(\"Best val Acc: {:4f}\".format(best_acc))\n",
    "\n",
    "    # load best model weights\n",
    "    return model, val_acc_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Set Up PyTorch Training Objects\n",
    "In this step, you will select a device to train your model with, set up DataLoader objects, a loss function, and optimizer.  All of these variables and objects can be replaced to fit your training flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import SGD\n",
    "\n",
    "# setup device\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)\n",
    "print(\"Using device: {}\".format(device))\n",
    "\n",
    "# setup data loaders\n",
    "batch_size = 128\n",
    "train_loader = DataLoader(\n",
    "    train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    val_dataset, batch_size, shuffle=False, pin_memory=True, num_workers=8\n",
    ")\n",
    "dataloaders = {\"train\": train_loader, \"val\": val_loader}\n",
    "\n",
    "# setup loss function and optimizer, LR will be overriden by sparseml\n",
    "criterion = CrossEntropyLoss()\n",
    "optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Apply a SparseML Recipe and Prune Model\n",
    "\n",
    "To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to wrap the optimizer object to gradually prune the model using unstructured weight magnitude pruning after each optimizer step.\n",
    "\n",
    "You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.\n",
    "\n",
    "Using the wrapped optimizer object, you will call the training function to prune your model. Finalize the model after training by making a call to manager's `finalize(...)` method.\n",
    "\n",
    "If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `batch_size` in the cell above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Downloading a Recipe from SparseZoo\n",
    "The [SparseZoo](https://github.com/neuralmagic/sparsezoo) API provides precofigured recipes for its optimized model.  In the cell below, you will download a recipe for pruning ResNet50 on the Imagenette dataset and record it's saved path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparsezoo import Model, search_models\n",
    "\n",
    "zoo_model = search_models(\n",
    "    domain=\"cv\",\n",
    "    sub_domain=\"classification\",\n",
    "    architecture=\"resnet_v1\",\n",
    "    sub_architecture=\"50\",\n",
    "    framework=\"pytorch\",\n",
    "    repo=\"torchvision\",\n",
    "    dataset=\"imagenette\",\n",
    "    sparse_name=\"pruned\",\n",
    ")[0]  # unwrap search result\n",
    "\n",
    "recipe_path = zoo_model.recipes.default.path\n",
    "print(f\"Recipe downloaded to: {recipe_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.optim import (\n",
    "    ScheduledModifierManager,\n",
    ")\n",
    "\n",
    "# create ScheduledModifierManager and Optimizer wrapper\n",
    "manager = ScheduledModifierManager.from_yaml(recipe_path)\n",
    "optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))\n",
    "\n",
    "\n",
    "train_model(\n",
    "    model,\n",
    "    dataloaders,\n",
    "    criterion,\n",
    "    optimizer,\n",
    "    device,\n",
    "    num_epochs=manager.max_epochs,\n",
    "    is_inception=False,\n",
    ")\n",
    "\n",
    "manager.finalize(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - View Model Sparsity\n",
    "To see the effects of the model pruning, in this step, you will print out the sparsities of each Conv and FC layer in your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity\n",
    "\n",
    "# print sparsities of each layer\n",
    "for (name, layer) in get_prunable_layers(model):\n",
    "    print(\"{}.weight: {:.4f}\".format(name, tensor_sparsity(layer.weight).item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6 - Save Model and Export to ONNX\n",
    "\n",
    "Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). For PyTorch, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.\n",
    "\n",
    "Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.pytorch.utils import ModuleExporter\n",
    "\n",
    "save_dir = \"torchvision_models\"\n",
    "\n",
    "exporter = ModuleExporter(model, output_dir=save_dir)\n",
    "exporter.export_pytorch(name=\"resnet50_imagenette_pruned.pth\")\n",
    "exporter.export_onnx(torch.randn(1, 3, 224, 224), name=\"resnet50_imagenette_pruned.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:\n",
    "* Pruning different models using SparseML\n",
    "* Trying different pruning and optimization recipes\n",
    "* Running your model on the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
