{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🛑 Early Stopping\n",
    "\n",
    "In this tutorial, we're going to learn about how to perform early stopping in Composer using **callbacks**.\n",
    "\n",
    "In Composer, callbacks modify trainer behavior and are called at the relevant Events in the training loop. This tutorial focuses on two callbacks, the [EarlyStopper][earlystopper] and [ThresholdStopper][thresholdstopper], both of which halt training early depending on different criteria.\n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "This tutorial assumes that you're generally familiar with techniques such as early stopping.\n",
    "\n",
    "You should also be comfortable with the material in the [Getting Started tutorial][getting_started] before you embark on this slightly more advanced tutorial.\n",
    "\n",
    "Finally, you'll probably have a better intuition for how the demonstrated features work if you brush up on Composer's event-driven design in our [Welcome Tour][welcome_tour].\n",
    "\n",
    "### Tutorial Goals and Concepts Covered\n",
    "\n",
    "The goal of this tutorial is to demonstrate a basic training run using one of our callbacks to control if/when training stops before the maximum training duration.\n",
    "\n",
    "We'll demonstrate:\n",
    "\n",
    "* [Basic training setup with Composer](#Setup)\n",
    "* [Using the EarlyStopper](#Earlystopper), and\n",
    "* [Using the ThresholdStopper](#Thresholdstopper)\n",
    "\n",
    "A comprehensive overview of Composer callbacks is outside the scope of this tutorial, but this should introduce you to some useful tools and give you a sense for the ways callbacks can be used to modify training behavior.\n",
    "\n",
    "Let's get started!\n",
    "\n",
    "[earlystopper]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.callbacks.EarlyStopper.html\n",
    "[thresholdstopper]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.callbacks.ThresholdStopper.html\n",
    "[getting_started]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html\n",
    "[welcome_tour]: https://docs.mosaicml.com/projects/composer/en/stable/getting_started/welcome_tour.html"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "<a id=\"Setup\"></a>\n",
    "\n",
    "In this tutorial, we'll train a `ComposerModel` and halt training for criteria that we'll set. We'll use the same basic setup as in the [Getting Started][tutorial] tutorial. If you want to better understand the details of the setup, that's a good place to review.\n",
    "\n",
    "[tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Composer\n",
    "\n",
    "First, install Composer if you haven't already:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Seed\n",
    "\n",
    "Next, we'll set the seed for reproducibility:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.utils.reproducibility import seed_all\n",
    "\n",
    "seed_all(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloader Setup\n",
    "\n",
    "Next, instantiate the training and evaluation datasets for CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "data_directory = \"./data\"\n",
    "\n",
    "# Normalization constants\n",
    "mean = (0.507, 0.487, 0.441)\n",
    "std = (0.267, 0.256, 0.276)\n",
    "\n",
    "batch_size = 1024\n",
    "\n",
    "cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)\n",
    "eval_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)\n",
    "\n",
    "# Setting shuffle=False to allow for easy overfitting in this example\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)\n",
    "eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model, Optimizer, Scheduler, and Evaluator Setup\n",
    "\n",
    "Finally, set up the model, optimizer, scheduler, and an [evaluator][evaluator].\n",
    "\n",
    "[evaluator]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Evaluator.html#evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.models import ComposerClassifier\n",
    "from composer.optim import DecoupledSGDW, LinearWithWarmupScheduler\n",
    "from composer.core import Evaluator\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\"A ResNet block.\"\"\"\n",
    "\n",
    "    def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        stride = 2 if downsample else 1\n",
    "        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(f_out)\n",
    "        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(f_out)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # No parameters for shortcut connections.\n",
    "        if downsample or f_in != f_out:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
    "                nn.BatchNorm2d(f_out),\n",
    "            )\n",
    "        else:\n",
    "            self.shortcut = nn.Sequential()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return self.relu(out)\n",
    "\n",
    "class ResNetCIFAR(nn.Module):\n",
    "    \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
    "\n",
    "    def __init__(self, outputs: int = 10):\n",
    "        super(ResNetCIFAR, self).__init__()\n",
    "\n",
    "        depth = 56\n",
    "        width = 16\n",
    "        num_blocks = (depth - 2) // 6\n",
    "\n",
    "        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
    "\n",
    "        self.num_classes = outputs\n",
    "\n",
    "        # Initial convolution.\n",
    "        current_filters = plan[0][0]\n",
    "        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(current_filters)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # The subsequent blocks of the ResNet.\n",
    "        blocks = []\n",
    "        for segment_index, (filters, num_blocks) in enumerate(plan):\n",
    "            for block_index in range(num_blocks):\n",
    "                downsample = segment_index > 0 and block_index == 0\n",
    "                blocks.append(Block(current_filters, filters, downsample))\n",
    "                current_filters = filters\n",
    "\n",
    "        self.blocks = nn.Sequential(*blocks)\n",
    "\n",
    "        # Final fc layer. Size = number of filters in last segment.\n",
    "        self.fc = nn.Linear(plan[-1][0], outputs)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn(self.conv(x)))\n",
    "        out = self.blocks(out)\n",
    "        out = F.avg_pool2d(out, out.size()[3])\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
    "\n",
    "optimizer = DecoupledSGDW(\n",
    "    model.parameters(), # Model parameters to update\n",
    "    lr=0.05, # Peak learning rate\n",
    "    momentum=0.9,\n",
    "    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay\n",
    ")\n",
    "\n",
    "lr_scheduler = LinearWithWarmupScheduler(\n",
    "    t_warmup=\"1ep\", # Warm up over 1 epoch\n",
    "    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f\n",
    "    alpha_f=1.0\n",
    ")\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    dataloader = eval_dataloader,\n",
    "    label = \"eval\",\n",
    "    metric_names = ['MulticlassAccuracy']\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EarlyStopper\n",
    "<a id=\"Earlystopper\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `EarlyStopper` callback tracks a particular training or evaluation metric and stops training if the metric does not improve within a given time interval. \n",
    "\n",
    "\n",
    "The callback takes the following parameters:\n",
    "\n",
    "-   `monitor`: The name of the metric to track\n",
    "-   `dataloader_label`: This string identifies which dataloader the metric belongs to. By default, the train dataloader is labeled `train`, and the evaluation dataloader is labeled `eval`. (These names can be customized via the `train_dataloader_label` in the [Trainer][trainer] or the `label` argument of the [Evaluator][evaluator], respectively.)\n",
    "-   `patience`: The interval of the time that the callback will wait before stopping training if the metric is not improving. You can use integers to specify the number of epochs or provide a [Time][time] string—e.g., `\"50ba\"` or `\"2ep\"` for 50 batches and 2 epochs, respectively.\n",
    "-   `min_delta`: If non-zero, the change in the tracked metric over the `patience` window must be at least this large.\n",
    "-   `comp`: A comparison operator can be provided to measure the change in the monitored metric. The comparison operator will be called like `comp(current_value, previous_best)`\n",
    "\n",
    "See the [API Reference][api] for more information.\n",
    "\n",
    "[evaluator]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Evaluator.html#evaluator\n",
    "[trainer]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Trainer.html#trainer\n",
    "[time]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.Time.html#time\n",
    "[api]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.callbacks.EarlyStopper.html\n",
    "\n",
    "Here, we'll use our callback to track the MulticlassAccuracy metric over one epoch on the test dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.callbacks import EarlyStopper\n",
    "\n",
    "early_stopper = EarlyStopper(monitor=\"MulticlassAccuracy\", dataloader_label=\"eval\", patience=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have our callback, we can instantiate a Composer trainer and train:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.trainer import Trainer\n",
    "\n",
    "# Early stopping should stop training before we reach 100 epochs!\n",
    "train_epochs = \"100ep\"\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=evaluator,\n",
    "    max_duration=train_epochs,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=lr_scheduler,\n",
    "    callbacks=[early_stopper],    # Instruct the trainer to use our early stopping callback\n",
    "    train_subset_num_batches=10,  # Only training on a subset of the data to trigger the callback sooner\n",
    ")\n",
    "\n",
    "# Train!\n",
    "trainer.fit()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ThresholdStopper\n",
    "<a id=\"Thresholdstopper\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [ThresholdStopper][thresholdstopper] callback is similar to the EarlyStopper, but it halts training when the metric crosses a threshold set in the ThresholdStopper callback.\n",
    "\n",
    "This callback takes the following parameters:\n",
    "\n",
    "-   `monitor`, `dataloader_label`, and `comp`: Same as the EarlyStopper callback\n",
    "-   `threshold`: The float threshold that dictates when to halt training.\n",
    "-   `stop_on_batch`: If `True`, training can halt in the middle of an epoch, rather than just add the end.\n",
    "\n",
    "[thresholdstopper]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/generated/composer.callbacks.ThresholdStopper.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will reuse the same setup for the ThresholdStopper example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer.callbacks import ThresholdStopper\n",
    "\n",
    "threshold_stopper = ThresholdStopper(\"MulticlassAccuracy\", \"eval\", threshold=0.3)\n",
    "\n",
    "# Threshold stopping should stop training before we reach 100 epochs!\n",
    "train_epochs = \"100ep\"\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    eval_dataloader=evaluator,\n",
    "    max_duration=train_epochs,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=lr_scheduler,\n",
    "    callbacks=[threshold_stopper],  # Instruct the trainer to use our threshold stopper callback\n",
    "    train_subset_num_batches=10,    # Only training on a subset of the data to trigger the callback sooner\n",
    ")\n",
    "\n",
    "# Train!\n",
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "You've now seen how to implement early stopping in Composer using our `EarlyStopper` and `ThresholdStopper` callbacks.\n",
    "\n",
    "To dig deeper into Composer callbacks check out the [docs][docs] and our [API references][api].\n",
    "\n",
    "In addition, please continue to explore our tutorials! Here's a couple suggestions:\n",
    "\n",
    "* Continue learning about other Composer features like [automatic gradient accumulation][autograd] and [automatic restarting from checkpoints][autoresume]\n",
    "\n",
    "* Give your model life after training with Composer's [export for inference tools][exporting]\n",
    "\n",
    "[docs]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/callbacks.html\n",
    "[api]: https://docs.mosaicml.com/projects/composer/en/stable/api_reference/composer.callbacks.html#module-composer.callbacks\n",
    "[autograd]: https://docs.mosaicml.com/projects/composer/en/stable/examples/auto_microbatching.html\n",
    "[autoresume]: https://docs.mosaicml.com/projects/composer/en/stable/examples/checkpoint_autoresume.html\n",
    "[exporting]: https://docs.mosaicml.com/projects/composer/en/stable/examples/exporting_for_inference.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
