{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7b755249",
   "metadata": {},
   "source": [
    "# 🥡 Exporting for Inference\n",
    "\n",
    "Now that you've trained using Composer, do you need to make your model available for inference? We've got you covered.\n",
    "\n",
    "Composer provides model export support for inference using a dedicated export API and a callback. In this tutorial, we walk through how to export your models into various common formats (e.g., [ONNX][onnx] and [TorchScript][torchscript]) using the dedicated export API as well as Composer's callback mechanism. Composer models can also be exported like any other PyTorch module since Composer models are `torch.nn.Module` instances.\n",
    "\n",
    "For more detailed options and configuration settings, please consult the linked documentation. \n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "This tutorial assumes that you're familiar with basic export formats like ONNX and TorchScript, and that you're generally up to speed on using Composer for training. If you haven't already, you may find it helpful to review our [callback docs][callback_docs] and [checkpointing docs][checkpointing_docs].\n",
    "\n",
    "### Tutorial Goals and Covered Concepts\n",
    "\n",
    "The goal of this tutorial is to showcase Composer's export utilities for making a model available for inference. \n",
    "\n",
    "We'll touch on:\n",
    "\n",
    "* [Our standalone export API](#Torchscript-Export-Using-Standalone-API)\n",
    "* [Exporting from the trainer with callbacks](#Export-Using-a-Callback)\n",
    "* [Exporting from the trainer directly](#Exporting-from-Trainer-Directly)\n",
    "* [Exporting from an existing checkpoint](#Exporting-from-an-Existing-Checkpoint)\n",
    "* [Supported Composer algorithms](#Algorithm-Compatibility)\n",
    "\n",
    "[onnx]: https://onnx.ai/\n",
    "[torchscript]: https://pytorch.org/docs/stable/jit.html\n",
    "[callback_docs]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/callbacks.html\n",
    "[checkpointing_docs]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/checkpointing.html\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46de5fce",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "First, we install Composer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c3e0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3ccf56d",
   "metadata": {},
   "source": [
    "## Create the Model\n",
    "To start, we create the model we’d like to export, which in this case is ResNet-50 with our `SqueezeExcite` algorithm applied. This algorithm adds `SqueezeExcite` modules after certain `Conv2d` layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbde37e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet\n",
    "from composer.models import ComposerClassifier\n",
    "import composer.functional as cf\n",
    "\n",
    "model = ComposerClassifier(module=resnet.resnet50(), num_classes=1000)\n",
    "cf.apply_squeeze_excite(model)\n",
    "\n",
    "# switch to eval mode\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87cc6bc6",
   "metadata": {},
   "source": [
    "## Torchscript Export Using Standalone API\n",
    "<a id=\"Torchscript-Export-Using-Standalone-API\"></a>\n",
    "\n",
    "Torchscript creates models from PyTorch code that can be saved and also optimized for deployment, and is the tooling is native to PyTorch. \n",
    "\n",
    "The `ComposerClassifier`’s forward method takes as input a pair of tensors `(input, label)`, so we create dummy tensors to run the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02603737",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "input = (torch.rand(4, 3, 224, 224), torch.Tensor())\n",
    "\n",
    "output = model(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5668ffed",
   "metadata": {},
   "source": [
    "Now we run export using our standalone export API. Composer also supports exporting to an object store such as S3. For more info on using an object store, please checkout our full [documentation](https://docs.mosaicml.com/projects/composer/en/stable/api_reference/composer.utils.inference.html) for the `export_for_inference` API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e654bff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tempfile\n",
    "from composer.utils import export_for_inference\n",
    "\n",
    "save_format = 'torchscript'\n",
    "working_dir = tempfile.TemporaryDirectory()\n",
    "model_save_path = os.path.join(working_dir.name, 'model.pt')\n",
    "\n",
    "export_for_inference(model=model, \n",
    "                     save_format=save_format, \n",
    "                     save_path=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33d5a8ec",
   "metadata": {},
   "source": [
    "Check to make sure that the model exists in our working directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07519d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.listdir(path=working_dir.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda90bdf",
   "metadata": {},
   "source": [
    "Reload the saved model and run inference on it. We'll also compare the results with the previously computed results on the same input as a sanity check.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8726defb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripted_model = torch.jit.load(model_save_path)\n",
    "scripted_model.eval()\n",
    "scripted_output = scripted_model(input)\n",
    "print(torch.allclose(output, scripted_output))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e154ac3d",
   "metadata": {},
   "source": [
    "## Export Using a Callback\n",
    "<a id=\"Export-Using-a-Callback\"></a>\n",
    "\n",
    "The Composer trainer also lets you specify an export callback that automatically exports at the end of training. Since we will be training a model for a few epochs, we'll first create a dataloader with CIFAR for this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5141e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "mnist_transforms = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "dataset = datasets.MNIST(\"./data\", train=True, download=True, transform=mnist_transforms)\n",
    "dataloader = DataLoader(dataset=dataset, batch_size=4)\n",
    "input_mnist = (torch.rand(4, 1, 28, 28), torch.Tensor())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4de762c7",
   "metadata": {},
   "source": [
    "## Create the Model\n",
    "\n",
    "We create the model we are training, which in this case is a ResNet-50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6197713c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "class ToyModel(nn.Module):\n",
    "    \"\"\"Toy convolutional neural network architecture in pytorch for MNIST.\"\"\"\n",
    "\n",
    "    def __init__(self, num_classes: int = 10):\n",
    "        super().__init__()\n",
    "\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "        self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)\n",
    "        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)\n",
    "        self.bn = nn.BatchNorm2d(32)\n",
    "        self.fc1 = nn.Linear(32 * 16, 32)\n",
    "        self.fc2 = nn.Linear(32, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn(out)\n",
    "        out = F.relu(out)\n",
    "        out = F.adaptive_avg_pool2d(out, (4, 4))\n",
    "        out = torch.flatten(out, 1, -1)\n",
    "        out = self.fc1(out)\n",
    "        out = F.relu(out)\n",
    "        return self.fc2(out)\n",
    "\n",
    "\n",
    "model = ComposerClassifier(module=ToyModel(num_classes=10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2b15fb6",
   "metadata": {},
   "source": [
    "## Create the Export Callback\n",
    "Now we create a callback that is used by the trainer to export the model for inference. Since we already saw torchscript export using Composer's standalone export API, we are using `onnx` as our export format for this section to showcase both capabilities. You can easily choose between these options by setting `save_format` to whichever of `'onnx'` or `'torchscript'` you prefer.\n",
    "\n",
    "**Note**: ONNX does not have a prebuilt wheel for Mac M1/M2 chips yet, so is not pip installable on recent Mac computers. Skip this section if your computer has an M1/M2 chip."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24b649ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import composer.functional as cf\n",
    "from composer.callbacks import ExportForInferenceCallback\n",
    "# change to 'torchscript' for exporting to torchscript format \n",
    "save_format = 'onnx'\n",
    "model_save_path = os.path.join(working_dir.name, 'model1.onnx')\n",
    "export_callback = ExportForInferenceCallback(save_format=save_format, save_path=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5408fbe1",
   "metadata": {},
   "source": [
    "## Run Training\n",
    "Now we construct the trainer using this callback. The model is exported at the end of the training. In the later part of this tutorail we show model exporting from a checkpoint, so we also supply trainer `save_folder` and `save_interval` to save some checkpoints. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "220d936b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from composer import Trainer\n",
    "from composer.algorithms import SqueezeExcite\n",
    "from composer.optim import DecoupledSGDW\n",
    "\n",
    "optimizer = DecoupledSGDW(model.parameters(), lr=0.01)\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=dataloader,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=scheduler,\n",
    "    save_folder=working_dir.name,\n",
    "    algorithms=[SqueezeExcite()],\n",
    "    callbacks=[export_callback],\n",
    "    max_duration='2ep',\n",
    "    save_interval='1ep',\n",
    "    save_overwrite=True,\n",
    ")\n",
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c06abf1f",
   "metadata": {},
   "source": [
    "Let's list the content of the `working_dir` to check if the checkpoints and exported model is available. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b38073",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.listdir(path=working_dir.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beb2f655",
   "metadata": {},
   "source": [
    "## Exporting from Trainer Directly\n",
    "<a id=\"Exporting-From-Trainer-Directly\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3612cdae",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_save_path = os.path.join(working_dir.name, 'model2.onnx')\n",
    "\n",
    "trainer.export_for_inference(save_format='onnx', save_path=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55470cc5",
   "metadata": {},
   "source": [
    "Similarly, let's list the content of the `working_dir` to see if this exported model is available. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac1ef9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.listdir(path=working_dir.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8517a2d2",
   "metadata": {},
   "source": [
    "## Load and Run Exported ONNX Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e496af25",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install onnx\n",
    "%pip install onnxruntime"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b255546",
   "metadata": {},
   "source": [
    "Let's load the model and check that everything was exported properly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1438c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "\n",
    "onnx_model = onnx.load(model_save_path)\n",
    "onnx.checker.check_model(onnx_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8623a917",
   "metadata": {},
   "source": [
    "Lastly, we can run inference with the model and check that the model indeed runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac417341",
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as ort\n",
    "import numpy as np\n",
    "\n",
    "# run inference\n",
    "ort_session = ort.InferenceSession(model_save_path, providers=['CPUExecutionProvider'])\n",
    "outputs = ort_session.run(\n",
    "    None,\n",
    "    {'input': input_mnist[0].numpy()})\n",
    "print(f\"The predicted classes are {np.argmax(outputs[0], axis=1)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bc52f62",
   "metadata": {},
   "source": [
    "If our input is a dictionary, as if often the case when using a Composer [HuggingFaceModel](https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html), we'll need to make sure all the elements of our input dictionary are numpy arrays before calling `ort_session.run()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca091f8e",
   "metadata": {},
   "source": [
    "**Note**: Since the model is randomly initialized, and the input tensor is random, the output classes in this example have no meaning. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "84137136",
   "metadata": {},
   "source": [
    "## Exporting from an Existing Checkpoint\n",
    "<a id=\"Exporting-from-an-Existing-Checkpoint\"></a>\n",
    "\n",
    "In this part of the tutorial, we will look at exporting a model from a previously created checkpoint that is stored locally. Composer also supports exporting from a checkpoint stored in an object store such as S3. Please checkout the [full documentation][docs] for `export_for_inference` API for using an object store. \n",
    "\n",
    "Some of our algorithms alter the model architecture. For example, [SqueezeExcite][squeezeexcite] adds a channel-wise attention operator in CNNs and modifies the model architecure. Therefore, we need to provide a function that takes the mode and applies the algorithm before we can load the model weights from a checkpoint. The functional form of SqueezeExcite does exactly that, and we pass this function in the `surgery_algs` argument to the `export_for_inference` API. \n",
    "\n",
    "[docs]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.utils.export_for_inference.html\n",
    "[squeezeexcite]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/squeeze_excite.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9da03fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.listdir(working_dir.name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d51ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.utils import export_for_inference\n",
    "# We call it model2.onnx to make it different from our previous export\n",
    "model_save_path = os.path.join(working_dir.name, 'model2.onnx')\n",
    "checkpoint_path = os.path.join(working_dir.name, 'ep2-ba4-rank0.pt')\n",
    "\n",
    "model = ComposerClassifier(module=ToyModel(num_classes=10))\n",
    "\n",
    "export_for_inference(model=model,\n",
    "                     save_format=save_format, \n",
    "                     save_path=model_save_path, \n",
    "                     sample_input=(input_mnist, {}),\n",
    "                     surgery_algs=[cf.apply_squeeze_excite],\n",
    "                     load_path=checkpoint_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b1aab7a",
   "metadata": {},
   "source": [
    "Let us list the content of the working_dir to check if the newly exported model is available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28fd1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.listdir(path=working_dir.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e02a7a9d",
   "metadata": {},
   "source": [
    "Make sure the model loaded from a checkpoint produces the same results as before"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fea266b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_session = ort.InferenceSession(model_save_path, providers=['CPUExecutionProvider'])\n",
    "new_outputs = ort_session.run(\n",
    "    None,\n",
    "    {'input': input_mnist[0].numpy()},\n",
    ")\n",
    "print(np.allclose(outputs[0], new_outputs[0], atol=1e-07))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7a48014",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean up working directory\n",
    "working_dir.cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5909d39f",
   "metadata": {},
   "source": [
    "## Torch.fx\n",
    "\n",
    "FX is a recent toolkit to transform PyTorch modules that allows for advanced graph manipulation and code generation capabilities. Eventually, PyTorch will add quantization and other optimization procedures on top of FX (e.g. see [FX Graph Mode Quantization][torchfx]. Composer is also starting to add algorithms that use `torch.fx` for graph optimization, so look forward to more of these in the future!\n",
    "\n",
    "Tracing a model with `torch.fx` is fairly straightforward:\n",
    "\n",
    "[torchfx]: https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0039f4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_model = torch.fx.symbolic_trace(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8788f688",
   "metadata": {},
   "source": [
    "Then, we can see all the nodes in the graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4fd3c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_model.graph.print_tabular()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "140aebd9",
   "metadata": {},
   "source": [
    "And also run inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91ce75a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = traced_model(input_mnist)\n",
    "print(f\"The predicted classes are {torch.argmax(output, dim=1)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "418159b8",
   "metadata": {},
   "source": [
    "`torch.fx` is powerful, but one of the key limitations of this tool is that it does not support dynamic control flow (e.g. `if` statements or loops that are data-dependant). Therefore, some algorithms, such as BlurPool, are currently not supported. We have ongoing work to bring `torch.fx` support to all our algorithms."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d58df1b",
   "metadata": {},
   "source": [
    "## Algorithm Compatibility\n",
    "<a id=\"Algorithm-Compatability\"></a>\n",
    "\n",
    "Some of our algorithms alter the model architecture in ways that may render them incompatible with some of the export procedures above. For example, BlurPool replaces some instances of `Conv2d` with `BlurConv2d` layers which are not yet compatible with `torch.fx`. \n",
    "\n",
    "The following table shows which algorithms are compatible with which export formats for inference.\n",
    "\n",
    "|                        | torchscript | torch.fx | ONNX |\n",
    "|------------------------|-------------|----------|------|\n",
    "| apply_blurpool         | &check;           |          | &check;    |\n",
    "| apply_factorization    |             | &check;        | &check;    |\n",
    "| apply_ghost_batchnorm  | &check;           |          | &check;    |\n",
    "| apply_squeeze_excite   | &check;           | &check;        | &check;    |\n",
    "| apply_stochastic_depth | &check;           | &check;        | &check;    |\n",
    "| apply_channels_last    | &check;           | &check;        | &check;    |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b802d02a",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "You've now seen all the ways that Composer enables you to make your trained models available for downstream inference.\n",
    "\n",
    "To keep learning more, please continue to explore our tutorials! Here's a suggestion:\n",
    "\n",
    "* Check out our beta support for [training on TPUs][tpu_training].\n",
    "\n",
    "[tpu_training]: https://docs.mosaicml.com/projects/composer/en/stable/examples/TPU_Training_in_composer.html\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3870c72",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
