{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🤖 Custom Speedup Methods\n",
    "\n",
    "One of Composer's superpowers is its suite of speed-up algorithms. By default, Composer comes packed with a (growing) set of algorithms, but it is also easy to add your own! This tutorial shows you how.\n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "This tutorial assumes that you have a working familiarity with PyTorch training loops and a general familiarity with Composer. Make sure you're comfortable with the material in the [getting started tutorial][getting_started] before you embark on this slightly more advanced tutorial.\n",
    "\n",
    "### Tutorial Goals and Concepts Covered\n",
    "\n",
    "The goal of this tutorial is to illustrate the process of implementing a custom method to work with the Composer trainer. In order, it will cover:\n",
    "\n",
    "* [Implementing a method using standard PyTorch](#implementing-a-method-with-pytorch)\n",
    "* [Brief overview of events and state in Composer](#events-and-state-in-composer)\n",
    "* [Modifying the implementation to work with Composer](#implementing-a-method-with-composer)\n",
    "* [Composing multiple methods together](#composing-multiple-methods)\n",
    "\n",
    "We will see that getting a method working with Composer is not very different from getting it working with vanilla PyTorch. With just a couple extra steps we can use our algorithm while enjoying the flexibility and simplicity of Composer!\n",
    "\n",
    "[getting_started]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Composer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, installation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git\n",
    "\n",
    "%pip install matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll set up a toy resnet model for CIFAR. This will be useful throughout this tutorial as an example model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\"A ResNet block.\"\"\"\n",
    "\n",
    "    def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        stride = 2 if downsample else 1\n",
    "        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(f_out)\n",
    "        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(f_out)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # No parameters for shortcut connections.\n",
    "        if downsample or f_in != f_out:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
    "                nn.BatchNorm2d(f_out),\n",
    "            )\n",
    "        else:\n",
    "            self.shortcut = nn.Sequential()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return self.relu(out)\n",
    "\n",
    "class ResNetCIFAR(nn.Module):\n",
    "    \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
    "\n",
    "    def __init__(self, outputs: int = 10):\n",
    "        super(ResNetCIFAR, self).__init__()\n",
    "\n",
    "        depth = 56\n",
    "        width = 16\n",
    "        num_blocks = (depth - 2) // 6\n",
    "\n",
    "        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
    "\n",
    "        self.num_classes = outputs\n",
    "\n",
    "        # Initial convolution.\n",
    "        current_filters = plan[0][0]\n",
    "        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(current_filters)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # The subsequent blocks of the ResNet.\n",
    "        blocks = []\n",
    "        for segment_index, (filters, num_blocks) in enumerate(plan):\n",
    "            for block_index in range(num_blocks):\n",
    "                downsample = segment_index > 0 and block_index == 0\n",
    "                blocks.append(Block(current_filters, filters, downsample))\n",
    "                current_filters = filters\n",
    "\n",
    "        self.blocks = nn.Sequential(*blocks)\n",
    "\n",
    "        # Final fc layer. Size = number of filters in last segment.\n",
    "        self.fc = nn.Linear(plan[-1][0], outputs)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn(self.conv(x)))\n",
    "        out = self.blocks(out)\n",
    "        out = F.avg_pool2d(out, out.size()[3])\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='implementing-a-method-with-pytorch'></a>\n",
    "\n",
    "## Implementing a Method with PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we'll go through the process of implementing a new method without using Composer. First, some relevant imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.data\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "torch.manual_seed(42)\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, set up some training data. For simplicity, we will use CIFAR10 with minimal preprocessing. All we do is convert elements to tensors and normalize them using a randomly chosen mean and standard deviation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = (0.507, 0.487, 0.441)\n",
    "std = (0.267, 0.256, 0.276)\n",
    "\n",
    "c10_transforms = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean, std)\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=c10_transforms)\n",
    "test_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=c10_transforms)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will define a model. For this, we will simply use Composer's ResNet-50. One quirk to be aware of with this model is that the forward method takes in an `(X, y)` pair of inputs and targets, essentially what the dataloaders will spit out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet\n",
    "from composer.models.tasks import ComposerClassifier\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we'll define a function to test the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, test_loader, device):\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model((data, target))\n",
    "            output = F.log_softmax(output, dim=1)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(test_loss, correct, len(test_loader.dataset),\n",
    "          100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will train it for a single epoch to check that things are working. We'll also test before and after training to check how accuracy changes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model.to(device)\n",
    "test(model, test_dataloader, device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "model.train()\n",
    "for batch_idx, (data, target) in enumerate(train_dataloader):\n",
    "    data, target = data.to(device), target.to(device)\n",
    "    optimizer.zero_grad()\n",
    "    output = model((data, target))\n",
    "    output = F.log_softmax(output, dim=1)\n",
    "    loss = F.nll_loss(output, target)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "test(model, test_dataloader, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like things are working! Time to implement our own modification to the training procedure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementing ColOut\n",
    "\n",
    "\n",
    "For this tutorial, we'll look at how to implement one of the simpler speedup methods currently in our composer library: [ColOut](https://docs.mosaicml.com/projects/composer/en/stable/method_cards/colout.html). This method works on image data by dropping random rows and columns from the training images. This reduces the size of the training images, which reduces the time per training iteration and hopefully does not alter the semantic content of the image too much. Additionally, dropping a small fraction of random rows and columns can also slightly distort objects and perhaps provide a data augmentation effect.\n",
    "\n",
    "To start our implementation, we'll write a function to drop random rows and columns from a batch of input images. We'll assume that these are torch tensors and operate on a batch, rather than individual images, for simplicity here.\n",
    "\n",
    "[colout]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/colout.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_colout(X, p_row, p_col):\n",
    "    # Get the dimensions of the image\n",
    "    row_size = X.shape[2]\n",
    "    col_size = X.shape[3]\n",
    "\n",
    "    # Determine how many rows and columns to keep\n",
    "    kept_row_size = int((1 - p_row) * row_size)\n",
    "    kept_col_size = int((1 - p_col) * col_size)\n",
    "\n",
    "    # Randomly choose indices to keep. Must be sorted for slicing\n",
    "    kept_row_idx = sorted(torch.randperm(row_size)[:kept_row_size].numpy())\n",
    "    kept_col_idx = sorted(torch.randperm(col_size)[:kept_col_size].numpy())\n",
    "\n",
    "    # Keep only the selected row and columns\n",
    "    X_colout = X[:, :, kept_row_idx, :]\n",
    "    X_colout = X_colout[:, :, :, kept_col_idx]\n",
    "    return X_colout"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is very simple, but as a check, we should visualize what this does to the data:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = next(iter(train_dataloader))\n",
    "\n",
    "X_colout_1 = batch_colout(X, 0.1, 0.1)\n",
    "X_colout_2 = batch_colout(X, 0.2, 0.2)\n",
    "X_colout_3 = batch_colout(X, 0.3, 0.3)\n",
    "\n",
    "def unnormalize(X, mean, std):\n",
    "    X *= torch.tensor(std).view(1, 3, 1, 1)\n",
    "    X += torch.tensor(mean).view(1, 3, 1, 1)\n",
    "    X = X.permute(0,2,3,1)\n",
    "    return X\n",
    "\n",
    "X = unnormalize(X, mean, std)\n",
    "X_colout_1 = unnormalize(X_colout_1, mean, std)\n",
    "X_colout_2 = unnormalize(X_colout_2, mean, std)\n",
    "X_colout_3 = unnormalize(X_colout_3, mean, std)\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(20,5))\n",
    "axes[0].imshow(X[0])\n",
    "axes[0].set_title(\"Unmodified\", fontsize=18)\n",
    "axes[1].imshow(X_colout_1[0])\n",
    "axes[1].set_title(\"p_row = 0.1, p_col = 0.1\", fontsize=18)\n",
    "axes[2].imshow(X_colout_2[0])\n",
    "axes[2].set_title(\"p_row = 0.2, p_col = 0.2\", fontsize=18)\n",
    "axes[3].imshow(X_colout_3[0])\n",
    "axes[3].set_title(\"p_row = 0.3, p_col = 0.3\", fontsize=18)\n",
    "\n",
    "for ax in axes:\n",
    "    ax.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like things are behaving as they should! Now let's insert it into our training loop. We'll also reinitialize the model here for a fair comparison with our earlier, single epoch run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.models import ComposerClassifier\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we perform colout on the batch of data that the dataloader spits out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "p_row = 0.15\n",
    "p_col = 0.15\n",
    "model.to(device)\n",
    "test(model, test_dataloader, device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "model.train()\n",
    "for batch_idx, (data, target) in enumerate(train_dataloader):\n",
    "    data, target = data.to(device), target.to(device)\n",
    "    ### Insert ColOut here ###\n",
    "    data = batch_colout(data, p_row, p_col)\n",
    "    ### ------------------ ###\n",
    "    optimizer.zero_grad()\n",
    "    output = model((data, target))\n",
    "    output = F.log_softmax(output, dim=1)\n",
    "    loss = F.nll_loss(output, target)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "test(model, test_dataloader, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Accuracy is pretty similar, and the wall clock time is definitely faster. Perhaps this method does provide a speedup in this case, but to do a proper evaluation we would want to look at the Pareto curves, similar to the data in our [explorer][explorer].\n",
    "\n",
    "This style of implementation is similar to the functional implementations in the Composer library. Since ColOut is already implemented there, we could have simply done:\n",
    "\n",
    "```python\n",
    "import composer.functional as cf\n",
    "data = cf.colout_batch(data, p_row, p_col)\n",
    "```\n",
    "\n",
    "[explorer]: https://app.mosaicml.com/explorer/imagenet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A natural question here is how do we combine ColOut with other methods? One way to do so would be to simply repeat the process we went through above for each method we want to try, inserting it into the training loop where appropriate. However, this can quickly become unwieldy and makes it difficult to run experiments using many different combinations of methods. **This is the problem Composer aims to solve!**\n",
    "\n",
    "In the following sections, we will modify our above implementation to work with Composer so that we can run many methods together. The modifications we need to make are fairly simple. In essence, we just need a way to tell Composer where in the training loop to insert our method and track the appropriate objects our method acts on. With Composer, we can insert our method into the training loop at what Composer calls an `event` and track what our method needs to modify in what Composer calls `state`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"events-and-state-in-composer\"></a>\n",
    "\n",
    "## Events and State in Composer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training loop Composer uses provides multiple different locations where method code can be inserted and run. These are called events. Diagramatically, the training loop looks as follows:\n",
    "\n",
    "![Training Loop](https://storage.googleapis.com/docs.mosaicml.com/images/training_loop.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At the top, we see the different steps of the Composer training loop. At the bottom, we see the many events that occur at different places within the training loop. For example, the event `EVENT.BEFORE_FORWARD` occurs just before the forward pass through the model, but after dataloading and preprocessing has taken place. These events are the points at which we can run the code for the method we want to implement.\n",
    "\n",
    "Most methods require making modifications to some object used in training, such as the model itself, the input/output data, training hyperparameters, etc. These quantities are tracked in Composer's `State` object, which can be found [here][state]. \n",
    "\n",
    "[state]: https://github.com/mosaicml/composer/blob/dev/composer/core/state.py"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"implementing-a-method-with-composer\"></a>\n",
    "\n",
    "## Implementing a method with Composer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's time to set up our method to work with Composer's trainer. To do this we will wrap the ColOut transformation we wrote up above in a class that inherits from Composer's base `Algorithm` class. Then we will need to implement two methods within that class: a `match` method that tells Composer which `event` we want ColOut to run on and an `apply` method that tells Composer how to run ColOut. First, some relevant imports from Composer:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer import Trainer\n",
    "from composer.core import Algorithm, Event"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before, we inserted ColOut into the training loop after getting a batch from the dataloader, but before the forward pass. As such, it makes sense for us to run ColOut on the event `EVENT.AFTER_DATALOADER`. The `match` method for this will simply check that the current event is this one and return `True` if it is, and `False` otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def match(self, event, state):\n",
    "    return event == Event.AFTER_DATALOADER"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `apply` method is also simple in this case. It will just tell Composer how to run the function we already wrote and how to save the results in `state`.\n",
    "\n",
    "*Useful note: The active minibatch is always kept in `state.batch`, and algorithms are free to access and modify the minibatch, as shown below.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply(self, event, state, logger):\n",
    "    inputs, labels = state.batch\n",
    "    new_inputs = batch_colout(\n",
    "        inputs,\n",
    "        p_row=self.p_row,\n",
    "        p_col=self.p_col\n",
    "    )\n",
    "    state.batch = (new_inputs, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Packaging this together into an `Algorithm` class gives our full Composer-ready ColOut implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ColOut(Algorithm):\n",
    "    def __init__(self, p_row=0.15, p_col=0.15):\n",
    "        self.p_row = p_row\n",
    "        self.p_col = p_col\n",
    "\n",
    "    def match(self, event, state):\n",
    "        return event == Event.AFTER_DATALOADER\n",
    "\n",
    "    def apply(self, event, state, logger):\n",
    "        inputs, labels = state.batch\n",
    "        new_inputs = batch_colout(\n",
    "            inputs,\n",
    "            p_row=self.p_row,\n",
    "            p_col=self.p_col\n",
    "        )\n",
    "        state.batch = (new_inputs, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we'll create a model/optimizer, train it, and test it similar to how we did before with PyTorch, but this time we'll let Composer handle the training!\n",
    "\n",
    "To do this we'll create a Composer `trainer`. This should look familiar if you've already gone through the other tutorials.\n",
    "\n",
    "We'll have to give the trainer our `model` and the two dataloaders. We'll tell it to run for one epoch by setting `max_duration='1ep'` and run on the gpu by setting `device='gpu'`. Since we're handling the testing ourselves in this example, we'll turn off Composer's validation by setting `eval_interval=0`. Finally, we'll set the seed for reproducibility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Baseline (i.e., without ColOut)\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    max_duration='1ep',\n",
    "    eval_interval=0,\n",
    "    seed=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# Use our testing code and train with trainer.fit()\n",
    "test(model, test_dataloader, device)\n",
    "trainer.fit()\n",
    "test(model, test_dataloader, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's do the same thing but this time we'll add ColOut!\n",
    "\n",
    "We'll recreate the model and optimizer and also create `colout_method`, an instance of the `ColOut` class we implemented above. To train using Colout, we just need to set `algorithms=[colout_method]` when constructing the trainer!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# With ColOut\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "# An instance of our ColOut algorithm class!\n",
    "colout_method = ColOut()\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    max_duration='1ep',\n",
    "    algorithms=[colout_method],\n",
    "    eval_interval=0,\n",
    "    seed=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# Use our testing code and train with trainer.fit()\n",
    "test(model, test_dataloader, device)\n",
    "trainer.fit()\n",
    "test(model, test_dataloader, device)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"composing-multiple-methods\"></a>\n",
    "\n",
    "## Composing multiple methods\n",
    "\n",
    "Now that we've implemented ColOut as a Composer algorithm, composing it with other methods is easy!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we'll compose our custom ColOut method with another method from the composer library, `BlurPool`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.algorithms.blurpool import BlurPool"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the blurpool object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blurpool = BlurPool(\n",
    "    replace_convs=True,\n",
    "    replace_maxpools=True,\n",
    "    blur_first=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And add it to the list of methods for the trainer to run. In general, we can pass in as many methods as we want, and Composer will run them all together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# With ColOut and BlurPool\n",
    "from composer.models import ComposerClassifier\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=test_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    max_duration='1ep',\n",
    "    algorithms=[colout_method, blurpool],\n",
    "    eval_interval=0,\n",
    "    seed=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# Use our testing code and train with trainer.fit()\n",
    "test(model, test_dataloader, device)\n",
    "trainer.fit()\n",
    "test(model, test_dataloader, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "You've now seen a simple example of how to take custom algorithms and make them usable with the Composer Trainer! For more complex methods, the process is the same, but might require using different `events`, or modifying different things in `state`. Here are some interesting examples of other methods in the Composer library, which we encourage you to dig into if you want to see more:\n",
    "\n",
    "* [BlurPool][blurpool] swaps out some of the layers of the network.\n",
    "* [LayerFreezing][layerfreezing] changes which network parameters are trained at different epochs.\n",
    "* [RandAugment][randaugment] adds an additional data augmentation.\n",
    "* [SelectiveBackprop][selectivebackprop] changes which samples are used to compute gradients.\n",
    "* [SAM][sam] changes the optimizer used for training.\n",
    "\n",
    "[blurpool]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/blurpool.html\n",
    "[layerfreezing]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/layer_freezing.html\n",
    "[randaugment]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/randaugment.html\n",
    "[selectivebackprop]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/selective_backprop.html\n",
    "[sam]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/sam.html\n",
    "\n",
    "In addition, please continue to explore our tutorials! Here's a couple suggestions:\n",
    "\n",
    "* Explore more advanced applications of Composer like [applying image segmentation to medical images][image_segmentation_tutorial] or [fine-tuning a transformer for sentiment classification][huggingface_tutorial].\n",
    "\n",
    "* Learn about callbacks and how to apply [early stopping][early_stopping_tutorial].\n",
    "\n",
    "* A [transition guide][ptl_tutorial] for switching from PyTorch Lightening to Composer.\n",
    "\n",
    "[ptl_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/migrate_from_ptl.html\n",
    "[image_segmentation_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/medical_image_segmentation.html\n",
    "[huggingface_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html\n",
    "[early_stopping_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/early_stopping.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
