{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> \n",
    "\n",
    "# TensorFlow v1 Classification Model Pruning Using SparseML\n",
    "\n",
    "This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the DeepSparse Engine. You will:\n",
    "- Set up the model and dataset\n",
    "- Define a TensorFlow training flow with a simple SparseML integration\n",
    "- Prune the model using the TensorFlow+SparseML flow\n",
    "- Export to [ONNX](https://onnx.ai/)\n",
    "\n",
    "Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your TensorFlow training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:\n",
    "- 20 minutes on a GPU\n",
    "- 60 minutes on a laptop CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Requirements\n",
    "To run this notebook, you will need the following packages already installed:\n",
    "* SparseML and SparseZoo\n",
    "* TensorFlow v1 and tf2onnx\n",
    "\n",
    "You can install any package that is not already present via `pip`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sparseml\n",
    "import sparsezoo\n",
    "import tensorflow as tf\n",
    "import tf2onnx\n",
    "\n",
    "assert tf.__version__ < \"2\"\n",
    "\n",
    "# suppress warnings\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Setting Up the Model and Dataset\n",
    "\n",
    "By default, you will prune a [ResNet-50](https://arxiv.org/abs/1512.03385) model trained on the [Imagenette dataset](https://github.com/fastai/imagenette). The model's pretrained weights are downloaded from the SparseZoo model repo.   The Imagenette dataset is downloaded from its repository via a helper class from SparseML.\n",
    "\n",
    "In the cells below, functions are defined to load the dataset, model, and training objects to be called during training from within the `Graph` context.\n",
    "\n",
    "If you would like to try out your model for pruning, modify the appropriate function for to load your model or dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "from sparseml.tensorflow_v1.models import ModelRegistry\n",
    "from sparseml.tensorflow_v1.datasets import (\n",
    "    ImagenetteDataset,\n",
    "    ImagenetteSize,\n",
    ")\n",
    "from sparseml.tensorflow_v1.utils import (\n",
    "    batch_cross_entropy_loss,\n",
    "    accuracy,\n",
    ")\n",
    "\n",
    "MODEL_NAME = \"resnet50\"\n",
    "BATCH_SIZE = 128\n",
    "INPUT_SIZE = 224\n",
    "NUM_CLASSES = 10\n",
    "SAVE_DIR = \"tensorflow_v1_classification_pruning\"\n",
    "\n",
    "\n",
    "def load_dataset():\n",
    "    with tf.device(\"/cpu:0\"):\n",
    "        print(\"loading datasets\")\n",
    "        train_dataset = ImagenetteDataset(\n",
    "            train=True, dataset_size=ImagenetteSize.s320, image_size=INPUT_SIZE\n",
    "        )\n",
    "        train_len = len(train_dataset)\n",
    "        train_steps = math.ceil(train_len / float(BATCH_SIZE))\n",
    "        train_dataset = train_dataset.build(\n",
    "            BATCH_SIZE,\n",
    "            shuffle_buffer_size=1000,\n",
    "            prefetch_buffer_size=BATCH_SIZE,\n",
    "            num_parallel_calls=4,\n",
    "        )\n",
    "\n",
    "        val_dataset = ImagenetteDataset(\n",
    "            train=False, dataset_size=ImagenetteSize.s320, image_size=INPUT_SIZE\n",
    "        )\n",
    "        val_len = len(val_dataset)\n",
    "        val_steps = math.ceil(val_len / float(BATCH_SIZE))\n",
    "        val_dataset = val_dataset.build(\n",
    "            BATCH_SIZE,\n",
    "            shuffle_buffer_size=1000,\n",
    "            prefetch_buffer_size=BATCH_SIZE,\n",
    "            num_parallel_calls=4,\n",
    "        )\n",
    "\n",
    "    return train_dataset, val_dataset, (train_steps, val_steps)\n",
    "\n",
    "\n",
    "def create_model(sample_input, training):\n",
    "    print(\"Creating model graph for {}\".format(MODEL_NAME))\n",
    "    logits = ModelRegistry.create(\n",
    "        MODEL_NAME,\n",
    "        inputs=sample_input,\n",
    "        training=training,\n",
    "        num_classes=NUM_CLASSES,\n",
    "    )\n",
    "    return logits\n",
    "\n",
    "\n",
    "def create_training_objects(sample_labels):\n",
    "    print(\"Creating loss, accuracy, and optimizer in graph\")\n",
    "    loss = batch_cross_entropy_loss(logits, labels)\n",
    "    acc = accuracy(logits, labels)\n",
    "    global_step = tf.train.get_or_create_global_step()\n",
    "    train_op = tf.train.AdamOptimizer(learning_rate=0.00008).minimize(\n",
    "        loss, global_step=global_step\n",
    "    )\n",
    "    return loss, acc, global_step, train_op\n",
    "\n",
    "\n",
    "def load_pretrained():\n",
    "    print(\"loading pre-trained model weights\")\n",
    "    ModelRegistry.load_pretrained(\n",
    "        MODEL_NAME, pretrained=\"base\", remove_dynamic_tl_vars=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Create a SparseML Modifier Manager\n",
    "\n",
    "To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to create operations that modify the training process.\n",
    "\n",
    "You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.\n",
    "\n",
    "Using the operators generated from this manager object, you will be able to prune your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.tensorflow_v1.optim import (\n",
    "    ScheduledModifierManager,\n",
    ")\n",
    "from sparsezoo import Model, search_models\n",
    "\n",
    "\n",
    "def create_sparseml_manager():\n",
    "    model = search_models(\n",
    "        domain=\"cv\",\n",
    "        sub_domain=\"classification\",\n",
    "        architecture=\"resnet_v1\",\n",
    "        sub_architecture=\"50\",\n",
    "        framework=\"tensorflow_v1\",\n",
    "        repo=\"sparseml\",\n",
    "        dataset=\"imagenette\",\n",
    "        sparse_name=\"pruned\",\n",
    "    )[0]  # unwrap search result\n",
    "\n",
    "    recipe_path = model.recipes.default.path\n",
    "    print(f\"Recipe downloaded to: {recipe_path}\")\n",
    "\n",
    "    return ScheduledModifierManager.from_yaml(recipe_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4 - Prune your model using a TensorFlow training loop\n",
    "SparseML can plug directly into your existing TensorFlow training flow by creating additional operators to run. To demonstrate this, in the cell below, prune the model using a standard TensorFlow training loop while also running the operators created by the manager object.  To prune your existing models using SparseML, you can use your own training flow with the additional operators added.\n",
    "\n",
    "For your convienence the lines needed for integrating with SparseML are preceeded by large comment blocks.\n",
    "\n",
    "If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `BATCH_SIZE` in Step 2 above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "import os\n",
    "from tqdm.auto import tqdm\n",
    "from sparseml.utils import create_unique_dir, create_dirs\n",
    "from sparseml.tensorflow_v1.datasets import create_split_iterators_handle\n",
    "\n",
    "with tf.Graph().as_default() as graph:\n",
    "    # create dataset\n",
    "    train_dataset, val_dataset, (train_steps, val_steps) = load_dataset()\n",
    "    handle, iterator, (train_iter, val_iter) = create_split_iterators_handle(\n",
    "        [train_dataset, val_dataset]\n",
    "    )\n",
    "    images, labels = iterator.get_next()\n",
    "\n",
    "    # create base training objects\n",
    "    training = tf.placeholder(dtype=tf.bool, shape=[])\n",
    "    logits = create_model(images, training)\n",
    "    loss, acc, global_step, train_op = create_training_objects(labels)\n",
    "    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
    "\n",
    "    #######################################################\n",
    "    # create sparseml training manager\n",
    "    #######################################################\n",
    "    manager = create_sparseml_manager()\n",
    "    mod_ops, mod_extras = manager.create_ops(train_steps, global_step, graph=graph)\n",
    "\n",
    "    with tf.Session() as sess:\n",
    "        print(\"initializing session\")\n",
    "        sess.run(\n",
    "            [\n",
    "                tf.global_variables_initializer(),\n",
    "                tf.local_variables_initializer(),\n",
    "            ]\n",
    "        )\n",
    "        train_iter_handle, val_iter_handle = sess.run(\n",
    "            [train_iter.string_handle(), val_iter.string_handle()]\n",
    "        )\n",
    "\n",
    "        # initialize sparseml manager after pretrained weights loaded\n",
    "        load_pretrained()\n",
    "        manager.initialize_session()\n",
    "\n",
    "        num_epochs = manager.max_epochs\n",
    "        for epoch in tqdm(range(num_epochs), desc=\"pruning\"):\n",
    "            print(\"training for epoch {}...\".format(epoch))\n",
    "            sess.run(train_iter.initializer)\n",
    "            train_losses = []\n",
    "            train_acc = []\n",
    "\n",
    "            for step in range(train_steps):\n",
    "                _, __, meas_step, meas_loss, meas_acc = sess.run(\n",
    "                    [train_op, update_ops, global_step, loss, acc],\n",
    "                    feed_dict={handle: train_iter_handle, training: True},\n",
    "                )\n",
    "                train_losses.append(meas_loss)\n",
    "                train_acc.append(meas_acc)\n",
    "\n",
    "                #######################################################\n",
    "                # Modifier update ops line for transfer learning from a sparse model in TensorFlow\n",
    "                #######################################################\n",
    "                sess.run(mod_ops)\n",
    "            print(\n",
    "                \"completed epoch {} training with: loss {} / acc {}\".format(\n",
    "                    epoch,\n",
    "                    numpy.mean(train_losses).item(),\n",
    "                    numpy.mean(train_acc).item() * 100,\n",
    "                )\n",
    "            )\n",
    "\n",
    "            print(\"validating for epoch {}...\".format(epoch))\n",
    "            sess.run(val_iter.initializer)\n",
    "            val_losses = []\n",
    "            val_acc = []\n",
    "\n",
    "            for step in range(val_steps):\n",
    "                meas_loss, meas_acc = sess.run(\n",
    "                    [loss, acc],\n",
    "                    feed_dict={handle: val_iter_handle, training: False},\n",
    "                )\n",
    "                val_losses.append(meas_loss)\n",
    "                val_acc.append(meas_acc)\n",
    "\n",
    "            print(\n",
    "                \"completed epoch {} validation with: loss {} / acc {}\".format(\n",
    "                    epoch,\n",
    "                    numpy.mean(val_losses).item(),\n",
    "                    numpy.mean(val_acc).item() * 100,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        #######################################################\n",
    "        # Final line for sparseml training in TensorFlow, complete the graph\n",
    "        #######################################################\n",
    "        manager.complete_graph()\n",
    "\n",
    "        NAME = \"resnet50-imagenette-pruned\"\n",
    "        checkpoint_path = create_unique_dir(\n",
    "            os.path.join(\".\", SAVE_DIR, NAME, \"checkpoint\")\n",
    "        )\n",
    "        checkpoint_path = os.path.join(checkpoint_path, \"model\")\n",
    "        create_dirs(checkpoint_path)\n",
    "        saver = ModelRegistry.saver(MODEL_NAME)\n",
    "        saver.save(sess, checkpoint_path)\n",
    "        print(\"saved model checkpoint to {}\".format(checkpoint_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 - Exporting to ONNX\n",
    "\n",
    "Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the DeepSparse Engine. For TensorFlow, exporting to ONNX is not natively supported. To add support, you will use the `tf2onnx` Python package. In the cell block below, a convenience class, `GraphExporter()`, is used to handle exporting. It wraps the somewhat complicated API for `tf2onnx` into an easy to use interface.\n",
    "\n",
    "Note, for some configurations, the tf2onnx code does not work properly in a Jupyter Notebook. To remedy this, you should run the `exporter.export_onnx()` function call in a Python console or script.\n",
    "\n",
    "Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sparseml.utils import clean_path\n",
    "from sparseml.tensorflow_v1.utils import GraphExporter\n",
    "\n",
    "\n",
    "export_path = clean_path(os.path.join(\".\", SAVE_DIR, NAME))\n",
    "exporter = GraphExporter(export_path)\n",
    "\n",
    "with tf.Graph().as_default() as graph:\n",
    "    print(\"Recreating graph...\", flush=True)\n",
    "\n",
    "    input_placeholder = tf.placeholder(\n",
    "        tf.float32, [None, INPUT_SIZE, INPUT_SIZE, 3], name=\"inputs\"\n",
    "    )\n",
    "    logits = create_model(input_placeholder, training=False)\n",
    "\n",
    "    input_names = [input_placeholder.name]\n",
    "    output_names = [logits.name]\n",
    "\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        print(\"Restoring previous weights...\", flush=True)\n",
    "        saver = ModelRegistry.saver(MODEL_NAME)\n",
    "        saver.restore(sess, checkpoint_path)\n",
    "\n",
    "        print(\"Exporting to pb...\", flush=True)\n",
    "        exporter.export_pb(outputs=[logits])\n",
    "        print(\"Exported pb file to {}\".format(exporter.pb_path), flush=True)\n",
    "\n",
    "print(\"Exporting to onnx...\", flush=True)\n",
    "exporter.export_onnx(inputs=input_names, outputs=output_names)\n",
    "print(\"Exported onnx file to {}\".format(exporter.onnx_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:\n",
    "* Pruning different models using SparseML\n",
    "* Trying different pruning and optimization recipes\n",
    "* Running your model on the DeepSparse Engine"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
