{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0nrBBPTlmcag"
      },
      "source": [
        "# Resnet-18 on CIFAR-10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-qio0AMbp5mH"
      },
      "source": [
        "### Define ResNet-18"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eN3wolT9Hy1M"
      },
      "source": [
        "ResNet code from https://github.com/kuangliu/pytorch-cifar [MIT License]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ABnM6xBzp5AL"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "    expansion = 1\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(BasicBlock, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(\n",
        "            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.bn2(self.conv2(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "    expansion = 4\n",
        "\n",
        "    def __init__(self, in_planes, planes, stride=1):\n",
        "        super(Bottleneck, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(planes)\n",
        "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n",
        "                               stride=stride, padding=1, bias=False)\n",
        "        self.bn2 = nn.BatchNorm2d(planes)\n",
        "        self.conv3 = nn.Conv2d(planes, self.expansion *\n",
        "                               planes, kernel_size=1, bias=False)\n",
        "        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n",
        "\n",
        "        self.shortcut = nn.Sequential()\n",
        "        if stride != 1 or in_planes != self.expansion*planes:\n",
        "            self.shortcut = nn.Sequential(\n",
        "                nn.Conv2d(in_planes, self.expansion*planes,\n",
        "                          kernel_size=1, stride=stride, bias=False),\n",
        "                nn.BatchNorm2d(self.expansion*planes)\n",
        "            )\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = F.relu(self.bn2(self.conv2(out)))\n",
        "        out = self.bn3(self.conv3(out))\n",
        "        out += self.shortcut(x)\n",
        "        out = F.relu(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "    def __init__(self, block, num_blocks, num_classes=10):\n",
        "        super(ResNet, self).__init__()\n",
        "        self.in_planes = 64\n",
        "\n",
        "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,\n",
        "                               stride=1, padding=1, bias=False)\n",
        "        self.bn1 = nn.BatchNorm2d(64)\n",
        "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
        "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
        "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
        "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
        "        self.linear = nn.Linear(512*block.expansion, num_classes)\n",
        "\n",
        "    def _make_layer(self, block, planes, num_blocks, stride):\n",
        "        strides = [stride] + [1]*(num_blocks-1)\n",
        "        layers = []\n",
        "        for stride in strides:\n",
        "            layers.append(block(self.in_planes, planes, stride))\n",
        "            self.in_planes = planes * block.expansion\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = F.relu(self.bn1(self.conv1(x)))\n",
        "        out = self.layer1(out)\n",
        "        out = self.layer2(out)\n",
        "        out = self.layer3(out)\n",
        "        out = self.layer4(out)\n",
        "        out = F.avg_pool2d(out, 4)\n",
        "        out = out.view(out.size(0), -1)\n",
        "        out = self.linear(out)\n",
        "        return out\n",
        "\n",
        "\n",
        "def ResNet18():\n",
        "    return ResNet(BasicBlock, [2, 2, 2, 2])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OSWYRsLQw6Xj"
      },
      "outputs": [],
      "source": [
        "# Defining utility progress bar\n",
        "\n",
        "import os\n",
        "import sys\n",
        "import time\n",
        "import math\n",
        "\n",
        "term_width = int(50)\n",
        "\n",
        "TOTAL_BAR_LENGTH = 65.\n",
        "last_time = time.time()\n",
        "begin_time = last_time\n",
        "def progress_bar(current, total, msg=None):\n",
        "    global last_time, begin_time\n",
        "    if current == 0:\n",
        "        begin_time = time.time()  # Reset for new bar.\n",
        "\n",
        "    cur_len = int(TOTAL_BAR_LENGTH*current/total)\n",
        "    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1\n",
        "\n",
        "    sys.stdout.write(' [')\n",
        "    for i in range(cur_len):\n",
        "        sys.stdout.write('=')\n",
        "    sys.stdout.write('>')\n",
        "    for i in range(rest_len):\n",
        "        sys.stdout.write('.')\n",
        "    sys.stdout.write(']')\n",
        "\n",
        "    cur_time = time.time()\n",
        "    step_time = cur_time - last_time\n",
        "    last_time = cur_time\n",
        "    tot_time = cur_time - begin_time\n",
        "\n",
        "    L = []\n",
        "    L.append('  Step: %s' % format_time(step_time))\n",
        "    L.append(' | Tot: %s' % format_time(tot_time))\n",
        "    if msg:\n",
        "        L.append(' | ' + msg)\n",
        "\n",
        "    msg = ''.join(L)\n",
        "    sys.stdout.write(msg)\n",
        "    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):\n",
        "        sys.stdout.write(' ')\n",
        "\n",
        "    # Go back to the center of the bar.\n",
        "    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):\n",
        "        sys.stdout.write('\\b')\n",
        "    sys.stdout.write(' %d/%d ' % (current+1, total))\n",
        "\n",
        "    if current < total-1:\n",
        "        sys.stdout.write('\\r')\n",
        "    else:\n",
        "        sys.stdout.write('\\n')\n",
        "    sys.stdout.flush()\n",
        "\n",
        "def format_time(seconds):\n",
        "    days = int(seconds / 3600/24)\n",
        "    seconds = seconds - days*3600*24\n",
        "    hours = int(seconds / 3600)\n",
        "    seconds = seconds - hours*3600\n",
        "    minutes = int(seconds / 60)\n",
        "    seconds = seconds - minutes*60\n",
        "    secondsf = int(seconds)\n",
        "    seconds = seconds - secondsf\n",
        "    millis = int(seconds*1000)\n",
        "\n",
        "    f = ''\n",
        "    i = 1\n",
        "    if days > 0:\n",
        "        f += str(days) + 'D'\n",
        "        i += 1\n",
        "    if hours > 0 and i <= 2:\n",
        "        f += str(hours) + 'h'\n",
        "        i += 1\n",
        "    if minutes > 0 and i <= 2:\n",
        "        f += str(minutes) + 'm'\n",
        "        i += 1\n",
        "    if secondsf > 0 and i <= 2:\n",
        "        f += str(secondsf) + 's'\n",
        "        i += 1\n",
        "    if millis > 0 and i <= 2:\n",
        "        f += str(millis) + 'ms'\n",
        "        i += 1\n",
        "    if f == '':\n",
        "        f = '0ms'\n",
        "    return f"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a7i6AGYkrRpz"
      },
      "source": [
        "### Experiment Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xR79uyY_rVDu"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "import torch.backends.cudnn as cudnn\n",
        "\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "\n",
        "import os\n",
        "\n",
        "from datetime import datetime\n",
        "\n",
        "torch.manual_seed(1000)\n",
        "\n",
        "\n",
        "def main(experiment):\n",
        "\n",
        "    # Set processing device. If GPU is available prefer cuda for better performance\n",
        "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "    # Data\n",
        "    print('==> Preparing data..')\n",
        "    transform_train = transforms.Compose([\n",
        "        transforms.RandomCrop(32, padding=4),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "    ])\n",
        "\n",
        "    transform_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "    ])\n",
        "\n",
        "    import ssl\n",
        "    ssl._create_default_https_context = ssl._create_unverified_context\n",
        "\n",
        "    trainset = torchvision.datasets.CIFAR10(\n",
        "        root='../resnetexperiment-master/data', train=True, download=True, transform=transform_train)\n",
        "    trainloader = torch.utils.data.DataLoader(\n",
        "        trainset, batch_size=512, shuffle=True, num_workers=4)\n",
        "\n",
        "    testset = torchvision.datasets.CIFAR10(\n",
        "        root='../resnetexperiment-master/data', train=False, download=True, transform=transform_test)\n",
        "    testloader = torch.utils.data.DataLoader(\n",
        "        testset, batch_size=1000, shuffle=False, num_workers=2)\n",
        "\n",
        "    # Model\n",
        "    print('==> Building model..')\n",
        "    net = ResNet18()\n",
        "    net = net.to(device)\n",
        "    if device == 'cuda':\n",
        "        net = torch.nn.DataParallel(net)\n",
        "        cudnn.benchmark = True\n",
        "\n",
        "    # Loss Function\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    # Hyperparameters\n",
        "    algo = experiment.initial_settings[\"algo\"].lower()\n",
        "    eta = experiment.initial_settings[\"eta\"]\n",
        "    gamma = experiment.initial_settings[\"gamma\"]\n",
        "    delta = experiment.initial_settings[\"delta\"]\n",
        "    reduce = experiment.initial_settings[\"etaReduction\"]\n",
        "\n",
        "    # Store training information\n",
        "    accuracies = []\n",
        "    losses = []\n",
        "    step_sizes = []\n",
        "\n",
        "    if algo == \"adam\":\n",
        "        optimizer = optim.Adam(net.parameters(), lr=eta)\n",
        "    else:\n",
        "        optimizer = optim.SGD(net.parameters(), lr=eta)\n",
        "\n",
        "\n",
        "    def test(epoch):\n",
        "        \"\"\"\n",
        "        Function to calculate the accuracy at a given epoch\n",
        "        \"\"\"\n",
        "        net.eval()\n",
        "        test_loss = 0\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        with torch.no_grad():\n",
        "            for batch_idx, (inputs, targets) in enumerate(testloader):\n",
        "                inputs, targets = inputs.to(device), targets.to(device)\n",
        "                outputs = net(inputs)\n",
        "                loss = criterion(outputs, targets)\n",
        "\n",
        "                test_loss += loss.item()\n",
        "                _, predicted = outputs.max(1)\n",
        "                total += targets.size(0)\n",
        "                correct += predicted.eq(targets).sum().item()\n",
        "\n",
        "                # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n",
        "                #                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))\n",
        "        return correct / total\n",
        "\n",
        "\n",
        "    for epoch in range(0, experiment.initial_settings[\"epochs\"]):\n",
        "        print('\\nEpoch: %d' % epoch)\n",
        "        net.train()\n",
        "        train_loss = 0\n",
        "        correct = 0\n",
        "        total = 0\n",
        "        for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
        "            inputs, targets = inputs.to(device), targets.to(device)\n",
        "            optimizer.zero_grad()\n",
        "            outputs = net(inputs)\n",
        "            loss = criterion(outputs, targets)\n",
        "            loss.backward()\n",
        "            norm_grad_f = torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=float(\"inf\")).item()\n",
        "\n",
        "            if algo == \"d-gclip\" or algo == \"gclip\":\n",
        "                if algo == \"gclip\":\n",
        "                    h = min(eta, eta * (gamma / norm_grad_f) )\n",
        "                else:\n",
        "                    h = min(eta, eta * max(delta, gamma / norm_grad_f) )\n",
        "                for g in optimizer.param_groups:\n",
        "                    g[\"lr\"] = h\n",
        "\n",
        "\n",
        "            train_loss += loss.item()\n",
        "            _, predicted = outputs.max(1)\n",
        "            total += targets.size(0)\n",
        "            correct += predicted.eq(targets).sum().item()\n",
        "\n",
        "            # if batch_idx % 10 == 0 or batch_idx == 97:\n",
        "                # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n",
        "                #             % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))\n",
        "\n",
        "            optimizer.step()\n",
        "\n",
        "\n",
        "        print(f\"Epoch {epoch} | Step size: {h if algo == 'gclip' or algo == 'd-gclip' else eta} | norm =\", norm_grad_f)\n",
        "        if reduce and (epoch == 99 or epoch == 149): # If step size reduction has been set\n",
        "            eta *= 0.1\n",
        "            for g in optimizer.param_groups:\n",
        "                g[\"lr\"] = eta\n",
        "            print(\"reduced eta to:\", eta)\n",
        "\n",
        "        if epoch % 10 == 0:\n",
        "            now = datetime.now()\n",
        "            current_time = now.strftime(\"%H:%M:%S\")\n",
        "            print(\"Current Time =\", current_time)\n",
        "\n",
        "        losses.append(train_loss)\n",
        "        if algo == \"d-gclip\" or algo == \"gclip\":\n",
        "            step_sizes.append(h)\n",
        "        else:\n",
        "            step_sizes.append(eta)\n",
        "\n",
        "        accuracies.append(test(epoch))\n",
        "\n",
        "\n",
        "    # Save data\n",
        "    experiment.setResults(losses, step_sizes, accuracies)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cB4Uujg_GbXi"
      },
      "source": [
        "### Run Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q-JACil7GdY1"
      },
      "outputs": [],
      "source": [
        "class Experiment:\n",
        "    def __init__(self, algo, eta, gamma, delta, epochs, etaReduction) -> None:\n",
        "        self.initial_settings = dict()\n",
        "        self.initial_settings[\"algo\"] = algo\n",
        "        self.initial_settings[\"eta\"] = eta\n",
        "        self.initial_settings[\"gamma\"] = gamma\n",
        "        self.initial_settings[\"delta\"] = delta\n",
        "        self.initial_settings[\"epochs\"] = epochs\n",
        "        self.initial_settings[\"etaReduction\"] = etaReduction\n",
        "        self.results = dict()\n",
        "\n",
        "    def setResults(self, losses, step_sizes, test_accuracies, gradients = None, test_losses = None):\n",
        "        self.results[\"losses\"] = losses\n",
        "        self.results[\"step_sizes\"] = step_sizes\n",
        "        self.results[\"gradients\"] = gradients\n",
        "        self.results[\"test_losses\"] = test_losses\n",
        "        self.results[\"test_accuracies\"] = test_accuracies\n",
        "\n",
        "    def __str__(self):\n",
        "        return f\"Experiment with variables: {self.initial_settings}.\\nResults: {self.results}\"\n",
        "\n",
        "\n",
        "# DEFINE EXPERIMENTS\n",
        "# Here you can set the experiments to run, which will run one after the other until completion\n",
        "# Supported algorithms are: d-GClip, GClip, SGD and Adam\n",
        "# For d-GClip the parameters represent: eta, gamma, delta\n",
        "# For GClip the parameters represent: eta, gamma, [third one is ignored]\n",
        "# For Adam and SGD the first parameter is step size, and the second and third are ignored\n",
        "# Fourth parameter is the number of epochs\n",
        "# Fifth parameter defines if step size reduction should be used at epochs 50 and 150 (divide by 10)\n",
        "experiments = [\n",
        "    Experiment(\"gd\", 1, None, None, 200, True),\n",
        "    Experiment(\"d-gclip\", 5, 0.25, 1e-3, 200, True),\n",
        "    Experiment(\"d-gclip\", 5, 0.25, 1e-8, 200, True),\n",
        "    Experiment(\"gclip\", 5, 0.25, None, 200, True),\n",
        "    Experiment(\"adam\", 0.0001, None, None, 200, True),\n",
        "]\n",
        "\n",
        "\n",
        "experiment_results = []\n",
        "for exp in experiments:\n",
        "    print(\"Initiating experiment:\", exp.initial_settings)\n",
        "    main(exp)\n",
        "    print(\"Experiment ended. Results:\")\n",
        "    print([exp.initial_settings, exp.results])\n",
        "\n",
        "    # Save results to later produce graphs\n",
        "    experiment_results.append([exp.initial_settings, exp.results])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2TI4jVxtGIFh"
      },
      "source": [
        "## Generate Results Graphs"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "for experiment_res in experiment_results:\n",
        "    print(len(experiment_res), experiment_res)"
      ],
      "metadata": {
        "id": "Qk-nQRI7fRWR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eNiJpG_ipNkg"
      },
      "source": [
        "### Graph Plotting Utility Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fPIeKbyBpKGh"
      },
      "outputs": [],
      "source": [
        "def getExperimentString(exp):\n",
        "    \"\"\"\n",
        "    Returns formatted string for an experiment.\n",
        "\n",
        "    Eg. Given Experiment(\"d-gclip\", 5, 0.25, 1e-8, 5, False)\n",
        "    it will return the string: \"$\\delta$-GClip (5;0.25;1e-8)\".\n",
        "    \"\"\"\n",
        "    string = \"\"\n",
        "    if exp[\"algo\"] == \"gclip\":\n",
        "        if exp[\"delta\"]:\n",
        "            string += \"$\\delta$-GClip\"\n",
        "        else:\n",
        "            string += \"GClip\"\n",
        "    elif exp[\"algo\"] == \"d-gclip\":\n",
        "        string += \"$\\delta$-GClip\"\n",
        "    elif exp[\"algo\"].lower() == \"gd\":\n",
        "        string += \"SGD\"\n",
        "    elif exp[\"algo\"].lower() == \"adam\":\n",
        "        string += \"Adam\"\n",
        "    else:\n",
        "        string += exp[\"algo\"]\n",
        "\n",
        "    string += f\" ({exp['eta']}\"\n",
        "    if exp[\"gamma\"]:\n",
        "        string += f\";{exp['gamma']}\"\n",
        "    if exp[\"delta\"]:\n",
        "        if exp[\"delta\"] == 1e-3:\n",
        "            string += \";1e-03\"\n",
        "        else:\n",
        "            string += f\";{exp['delta']}\"\n",
        "    return string + \")\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XNjAWSlCpR0D"
      },
      "source": [
        "### Plot Graphs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LDLP9PnIGKm8"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "fig, axs = plt.subplots(2, 2, figsize=(10, 5))\n",
        "\n",
        "for idx, exp in enumerate(experiment_results):\n",
        "    algo, eta, gamma, delta = exp[0][\"algo\"].lower(), exp[0][\"eta\"], exp[0][\"gamma\"], exp[0][\"delta\"]\n",
        "\n",
        "    # Plot the four graphs\n",
        "    axs[0,0].plot([i for i in range(1, 201)], exp[1][\"losses\"], label=getExperimentString(exp[0]))\n",
        "    axs[0,1].plot([i for i in range(1, 201)], [i*100 for i in exp[1][\"test_accuracies\"]], label=getExperimentString(exp[0]))\n",
        "    axs[1,0].plot([i for i in range(150, 200)], exp[1][\"losses\"][-50:], label=getExperimentString(exp[0]))\n",
        "    axs[1,1].plot([i for i in range(150, 200)], [i*100 for i in exp[1][\"test_accuracies\"][-50:]], label=getExperimentString(exp[0]))\n",
        "\n",
        "\n",
        "# Top left\n",
        "axs[0,0].set_ylabel(\"Training Loss (log)\")\n",
        "axs[0,0].set_xlabel(\"Epochs\")\n",
        "axs[0,0].set_yscale(\"log\")\n",
        "\n",
        "# Top right\n",
        "axs[0,1].set_ylabel(\"Test Accuracy (%)\")\n",
        "axs[0,1].set_xlabel(\"Epochs\")\n",
        "axs[0,1].set_ylim(80)\n",
        "\n",
        "# Bottom left\n",
        "axs[1,0].set_ylabel(\"Training Loss\")\n",
        "axs[1,0].set_xlabel(\"Epochs (last 50)\")\n",
        "\n",
        "# Bottom right\n",
        "axs[1,1].set_ylabel(\"Test Accuracy (%)\")\n",
        "axs[1,1].set_xlabel(\"Epochs (last 50)\")\n",
        "axs[1,1].set_ylim(89)\n",
        "\n",
        "handles, labels = axs[0, 1].get_legend_handles_labels()\n",
        "labels[2], labels[1] = labels[1], labels[2]\n",
        "fig.legend(handles, labels, loc=\"lower center\", ncol=5, bbox_to_anchor=(0.5, -0.06))\n",
        "fig.suptitle(\"ResNet-18 on CIFAR-10 (LR scheduling, no weight-decay)\", fontsize=16)\n",
        "fig.subplots_adjust(hspace=0.3, wspace=0.2)\n",
        "plt.plot()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "-qio0AMbp5mH",
        "eNiJpG_ipNkg",
        "XNjAWSlCpR0D"
      ],
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}