{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ⚡ Migrating from PTL\n",
    "\n",
    "[PyTorch Lightning][ptl] is a popular and well-designed framework for training deep neural networks. You can use Composer's algorithms in your Pytorch Lightning code via the [functional API](https://docs.mosaicml.com/projects/composer/en/stable/functional_api.html) with no additional code changes.\n",
    "\n",
    "However, if you are interested in features like [automatic gradient accumulation](https://docs.mosaicml.com/projects/composer/en/stable/examples/auto_microbatching.html), a [clean time abstraction](https://docs.mosaicml.com/projects/composer/en/stable/trainer/time.html), and the easiest path to trying out different combinations of algorithms, you will need to switch from the PTL trainer to the Composer trainer.\n",
    "\n",
    "The below is a quick guide on how to adapt your `LightningModule` to our simple interface.\n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "This tutorial assumes you are already familiar with PyTorch Lightning (since you're switching from it) and some computer vision basics.\n",
    "\n",
    "To better understand the Composer part, make sure you're comfortable with the material in our [Getting Started][getting_started] tutorial.\n",
    "\n",
    "### Tutorial Goals and Concepts Covered\n",
    "\n",
    "The goal of this tutorial is to illustrate a path from working in PyTorch Lightning to working in Composer.\n",
    "\n",
    "We'll primarily focus on the different ways **models** are structured in each framework, in order to illustrate how one maps on to the other.\n",
    "\n",
    "[getting_started]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html\n",
    "[ptl]: https://www.pytorchlightning.ai/\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We'll first install dependencies and define the data and model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install Dependencies\n",
    "\n",
    "If you haven't already, let's install Composer and PyTorch Lightning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install pytorch-lightning\n",
    "\n",
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we'll go through the process of migrating a Resnet-18 model from PTL to Composer. We will be following the PTL example [here][example].\n",
    "\n",
    "First, some relevant imports, as well as creating the model as in the PTL tutorial.\n",
    "\n",
    "[example]: https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/cifar10-baseline.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.models\n",
    "from pytorch_lightning import LightningModule\n",
    "from torch.optim.lr_scheduler import OneCycleLR\n",
    "\n",
    "def create_model():\n",
    "    model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n",
    "    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
    "    model.maxpool = nn.Identity()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As is standard, we setup the training data for CIFAR-10 using `torchvision` datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "\n",
    "transform = torchvision.transforms.Compose(\n",
    "    [\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ]\n",
    ")\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PTL Lightning Module"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Following the PTL tutorial, we use the `LitResnet` model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.functional import accuracy\n",
    "\n",
    "class LitResnet(LightningModule):\n",
    "    def __init__(self, lr=0.05):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "        self.model = create_model()\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.model(x)\n",
    "        return F.log_softmax(out, dim=1)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        self.log(\"train_loss\", loss)\n",
    "        return loss\n",
    "\n",
    "    def evaluate(self, batch, stage=None):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        preds = torch.argmax(logits, dim=1)\n",
    "        acc = accuracy(preds, y)\n",
    "\n",
    "        if stage:\n",
    "            self.log(f\"{stage}_loss\", loss, prog_bar=True)\n",
    "            self.log(f\"{stage}_acc\", acc, prog_bar=True)\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        self.evaluate(batch, \"val\")\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        self.evaluate(batch, \"test\")\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.SGD(\n",
    "            self.model.parameters(),\n",
    "            lr=self.hparams.lr,\n",
    "            momentum=0.9,\n",
    "            weight_decay=5e-4,\n",
    "        )\n",
    "        steps_per_epoch = 45000 // 256\n",
    "        scheduler_dict = {\n",
    "            \"scheduler\": OneCycleLR(\n",
    "                optimizer,\n",
    "                0.1,\n",
    "                epochs=30,\n",
    "                steps_per_epoch=steps_per_epoch,\n",
    "            ),\n",
    "            \"interval\": \"step\",\n",
    "        }\n",
    "        return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler_dict}\n",
    "\n",
    "PTLModel = LitResnet(lr=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `LitModel` to Composer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that up to here, we have only used PyTorch Lightning code. Here we will modify the PTL module to be compatible with Composer. There are a few major differences:\n",
    "\n",
    "* The `training_step` is broken into two parts, the `forward` and the `loss` methods. This is needed since some algorithms (such as label smoothing or selective backprop) need to intercept and modify the loss. \n",
    "* Optimizers and schedulers are passed directly to the `Trainer` during initialization.\n",
    "* Our `forward` step accepts the entire batch as input and has to take care of unpacking the batch. \n",
    "\n",
    "For more information about the `ComposerModel` format, see our [documentation][composer_model].\n",
    "\n",
    "[composer_model]: https://docs.mosaicml.com/projects/composer/en/stable/composer_model.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.classification import MulticlassAccuracy\n",
    "from composer.models.base import ComposerModel \n",
    "PTLmodel = LitResnet(lr=0.05)\n",
    "\n",
    "class MosaicResnet(ComposerModel):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.model = create_model()\n",
    "        self.acc = MulticlassAccuracy(num_classes=10, average='micro')\n",
    "\n",
    "    def loss(self, outputs, batch, *args, **kwargs):\n",
    "        \"\"\"Accepts the outputs from forward() and the batch\"\"\"\n",
    "        x, y = batch  # unpack the labels\n",
    "        return F.nll_loss(outputs, y)\n",
    "\n",
    "    def get_metrics(self, is_train):\n",
    "        return {'MulticlassAccuracy': self.acc}\n",
    "\n",
    "    def forward(self, batch):\n",
    "        x, _ = batch\n",
    "        y = self.model(x)\n",
    "        return F.log_softmax(y, dim=1) \n",
    "\n",
    "    def eval_forward(self, batch, outputs = None):\n",
    "        return outputs if outputs is not None else self.forward(batch)\n",
    "\n",
    "    def update_metric(self, batch, outputs, metric) -> None:\n",
    "        _, targets = batch\n",
    "        metric.update(outputs, targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We instantiate the Composer trainer similarly by specifying the model, dataloaders, optimizers, and max_duration (epochs). For more details on the trainer arguments, see our [Using the Trainer][trainer] guide.\n",
    "\n",
    "Now you are ready to insert your algorithms! As an example, here we add the [BlurPool][blurpool] algorithm.\n",
    "\n",
    "[trainer]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/using_the_trainer.html\n",
    "[blurpool]: https://docs.mosaicml.com/projects/composer/en/stable/method_cards/blurpool.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from composer import Trainer\n",
    "from composer.algorithms import BlurPool\n",
    "from composer.optim import DecoupledSGDW\n",
    "\n",
    "model = MosaicResnet()\n",
    "optimizer = DecoupledSGDW(\n",
    "    model.parameters(),\n",
    "    lr=0.05,\n",
    "    momentum=0.9,\n",
    "    weight_decay=5e-4,\n",
    ")\n",
    "\n",
    "steps_per_epoch = 45000 // 256\n",
    "\n",
    "scheduler = OneCycleLR(\n",
    "    optimizer,\n",
    "    0.1,\n",
    "    epochs=30,\n",
    "    steps_per_epoch=steps_per_epoch,  \n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    algorithms=[\n",
    "        BlurPool(\n",
    "            replace_convs=True,\n",
    "            replace_maxpools=True,\n",
    "            blur_first=True\n",
    "        ),\n",
    "    ],\n",
    "    train_dataloader=train_dataloader,\n",
    "    device=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
    "    eval_dataloader=test_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    schedulers=scheduler,\n",
    "    step_schedulers_every_batch=True,  # interval should be step                  \n",
    "    max_duration='2ep',\n",
    "    eval_interval=1,\n",
    "    train_subset_num_batches=1,\n",
    ")\n",
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "Hopefully this tutorial provides you with some useful intuitions for making the jump from PyTorch Lightning to Composer.\n",
    "\n",
    "To continue learning about Composer, check out our [guide to using the trainer][trainer] and explore more of our tutorials! Here are a couple suggestions:\n",
    "\n",
    "* Get to know the [functional API][functional_tutorial] for using algorithms outside the Trainer.\n",
    "\n",
    "* Check out more advanced applications of Composer like [applying image segmentation to medical images][image_segmentation_tutorial] or [fine-tuning a transformer for sentiment classification][huggingface_tutorial].\n",
    "\n",
    "* Learn about implementing your own [custom speedup methods in Composer][custom_methods].\n",
    "\n",
    "[trainer]: https://docs.mosaicml.com/projects/composer/en/stable/trainer/using_the_trainer.html\n",
    "[functional_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/functional_api.html\n",
    "[image_segmentation_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/medical_image_segmentation.html\n",
    "[huggingface_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html\n",
    "[custom_methods]: https://docs.mosaicml.com/projects/composer/en/stable/examples/custom_speedup_methods.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
