{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🔌 Training with TPUs\n",
    "\n",
    "Composer provides beta support for single core training on TPUs. We integrate with the `torch_xla` backend. For installation instructions and more details, see [here][torch_xla]. \n",
    "\n",
    "### Recommended Background\n",
    "\n",
    "This tutorial is pretty straightforward. It uses the same basic training cycle set up in the [Getting Started][getting_started] tutorial, which you might want to check out first if you haven't already.\n",
    "\n",
    "### Tutorial Goals and Concepts Covered\n",
    "\n",
    "The goal of this tutorial is to show you the steps needed to do Composer training on TPUs. Concretely, we'll train a ResNet-20 on CIFAR10 using a single TPU core.\n",
    "\n",
    "The training setup is exactly the same as with any other device, except the model must be moved to the device before passing to our `Trainer`, where we must also specify `device=tpu` to enable the trainer to use TPUs. We'll touch on these steps below.\n",
    "\n",
    "Let's get started!\n",
    "\n",
    "[torch_xla]: https://github.com/pytorch/xla\n",
    "[getting_started]: https://docs.mosaicml.com/projects/composer/en/stable/examples/getting_started.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "__O91tR7Rh7O"
   },
   "source": [
    "As prerequisites, first install `torch_xla` and the latest Composer version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wc711rY5RxFy"
   },
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install git+https://github.com/mosaicml/composer.git\n",
    "\n",
    "%pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl\n",
    "\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# %pip install 'mosaicml @ git+https://github.com/mosaicml/composer.git'\"\n",
    "\n",
    "from composer import Trainer\n",
    "from composer.models import ComposerClassifier\n",
    "from composer.optim import DecoupledSGDW"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AqOqlKMNR3CR"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n",
    "\n",
    "Next, we define the model and optimizer. TPUs require the model to be moved to the device _before_ the optimizer is created, which we do here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0J8AUm3HSAQH"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch_xla.core.xla_model as xm\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\"A ResNet block.\"\"\"\n",
    "\n",
    "    def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        stride = 2 if downsample else 1\n",
    "        self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(f_out)\n",
    "        self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(f_out)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # No parameters for shortcut connections.\n",
    "        if downsample or f_in != f_out:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
    "                nn.BatchNorm2d(f_out),\n",
    "            )\n",
    "        else:\n",
    "            self.shortcut = nn.Sequential()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        return self.relu(out)\n",
    "\n",
    "class ResNetCIFAR(nn.Module):\n",
    "    \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
    "\n",
    "    def __init__(self, outputs: int = 10):\n",
    "        super(ResNetCIFAR, self).__init__()\n",
    "\n",
    "        depth = 20\n",
    "        width = 16\n",
    "        num_blocks = (depth - 2) // 6\n",
    "\n",
    "        plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
    "\n",
    "        self.num_classes = outputs\n",
    "\n",
    "        # Initial convolution.\n",
    "        current_filters = plan[0][0]\n",
    "        self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(current_filters)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "        # The subsequent blocks of the ResNet.\n",
    "        blocks = []\n",
    "        for segment_index, (filters, num_blocks) in enumerate(plan):\n",
    "            for block_index in range(num_blocks):\n",
    "                downsample = segment_index > 0 and block_index == 0\n",
    "                blocks.append(Block(current_filters, filters, downsample))\n",
    "                current_filters = filters\n",
    "\n",
    "        self.blocks = nn.Sequential(*blocks)\n",
    "\n",
    "        # Final fc layer. Size = number of filters in last segment.\n",
    "        self.fc = nn.Linear(plan[-1][0], outputs)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        out = self.relu(self.bn(self.conv(x)))\n",
    "        out = self.blocks(out)\n",
    "        out = F.avg_pool2d(out, out.size()[3])\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
    "\n",
    "model = model.to(xm.xla_device())\n",
    "\n",
    "optimizer = DecoupledSGDW(\n",
    "    model.parameters(),\n",
    "    lr=0.02,\n",
    "    momentum=0.9,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DqjBwfb8SE0G"
   },
   "source": [
    "### Datasets\n",
    "\n",
    "Creating the CIFAR10 dataset and dataloaders are exactly the same as with other non-TPU devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YwJEDZWxSJUX"
   },
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "data_directory = \"./data\"\n",
    "\n",
    "# Normalization constants\n",
    "mean = (0.507, 0.487, 0.441)\n",
    "std = (0.267, 0.256, 0.276)\n",
    "\n",
    "batch_size = 1024\n",
    "\n",
    "cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)\n",
    "test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jnCqqWAhSafT"
   },
   "source": [
    "## Training\n",
    "\n",
    "Lastly, we train for 20 epochs on the TPU by simply adding `device='tpu'` as an argument to the Trainer.\n",
    "\n",
    "**Note**: we currently only support single-core TPUs in this beta release. Future releases will include multi-core TPU support."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TAZJT9LGSiyB"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=train_dataloader,\n",
    "    device=\"tpu\",\n",
    "    eval_dataloader=test_dataloader,\n",
    "    optimizers=optimizer,\n",
    "    max_duration='20ep',\n",
    "    eval_interval=1,\n",
    ")\n",
    "\n",
    "trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## What next?\n",
    "\n",
    "You've now seen a simple example of how to use the Composer trainer on a TPU. Cool!\n",
    "\n",
    "To get to know Composer more, please continue to explore our tutorials! Here's a couple suggestions:\n",
    "\n",
    "* Continue learning about other Composer features like [automatic gradient accumulation][autograd] and [automatic restarting from checkpoints][autoresume]\n",
    "\n",
    "* Explore more advanced applications of Composer like [applying image segmentation to medical images][image_segmentation_tutorial] or [fine-tuning a transformer for sentiment classification][huggingface_tutorial].\n",
    "\n",
    "* Keep it custom with our [custom speedups][custom_speedups_tutorial] tutorial.\n",
    "\n",
    "[autograd]: https://docs.mosaicml.com/projects/composer/en/stable/examples/auto_microbatching.html\n",
    "[autoresume]: https://docs.mosaicml.com/projects/composer/en/stable/examples/checkpoint_autoresume.html\n",
    "[image_segmentation_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/medical_image_segmentation.html\n",
    "[huggingface_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/huggingface_models.html\n",
    "[custom_speedups_tutorial]: https://docs.mosaicml.com/projects/composer/en/stable/examples/custom_speedup_methods.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Come get involved with MosaicML!\n",
    "\n",
    "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
    "\n",
    "### [Star Composer on GitHub](https://github.com/mosaicml/composer)\n",
    "\n",
    "Help make others aware of our work by [starring Composer on GitHub](https://github.com/mosaicml/composer).\n",
    "\n",
    "### [Join the MosaicML Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg)\n",
    "\n",
    "Head on over to the [MosaicML slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
    "\n",
    "### Contribute to Composer\n",
    "\n",
    "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/composer/issues) or make a [pull request](https://github.com/mosaicml/composer/pulls)!"
   ]
  }
 ],
 "metadata": {
  "gpuClass": "standard",
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
