{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🖼️ Getting Started\n",
    "\n",
    "Welcome to Composer! If you're wondering how this thing works, then you're in the right place. This is the first in our series of tutorials aimed at getting you familiar with the core concepts, capabilities, and patterns that make Composer such a useful tool!\n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "For the sake of this tutorial, we won't assume you know much about Composer, but a quick read of the [Welcome Tour][welcome_tour] wouldn't hurt.\n",
    "\n",
    "Composer is heavily built around PyTorch, though, so we will assume some basic familiarity with it. If you're unfamiliar with PyTorch, we recommend spending some time getting used to its [basics][pytorch_basics].\n",
    "\n",
    "### Tutorial Goals and Concepts Covered\n",
    "\n",
    "We're going to focus this introduction to Composer around an old classic: training a ResNet56 model on CIFAR10.\n",
    "\n",
    "Our focus here running this with the [Composer Trainer][trainer].\n",
    "\n",
    "As part of using the Trainer, we'll set up:\n",
    "\n",
    "* [Dataloaders](#Dataset-and-DataLoader)\n",
    "* [A ResNet56 model](#Model)\n",
    "* [An optimizer and learning rate scheduler](#Optimizer-and-Scheduler)\n",
    "* [A logger](#Logging)\n",
    "\n",
    "We'll show how to [assemble these pieces and train with the Trainer](#Train-a-Baseline-Model)\n",
    "\n",
    "And finally we'll demonstrate [training with speed-up algorithms!](#Use-Algorithms-to-Speed-Up-Training)\n",
    "\n",
    "Along the way we'll illustrate some ways to visualize training progress using loggers.\n",
    "\n",
    "Let's get started!\n",
    "\n",
    "\n",
    "[welcome_tour]: https://docs.mosaicml.com/projects/composer/en/stable/getting_started/welcome_tour.html\n",
    "[pytorch_basics]: https://pytorch.org/tutorials/beginner/basics/intro.html\n",
    "[trainer]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Trainer.html#trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Composer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start by installing Composer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Up Our Workspace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports\n",
    "\n",
    "In this section we'll set up our workspace. We'll import the necessary packages, and set up our dataset and trainer. First, the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.utils.data\n",
    "\n",
    "import composer\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torchvision import datasets, transforms\n",
    "from composer.loggers import InMemoryLogger\n",
    "\n",
    "torch.manual_seed(42) # For replicability"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset and DataLoader\n",
    "<a id=\"Dataset-&-DataLoader\"></a>\n",
    "\n",
    "Now we'll start setting up the ingredients for the Composer Trainer! Let's start with dataloaders...\n",
    "\n",
    "Here, we instantiate our CIFAR-10 dataset and dataloader. Composer has its own CIFAR-10 dataset and dataloaders for convenience, but we'll stick with the Torchvision CIFAR-10 and PyTorch dataloader for the sake of familiarity.\n",
    "\n",
    "As a bit of extra detail, there are [three ways][composer_dataloaders] of passing training and/or evaluation dataloaders to the Trainer. If you're used to working with [PyTorch dataloaders][pytorch_dataloaders], good news—that's one of the ways! The below code should look pretty familiar if you're coming from PyTorch, but if not, now may be a good time to familiarize yourself with those basics.\n",
    "\n",
    "[composer_dataloaders]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/dataloaders.html\n",
    "[pytorch_dataloaders]: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_directory = \"./data\"\n",
    "\n",
    "# Normalization constants\n",
    "mean = (0.507, 0.487, 0.441)\n",
    "std = (0.267, 0.256, 0.276)\n",
    "\n",
    "batch_size = 1024\n",
    "\n",
    "cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)\n",
    "test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)\n",
    "\n",
    "# Our train and test dataloaders are PyTorch DataLoader objects!\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "<a id=\"Model\"></a>\n",
    "\n",
    "Next, we create our model. We're using Composer's built-in ResNet56. To use your own custom model, please see this [custom model example][custom_model_example].\n",
    "\n",
    "**Note**: The model below is an instance of a [ComposerModel][composer_model]. Models need to be wrapped in this class, which provides a convenient interface between the underlying PyTorch model and the Trainer.\n",
    "\n",
    "[custom_model_example]: https://docs.mosaicml.com/projects/composer/en/stable/composer_model.html#using-your-own-model\n",
    "[composer_model]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.ComposerModel.html#composer.ComposerModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\"A ResNet block.\"\"\"\n",
    "\n",
    "    def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        stride = 2 if downsample else 1\n",
    "        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(f_out)\n",
    "        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(f_out)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # No parameters for shortcut connections.\n",
    "        if downsample or f_in != f_out:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
    "                nn.BatchNorm2d(f_out),\n",
    "            )\n",
    "        else:\n",
    "            self.shortcut = nn.Sequential()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return self.relu(out)\n",
    "\n",
    "class ResNetCIFAR(nn.Module):\n",
    "    \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
    "\n",
    "    def __init__(self, outputs: int = 10):\n",
    "        super(ResNetCIFAR, self).__init__()\n",
    "\n",
    "        depth = 56\n",
    "        width = 16\n",
    "        num_blocks = (depth - 2) // 6\n",
    "\n",
    "        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
    "\n",
    "        self.num_classes = outputs\n",
    "\n",
    "        # Initial convolution.\n",
    "        current_filters = plan[0][0]\n",
    "        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(current_filters)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # The subsequent blocks of the ResNet.\n",
    "        blocks = []\n",
    "        for segment_index, (filters, num_blocks) in enumerate(plan):\n",
    "            for block_index in range(num_blocks):\n",
    "                downsample = segment_index > 0 and block_index == 0\n",
    "                blocks.append(Block(current_filters, filters, downsample))\n",
    "                current_filters = filters\n",
    "\n",
    "        self.blocks = nn.Sequential(*blocks)\n",
    "\n",
    "        # Final fc layer. Size = number of filters in last segment.\n",
    "        self.fc = nn.Linear(plan[-1][0], outputs)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn(self.conv(x)))\n",
    "        out = self.blocks(out)\n",
    "        out = F.avg_pool2d(out, out.size()[3])\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimizer and Scheduler\n",
    "<a id=\"Optimizer-and-Scheduler\"></a>\n",
    "\n",
    "Now let's create the optimizer and LR scheduler. We're using [MosaicML's SGD with decoupled weight decay][weight_decay]:\n",
    "\n",
    "[weight_decay]: https://arxiv.org/abs/1711.05101"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = composer.optim.DecoupledSGDW(\n",
    "    model.parameters(), # Model parameters to update\n",
    "    lr=0.05, # Peak learning rate\n",
    "    momentum=0.9,\n",
    "    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll assume this is being run on Colab, which means training for hundreds of epochs would take a very long time. Instead we'll train our baseline model for three epochs. The first epoch will be linear warmup, followed by two epochs of constant LR. We achieve this by instantiating a `LinearWithWarmupScheduler` class.\n",
    "\n",
    "**Note**: Composer provides a handful of different [schedulers][schedulers] to help customize your training!\n",
    "\n",
    "[schedulers]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/schedulers.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_scheduler = composer.optim.LinearWithWarmupScheduler(\n",
    "    t_warmup=\"1ep\", # Warm up over 1 epoch\n",
    "    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f\n",
    "    alpha_f=1.0\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logging\n",
    "<a id=\"Logging\"></a>\n",
    "\n",
    "Finally, we instantiate an [InMemoryLogger][in_memory_logger] that records all the data from the Composer Trainer. We will use this logger to generate data plots after we complete training.\n",
    "\n",
    "[in_memory_logger]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.loggers.InMemoryLogger.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"baseline\" = no algorithms (which is what we're doing now)\n",
    "logger_for_baseline = InMemoryLogger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a Baseline Model\n",
    "<a id=\"Train-a-Baseline-Model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we create our trainer!\n",
    "\n",
    "**The Trainer class is the workhorse of Composer.** You may be wondering what exactly it does. In short, the `Trainer` class takes a handful of ingredients (e.g., the model, data loaders, algorithms) and instructions (e.g., training duration, device) and *composes* them into a single object (here, `trainer`) that can manage **the entire training loop** described by those inputs. This lets you focus on the higher-level details of training without having to worry about things like distributed training, scheduling, memory issues, and all other kinds of low-level headaches.\n",
    "\n",
    "If you want to learn more about the Trainer, we recommend a deeper dive through our [docs][trainer_docs] and the [API reference][api]! In the meantime, you can follow the patterns in this tutorial to get going quickly.\n",
    "\n",
    "Here's a quick reference for the parameters we're specifying below:\n",
    "\n",
    "-   `model`: The model to train, an instance of [ComposerModel][composer_model] (in this case a ResNet-56)\n",
    "-   `train_dataloader`: A data loader supplying the training data. More info [here][dataloaders].\n",
    "-   `eval_dataloader`: A data loader supplying the data used during evaluation (see same reference for `train_dataloader`). Model-defined evaluation metrics will be aggregated across this dataset each evaluation round.\n",
    "-   `max_duration`: The training duration. You can use integers to specify the number of epochs or provide a [Time][time] string -- e.g., `\"50ba\"` or `\"2ep\"` for 50 batches and 2 epochs, respectively.\n",
    "-   `optimizer`: The optimizer used to update the model during training. More info [here][optimizers].\n",
    "-   `schedulers`: Any schedulers used to schedule hyperparameter (e.g., learning rate) values over the course of training. More info [here][schedulers].\n",
    "-   `device`: The device to use (e.g., CPU or GPU).\n",
    "-   `loggers`: Any loggers to use during training. More info [here][loggers].\n",
    "\n",
    "\n",
    "[trainer_docs]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/using_the_trainer.html\n",
    "[api]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Trainer.html#trainer\n",
    "[composer_model]: https://docs.mosaicml.com/projects/composer/en/stable/composer_model.html\n",
    "[dataloaders]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/dataloaders.html\n",
    "[time]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/time.html\n",
    "[optimizers]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/composer.optim.html\n",
    "[schedulers]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/schedulers.html\n",
    "[loggers]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/logging.html\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_epochs = \"3ep\" # Train for 3 epochs because we're assuming Colab environment and hardware\n",
    "device = \"gpu\" if torch.cuda.is_available() else \"cpu\" # select the device\n",
    "\n",
    "trainer = composer.trainer.Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    max_duration=train_epochs,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=lr_scheduler,\n",
    "    device=device,\n",
    "    loggers=logger_for_baseline,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We train and measure the training time below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.perf_counter()\n",
    "trainer.fit() # <-- Your training loop in action!\n",
    "end_time = time.perf_counter()\n",
    "print(f\"It took {end_time - start_time:0.4f} seconds to train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're running this on Colab, the runtime will vary a lot based on the instance. We found that the three epochs of training could take anywhere from 120-550 seconds to run, and the mean validation accuracy was typically in the range of 25%-40%."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract and Plot Logged Data\n",
    "\n",
    "We can now plot our validation accuracy over training...\n",
    "\n",
    "Now, you might be thinking, \"I don't remember logging any validation accuracy!\" And you'd be right—sort of. That's because all the logging happened automatically during the `trainer.fit()` call. The values that we'll plot were set up by the model, which defines which metric(s) to measure during evaluation (and possibly also during training). Check out the [documentation][loggers] for a deeper dive on the loggers Composer offers and how to interact with them!\n",
    "\n",
    "[loggers]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/logging.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timeseries_raw = logger_for_baseline.get_timeseries(\"metrics/eval/MulticlassAccuracy\")\n",
    "plt.plot(timeseries_raw['epoch'], timeseries_raw[\"metrics/eval/MulticlassAccuracy\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Validation Accuracy\")\n",
    "plt.title(\"Accuracy per epoch with Baseline Training\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use Algorithms to Speed Up Training\n",
    "<a id=\"Use-Algorithms-to-Speed-Up-Training\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One of the things we're most excited about at MosaicML is our arsenal of speed-up [algorithms][algorithms]. We used these algorithms to [speed up training of ResNet-50 on ImageNet by up to 7.6x][explorer]. Let's try applying a few algorithms to make our ResNet-56 more efficient.\n",
    "\n",
    "Before we jump in, here's a quick primer on Composer speed-up algorithms. Each one is implemented as an `Algorithm` class, which basically just adds some structure that controls what happens when the algorithm is applied and when in the training loop it should be applied. Adding a particular algorithm into the training loop is as simple as creating an instance of it (using args/kwargs to set any hyperparameters) and passing it to the `Trainer` during initialization. We'll see that in action below...\n",
    "\n",
    "For our first algorithm here, let's start with [Label Smoothing][label_smoothing], which serves as a form of regularization by interpolating between the target distribution and another distribution that usually has higher entropy.\n",
    "\n",
    "[algorithms]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/algorithms.html\n",
    "[explorer]: https://app.mosaicml.com/explorer/imagenet\n",
    "[label_smoothing]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/label_smoothing.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_smoothing = composer.algorithms.LabelSmoothing(0.1) # We're creating an instance of the LabelSmoothing algorithm class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's also use [BlurPool][blurpool], which increases accuracy by applying a spatial low-pass filter before the pool in max pooling and whenever using a strided convolution.\n",
    "\n",
    "[blurpool]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/blurpool.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blurpool = composer.algorithms.BlurPool(\n",
    "    replace_convs=True, # Blur before convs\n",
    "    replace_maxpools=True, # Blur before max-pools\n",
    "    blur_first=True # Blur before conv/max-pool\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our final algorithm in our improved training recipe is [Progressive Image Resizing][progressive_image_resizing]. Progressive Image Resizing initially shrinks the size of training images and slowly scales them back to their full size over the course of training. It increases throughput during the early phase of training, when the network may learn coarse-grained features that do not require the details lost by reducing image resolution.\n",
    "\n",
    "[progressive_image_resizing]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/progressive_resizing.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prog_resize = composer.algorithms.ProgressiveResizing(\n",
    "    initial_scale=.6, # Size of images at the beginning of training = .6 * default image size\n",
    "    finetune_fraction=0.34, # Train on default size images for 0.34 of total training time.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll assemble all our algorithms into a list to pass to our trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithms = [label_smoothing, blurpool, prog_resize]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's instantiate our model, optimizer, logger, and trainer again. No need to instantiate our scheduler again because it's stateless!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\"A ResNet block.\"\"\"\n",
    "\n",
    "    def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        stride = 2 if downsample else 1\n",
    "        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(f_out)\n",
    "        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(f_out)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # No parameters for shortcut connections.\n",
    "        if downsample or f_in != f_out:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
    "                nn.BatchNorm2d(f_out),\n",
    "            )\n",
    "        else:\n",
    "            self.shortcut = nn.Sequential()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return self.relu(out)\n",
    "\n",
    "class ResNetCIFAR(nn.Module):\n",
    "    \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
    "\n",
    "    def __init__(self, outputs: int = 10):\n",
    "        super(ResNetCIFAR, self).__init__()\n",
    "\n",
    "        depth = 56\n",
    "        width = 16\n",
    "        num_blocks = (depth - 2) // 6\n",
    "\n",
    "        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
    "\n",
    "        self.num_classes = outputs\n",
    "\n",
    "        # Initial convolution.\n",
    "        current_filters = plan[0][0]\n",
    "        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(current_filters)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # The subsequent blocks of the ResNet.\n",
    "        blocks = []\n",
    "        for segment_index, (filters, num_blocks) in enumerate(plan):\n",
    "            for block_index in range(num_blocks):\n",
    "                downsample = segment_index > 0 and block_index == 0\n",
    "                blocks.append(Block(current_filters, filters, downsample))\n",
    "                current_filters = filters\n",
    "\n",
    "        self.blocks = nn.Sequential(*blocks)\n",
    "\n",
    "        # Final fc layer. Size = number of filters in last segment.\n",
    "        self.fc = nn.Linear(plan[-1][0], outputs)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn(self.conv(x)))\n",
    "        out = self.blocks(out)\n",
    "        out = F.avg_pool2d(out, out.size()[3])\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger_for_algorithm_run = InMemoryLogger()\n",
    "\n",
    "optimizer = composer.optim.DecoupledSGDW(\n",
    "    model.parameters(),\n",
    "    lr=0.05,\n",
    "    momentum=0.9,\n",
    "    weight_decay=2.0e-3\n",
    ")\n",
    "\n",
    "trainer = composer.trainer.Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    max_duration=train_epochs,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=lr_scheduler,\n",
    "    device=device,\n",
    "    loggers=logger_for_algorithm_run,\n",
    "    algorithms=algorithms # Adding algorithms this time!\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And let's get training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now we're cooking with algorithms!\n",
    "start_time = time.perf_counter()\n",
    "trainer.fit()\n",
    "end_time = time.perf_counter()\n",
    "three_epochs_accelerated_time = end_time - start_time\n",
    "print(f\"It took {three_epochs_accelerated_time:0.4f} seconds to train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, the runtime will vary based on the instance, but we found that it took about **0.43x-0.75x** as long to train (a **1.3x-2.3x** speedup, which corresponds to 90-400 seconds) relative to the baseline recipe without augmentations. We also found that validation accuracy was similar for the algorithm-enhanced and baseline recipes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bonus Training!\n",
    "\n",
    "Because Progressive Resizing increases data throughput (i.e. more samples per second), we can train for more iterations in the same amount of wall clock time. Let's train our model for one additional epoch!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_epochs = \"1ep\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Resuming training means we'll need to use a flat LR schedule:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_scheduler = composer.optim.scheduler.ConstantScheduler(alpha=1.0, t_max='1dur')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can also get rid of progressive resizing (because we want to train on the full size images for this additional epoch), and the model already has blurpool enabled, so we don't need to pass that either:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithms = [label_smoothing]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger_for_bonus_1ep = InMemoryLogger()\n",
    "\n",
    "trainer = composer.trainer.Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    max_duration=train_epochs,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=lr_scheduler,\n",
    "    device=device,\n",
    "    loggers=logger_for_bonus_1ep,\n",
    "    algorithms=algorithms,\n",
    ")\n",
    "\n",
    "start_time = time.perf_counter()\n",
    "trainer.fit()\n",
    "\n",
    "end_time = time.perf_counter()\n",
    "final_epoch_accelerated_time = end_time - start_time\n",
    "# Time for four epochs = time for three epochs + time for fourth epoch\n",
    "four_epochs_accelerated_time = three_epochs_accelerated_time + final_epoch_accelerated_time\n",
    "print(f\"It took {four_epochs_accelerated_time:0.4f} seconds to train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We found that using these speed-up algorithms for four epochs resulted in runtime similar to or less than three epochs *without* speed-up algorithms (120-550 seconds, depending on the instance), and that they usually improved validation accuracy by 5-15 percentage points, yielding validation accuracy in the range of 30%-50%."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot the results from using Label Smoothing and Progressive Resizing!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Baseline (no algorithms) data\n",
    "baseline_timeseries = logger_for_baseline.get_timeseries(\"metrics/eval/MulticlassAccuracy\")\n",
    "baseline_epochs = baseline_timeseries['epoch']\n",
    "baseline_acc = baseline_timeseries[\"metrics/eval/MulticlassAccuracy\"]\n",
    "\n",
    "# Composer data\n",
    "with_algorithms_timeseries = logger_for_algorithm_run.get_timeseries(\"metrics/eval/MulticlassAccuracy\")\n",
    "with_algorithms_epochs = list(with_algorithms_timeseries[\"epoch\"])\n",
    "with_algorithms_acc = list(with_algorithms_timeseries[\"metrics/eval/MulticlassAccuracy\"])\n",
    "\n",
    "# Concatenate 3 epochs with Label Smoothing and ProgRes with 1 epoch without ProgRes\n",
    "bonus_epoch_timeseries = logger_for_bonus_1ep.get_timeseries(\"metrics/eval/MulticlassAccuracy\")\n",
    "bonus_epoch_epochs = [with_algorithms_epochs[-1] + i for i in bonus_epoch_timeseries[\"epoch\"]]\n",
    "with_algorithms_epochs.extend(bonus_epoch_epochs)\n",
    "with_algorithms_acc.extend(bonus_epoch_timeseries[\"metrics/eval/MulticlassAccuracy\"])\n",
    "\n",
    "#Print mean validation accuracies\n",
    "print(\"Baseline Validation Mean: \" + str(sum(baseline_acc)/len(baseline_acc)))\n",
    "print(\"W/ Algs. Validation Mean: \" + str(sum(with_algorithms_acc)/len(with_algorithms_acc)))\n",
    "\n",
    "# Plot both sets of data\n",
    "plt.plot(baseline_epochs, baseline_acc, label=\"Baseline\")\n",
    "plt.plot(with_algorithms_epochs, with_algorithms_acc, label=\"With Algorithms\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Validation Accuracy\")\n",
    "plt.title(\"Accuracy and speed improvements with equivalent WCT\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "You've now seen a simple example of how to use the Composer Trainer and how to take advantage of useful features like learning rate scheduling, logging, and speed-up algorithms.\n",
    "\n",
    "If you want to keep learning more, try looking through some of the documents linked throughout this tutorial to see if you can form a deeper intuition for how these examples were structured.\n",
    "\n",
    "In addition, please continue to explore our tutorials! Here's a couple suggestions:\n",
    "\n",
    "* Get to know the [functional API][functional_tutorial] for using algorithms outside the Trainer.\n",
    "\n",
    "* Explore more advanced applications of Composer like [applying image segmentation to medical images][image_segmentation_tutorial] or [fine-tuning a transformer for sentiment classification][huggingface_tutorial].\n",
    "\n",
    "* Learn about callbacks and how to apply [early stopping][early_stopping_tutorial].\n",
    "\n",
    "[functional_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/functional_api.html\n",
    "[image_segmentation_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/medical_image_segmentation.html\n",
    "[huggingface_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html\n",
    "[early_stopping_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/early_stopping.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
