{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thank you to [Yilun Kuang](https://github.com/YilunKuang) for providing this example!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "fa01c6233eb141b18c7a7ffbcd127d14",
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "# 🕹️ Distributed Training with Submitit"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "e90ac6d5536846758101a642b79585d5",
    "deepnote_cell_height": 174.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "Composer is compatible with [submitit](https://github.com/facebookincubator/submitit), a lightweight SLURM cluster job management package with a Python API. To run distributed training on SLURM with submitit, the following environment variables need to be specified:\n",
    "\n",
    "```\n",
    "RANK, WORLD_SIZE, LOCAL_RANK, LOCAL_WORLD_SIZE, NODE_RANK, MASTER_ADDR, MASTER_PORT, PYTHONUNBUFFERED\n",
    "```\n",
    "\n",
    "In this tutorial, we walk through how to set up these environment variables without using the Composer launcher. The example task we are considering is standard supervised training of ResNet18 on Cifar10.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "0a9cca48c92f447d93c68b84eeac363a",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "## Prerequisite"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "308c2201c44240d7b8a37e460f30fad9",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "To start with, let's first install the Composer and submitit libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "1a54499f0b9545b5a1501f5a6fb3a93a",
    "deepnote_cell_height": 61,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "%pip install mosaicml\n",
    "%pip install submitit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "a2317dc5560d4c17ba6a292355002508",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown",
    "tags": []
   },
   "source": [
    "## Prepare Dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "b645962d07e94a80acd7376be09996df",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "We will use a standard PyTorch DataLoader for the Composer training pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "84ea4256404349be8869c7d017bdb862",
    "deepnote_cell_height": 115,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "8652e444f7e244febd367a18ee7a14f7",
    "deepnote_cell_height": 97,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "def initialize_dataset():\n",
    "    cifar10_transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),]\n",
    "    )\n",
    "    train_data = torchvision.datasets.CIFAR10(\n",
    "        root=\"./data\", train=True, download=True, transform=cifar10_transform\n",
    "    )\n",
    "    train_loader = DataLoader(\n",
    "        dataset=train_data,\n",
    "        sampler=None,\n",
    "        batch_size=1024,\n",
    "        num_workers=10,\n",
    "        pin_memory=True,\n",
    "        drop_last=True,\n",
    "    )\n",
    "\n",
    "    return train_loader  # standard pytorch dataloader\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "e0c239bc455e4dbb8402c19bb9bfb906",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "## Prepare Model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "fff3b797982b46a689843281f1477cc5",
    "deepnote_cell_height": 74.796875,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "To use the Composer Trainer, our model must subclass the ComposerModel class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "21348c218fe84050b4eeb5bf72c6c972",
    "deepnote_cell_height": 421,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "from typing import Optional, Any\n",
    "\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "from composer.models import ComposerModel\n",
    "\n",
    "class ResNet18(ComposerModel):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.model = torchvision.models.resnet18()\n",
    "\n",
    "    def forward(self, batch): # batch is the output of the dataloader\n",
    "        # specify how batches are passed through the model\n",
    "        inputs, _ = batch\n",
    "        return self.model(inputs)\n",
    "\n",
    "    def loss(self, outputs, batch):\n",
    "        # pass batches and `forward` outputs to the loss\n",
    "        _, targets = batch\n",
    "        return F.cross_entropy(outputs, targets)\n",
    "    \n",
    "    def eval_forward(self, batch, outputs: Optional[Any] = None):\n",
    "        return outputs if outputs is not None else self.forward(batch)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "c0133067f82841e5bc61d3c6f6fe4f9a",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "For more details about model wrapping, see [🛻 ComposerModel](https://docs.mosaicml.com/projects/composer/en/stable/composer_model.html). "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "9f135416fba74637be4d214974d730b6",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "Next, we initialize a standard Adam optimizer from PyTorch for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "5c343f91013f4c7dbfe52532fecb97e6",
    "deepnote_cell_height": 151,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "def initialize_model_and_optimizer():\n",
    "    model = ResNet18()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "    return model, optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "804622d72fd84bb6a430ad53c579b700",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "## Set up Environment Variables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "c2f089360bcc4032ae0533c3708a99a6",
    "deepnote_cell_height": 97.1875,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "Before training, we need to set up all the necessary environment variables to correctly initialize the distributed training process group. The environment variables can be set using submitit built-in attributes, torch methods, and existing environment variables generated by the SLURM cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "94aa92db13bf40408f93adb1f7057846",
    "deepnote_cell_height": 97,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import submitit\n",
    "import subprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "e0f13f39a2754ddb84c949065f717878",
    "deepnote_cell_height": 655,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "def set_up_dist_env():\n",
    "    # 1. RANK\n",
    "    job_env = submitit.JobEnvironment()\n",
    "    global_rank = job_env.global_rank\n",
    "    \n",
    "    # 2. LOCAL_RANK\n",
    "    local_rank = job_env.local_rank\n",
    "    \n",
    "    # 3. LOCAL_WORLD_SIZE\n",
    "    ngpus_per_node = torch.cuda.device_count()\n",
    "    \n",
    "    # 4. WORLD_SIZE\n",
    "    world_size = int(os.getenv(\"SLURM_NNODES\")) * ngpus_per_node\n",
    "    \n",
    "    # 5. NODE_RANK\n",
    "    node_rank = int(os.getenv(\"SLURM_NODEID\"))\n",
    "    \n",
    "    # 6. MASTER_ADDR\n",
    "    cmd = \"scontrol show hostnames \" + os.getenv(\"SLURM_JOB_NODELIST\")\n",
    "    stdout = subprocess.check_output(cmd.split())\n",
    "    host_name = stdout.decode().splitlines()[0]\n",
    "    \n",
    "    # 7. MASTER_PORT\n",
    "    port = 54321\n",
    "    \n",
    "    # Set All the Necessary Environment Variables!\n",
    "    os.environ[\"RANK\"] = str(global_rank)\n",
    "    os.environ[\"LOCAL_RANK\"] = str(local_rank)\n",
    "    os.environ[\"LOCAL_WORLD_SIZE\"] = str(ngpus_per_node)\n",
    "    os.environ[\"WORLD_SIZE\"] = str(world_size)\n",
    "    os.environ[\"NODE_RANK\"] = str(node_rank)\n",
    "    os.environ[\"MASTER_ADDR\"] = host_name\n",
    "    os.environ[\"MASTER_PORT\"] = str(port)\n",
    "    os.environ[\"PYTHONUNBUFFERED\"] = \"1\"\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "0175defd38474f56b60106aa613bfe20",
    "deepnote_cell_height": 399.96875,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "The above setup is a bare minimum version of the [Composer launcher](https://github.com/mosaicml/composer/blob/dev/composer/cli/launcher.py) for the distributed training pipeline to work. Here is a code snippet from the [Composer launcher](https://github.com/mosaicml/composer/blob/dev/composer/cli/launcher.py) source code showing how Composer sets the environment variables. The above function is doing the same thing but with the necessary variables set automatically in a slurm environment. \n",
    "\n",
    "\n",
    "```python\n",
    "# composer/composer/cli/launcher.py\n",
    "\n",
    "with _patch_env(\n",
    "        RANK=str(global_rank),\n",
    "        WORLD_SIZE=str(world_size),\n",
    "        LOCAL_RANK=str(local_rank),\n",
    "        LOCAL_WORLD_SIZE=str(nproc),\n",
    "        NODE_RANK=str(node_rank),\n",
    "        MASTER_ADDR=master_addr,\n",
    "        MASTER_PORT=str(master_port),\n",
    "        PYTHONUNBUFFERED='1',\n",
    "):\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "3ddb24c42c3a42fe8377cdb2334bf8fc",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "## Submit Job to the Cluster"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "ec3fea1ee71145828bddb5aecf69f2db",
    "deepnote_cell_height": 111.1875,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "Here comes the final step. Assume we have a multi-node setup, where we have two nodes and each node has four GPUs. The same `set_up_dist_env()` function should also work with a single node setup with multiple GPUs.\n",
    "\n",
    "\n",
    "Let's define the `submit_job()` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "3f01decae4a84f4d9332c23a68c246a6",
    "deepnote_cell_height": 349,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "from composer.trainer import Trainer\n",
    "\n",
    "def train():\n",
    "    set_up_dist_env()\n",
    "    train_dataloader = initialize_dataset()\n",
    "    model, optimizer = initialize_model_and_optimizer()\n",
    "    \n",
    "    print(\"Trainer\")\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        optimizers=optimizer,\n",
    "        train_dataloader=train_dataloader,\n",
    "        max_duration='10ep',\n",
    "        device='gpu' if torch.cuda.is_available() else 'cpu',\n",
    "    )\n",
    "    print(\"trainer.fit\")\n",
    "    trainer.fit()\n",
    "\n",
    "def submit_job():\n",
    "    slurm_ngpus = 4\n",
    "    slurm_nnodes = 2\n",
    "    slurm_timeout = 1024\n",
    "    workers = 10\n",
    "    \n",
    "    slurm_directory = \".\" # \"<Your Specified Directory>\"\n",
    "    executor = submitit.AutoExecutor(folder=slurm_directory)\n",
    "\n",
    "    executor.update_parameters(\n",
    "            mem_gb=128*slurm_ngpus,\n",
    "            gpus_per_node=slurm_ngpus,\n",
    "            tasks_per_node=slurm_ngpus,\n",
    "            cpus_per_task=workers,\n",
    "            nodes=slurm_nnodes,\n",
    "            timeout_min=slurm_timeout,\n",
    "            slurm_partition=\"gpu\",\n",
    "            # see submitit github repo for details\n",
    "    )\n",
    "\n",
    "    job = executor.submit(train)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "56a8162da7b44113a5003db41346153f",
    "deepnote_cell_height": 70,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "## Putting Things Together"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "8439f45c91a34f8dac6241891caba285",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "We can now put everything into a python file for job submission"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_id": "97a7e830aad24abc8ef1333bcbc51951",
    "deepnote_cell_height": 907,
    "deepnote_cell_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import submitit\n",
    "import subprocess\n",
    "\n",
    "import torchvision\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from composer.models import ComposerModel\n",
    "from composer import Trainer\n",
    "\n",
    "class ResNet18(ComposerModel):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.model = torchvision.models.resnet18()\n",
    "\n",
    "    def forward(self, batch): # batch is the output of the dataloader\n",
    "        # specify how batches are passed through the model\n",
    "        inputs, _ = batch\n",
    "        return self.model(inputs)\n",
    "\n",
    "    def loss(self, outputs, batch):\n",
    "        # pass batches and `forward` outputs to the loss\n",
    "        _, targets = batch\n",
    "        return F.cross_entropy(outputs, targets)\n",
    "\n",
    "def initialize_model_and_optimizer():\n",
    "    model = ResNet18()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "    return model, optimizer\n",
    "\n",
    "def initialize_dataset():\n",
    "    cifar10_transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),]\n",
    "    )\n",
    "    train_data = torchvision.datasets.CIFAR10(\n",
    "        root=\".\", train=True, download=True, transform=cifar10_transform\n",
    "    )\n",
    "    train_loader = DataLoader(\n",
    "        dataset=train_data,\n",
    "        sampler=None,\n",
    "        batch_size=1024,\n",
    "        num_workers=10,\n",
    "        pin_memory=True,\n",
    "        drop_last=True,\n",
    "    )\n",
    "\n",
    "    return train_loader  # standard pytorch dataloader\n",
    "\n",
    "def set_up_dist_env():\n",
    "    # 1. RANK\n",
    "    job_env = submitit.JobEnvironment()\n",
    "    global_rank = job_env.global_rank\n",
    "    \n",
    "    # 2. LOCAL_RANK\n",
    "    local_rank = job_env.local_rank\n",
    "    \n",
    "    # 3. LOCAL_WORLD_SIZE\n",
    "    ngpus_per_node = torch.cuda.device_count()\n",
    "    \n",
    "    # 4. WORLD_SIZE\n",
    "    world_size = int(os.getenv(\"SLURM_NNODES\")) * ngpus_per_node\n",
    "    \n",
    "    # 5. NODE_RANK\n",
    "    node_rank = int(os.getenv(\"SLURM_NODEID\"))\n",
    "    \n",
    "    # 6. MASTER_ADDR\n",
    "    cmd = \"scontrol show hostnames \" + os.getenv(\"SLURM_JOB_NODELIST\")\n",
    "    stdout = subprocess.check_output(cmd.split())\n",
    "    host_name = stdout.decode().splitlines()[0]\n",
    "    \n",
    "    # 7. MASTER_PORT\n",
    "    port = 54321\n",
    "    \n",
    "    # Set All the Necessary Environment Variables!\n",
    "    os.environ[\"RANK\"] = str(global_rank)\n",
    "    os.environ[\"LOCAL_RANK\"] = str(local_rank)\n",
    "    os.environ[\"LOCAL_WORLD_SIZE\"] = str(ngpus_per_node)\n",
    "    os.environ[\"WORLD_SIZE\"] = str(world_size)\n",
    "    os.environ[\"NODE_RANK\"] = str(node_rank)\n",
    "    os.environ[\"MASTER_ADDR\"] = host_name\n",
    "    os.environ[\"MASTER_PORT\"] = str(port)\n",
    "    os.environ[\"PYTHONUNBUFFERED\"] = \"1\"\n",
    "\n",
    "def train():\n",
    "    set_up_dist_env()\n",
    "    train_dataloader = initialize_dataset()\n",
    "    model, optimizer = initialize_model_and_optimizer()\n",
    "    \n",
    "    print(\"Trainer\")\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        optimizers=optimizer,\n",
    "        train_dataloader=train_dataloader,\n",
    "        max_duration='10ep',\n",
    "        device='gpu' if torch.cuda.is_available() else 'cpu',\n",
    "    )\n",
    "    print(\"trainer.fit\")\n",
    "    trainer.fit()\n",
    "\n",
    "def submit_job():\n",
    "    slurm_ngpus = 4\n",
    "    slurm_nnodes = 2\n",
    "    slurm_timeout = 1024\n",
    "    workers = 10\n",
    "    \n",
    "    slurm_directory = \".\" # \"<Your Specified Directory>\"\n",
    "    executor = submitit.AutoExecutor(folder=slurm_directory)\n",
    "\n",
    "    executor.update_parameters(\n",
    "            mem_gb=128*slurm_ngpus,\n",
    "            gpus_per_node=slurm_ngpus,\n",
    "            tasks_per_node=slurm_ngpus,\n",
    "            cpus_per_task=workers,\n",
    "            nodes=slurm_nnodes,\n",
    "            timeout_min=slurm_timeout,\n",
    "            slurm_partition=\"gpu\",\n",
    "            # see submitit github repo for details\n",
    "    )\n",
    "\n",
    "    job = executor.submit(train)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    submit_job()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_id": "1f0fb2c92b6e4887a718529479622623",
    "deepnote_cell_height": 52.390625,
    "deepnote_cell_type": "markdown"
   },
   "source": [
    "Run the above python script in your command shell and you're all set! "
   ]
  }
 ],
 "metadata": {
  "deepnote": {},
  "deepnote_execution_queue": [],
  "deepnote_notebook_id": "851418bcf18f4d14b054f1beb28ec800",
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
