{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "output_layers.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zqdhd4Z_DJ0E"
      },
      "source": [
        "# Output Layer Designs - Code Sample\n",
        "\n",
        "This notebook trains a ResNet-50 on a small portion of the MNIST dataset. It implements all presented output layers. To try it out, change the layer_type in the training settings section and run all cells."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x_lirhBwDkFA"
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "from sklearn.model_selection import train_test_split\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torchsummary import summary\n",
        "\n",
        "import torch.backends.cudnn as cudnn\n",
        "seed = 42\n",
        "np.random.seed(seed)\n",
        "random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "cudnn.deterministic = True"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nqBQerJ2EovW"
      },
      "source": [
        "### Training Settings"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qHmte0x8DoB8"
      },
      "source": [
        "num_input_channels = 1\n",
        "img_size = 32\n",
        "\n",
        "lr = 1e-2\n",
        "num_epochs = 20\n",
        "bs = 32\n",
        "\n",
        "# Output layer settings and hyperparameters\n",
        "layer_type = 'trained' #set either of: trained, random, sparse, scale, 1to1, ensemble\n",
        "K = 10 # number of classes\n",
        "N = 2048 # number of conv channels in last conv layer\n",
        "\n",
        "H = 10 # number of heads, used for W^ensemble\n",
        "q = 0.9 # sparsity hyperparam, used for W^sparse\n",
        "alpha = 0.1 # used for W^scale\n",
        "alpha_head = 1. # used for W^ensemble"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cl-Ru6nTFttT"
      },
      "source": [
        "### Device"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JCCEUmCMFt_p"
      },
      "source": [
        "def to_device(data, device):\n",
        "  if isinstance(data, (list, tuple)):\n",
        "    return [to_device(x, device) for x in data]\n",
        "  else:\n",
        "    return data.to(device, non_blocking=True)\n",
        "\n",
        "def get_default_device():\n",
        "  if torch.cuda.is_available():\n",
        "    return torch.device('cuda')\n",
        "  else:\n",
        "    return torch.device('cpu')\n",
        "\n",
        "device = get_default_device()"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "srNpopJaFyay"
      },
      "source": [
        "### Dataloading"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8rm5kIdhFvcC"
      },
      "source": [
        "class DeviceDataLoader():\n",
        "  def __init__(self, dl, device):\n",
        "    self.dl = dl\n",
        "    self.device = device\n",
        "\n",
        "  def __iter__(self):\n",
        "    for b in self.dl:\n",
        "      yield to_device(b, self.device)\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.dl)"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TsHiVbXyF0nI"
      },
      "source": [
        "classes = [str(i) for i in range(K)]\n",
        "data_transform = transforms.Compose([\n",
        "  transforms.Resize(img_size),\n",
        "  transforms.ToTensor(),\n",
        "  transforms.Normalize((0.5), (0.5))\n",
        "])\n",
        "dset = torchvision.datasets.MNIST(root='mnist', train=True, download=True, transform=data_transform)"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HDXrOMuXJP4y"
      },
      "source": [
        "train_idx, val_idx= train_test_split(\n",
        "  np.arange(len(dset.targets)),\n",
        "  test_size=0.99,\n",
        "  shuffle=True,\n",
        "  stratify=dset.targets,\n",
        "  random_state=42\n",
        ")\n",
        "\n",
        "# create a small training set (1% of MNIST), and a big dev set\n",
        "trainset = torch.utils.data.Subset(dset, train_idx)\n",
        "valset = torch.utils.data.Subset(dset, val_idx)\n",
        "\n",
        "trainloader = DeviceDataLoader(torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=2, pin_memory=True), device=device)\n",
        "valloader = DeviceDataLoader(torch.utils.data.DataLoader(valset, batch_size=bs, num_workers=2, pin_memory=True), device=device)"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xCJHRae_G5_X"
      },
      "source": [
        "### ResNet Components"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W_3Z2KdTG6mK"
      },
      "source": [
        "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n",
        "  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
        "                   padding=dilation, groups=groups, bias=False, dilation=dilation)\n",
        "\n",
        "def conv1x1(in_planes, out_planes, stride=1):\n",
        "  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
        "\n",
        "class Bottleneck(nn.Module):\n",
        "  expansion = 4\n",
        "\n",
        "  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None):\n",
        "    super(Bottleneck, self).__init__()\n",
        "    norm_layer = norm_layer\n",
        "    width = int(planes * (base_width / 64.)) * groups\n",
        "    self.conv1 = conv1x1(inplanes, width)\n",
        "    self.bn1 = norm_layer(width)\n",
        "    self.conv2 = conv3x3(width, width, stride, groups, dilation)\n",
        "    self.bn2 = norm_layer(width)\n",
        "    self.conv3 = conv1x1(width, planes * self.expansion)\n",
        "    self.bn3 = norm_layer(planes * self.expansion)\n",
        "    self.relu = nn.ReLU(inplace=True)\n",
        "    self.downsample = downsample\n",
        "    self.stride = stride\n",
        "\n",
        "  def forward(self, x):\n",
        "    identity = x\n",
        "\n",
        "    out = self.conv1(x)\n",
        "    out = self.bn1(out)\n",
        "    out = self.relu(out)\n",
        "\n",
        "    out = self.conv2(out)\n",
        "    out = self.bn2(out)\n",
        "    out = self.relu(out)\n",
        "\n",
        "    out = self.conv3(out)\n",
        "    out = self.bn3(out)\n",
        "\n",
        "    if self.downsample is not None:\n",
        "        identity = self.downsample(x)\n",
        "\n",
        "    out += identity\n",
        "    out = self.relu(out)\n",
        "\n",
        "    return out"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fQ3gbMigHCKe"
      },
      "source": [
        "### ResNet definition"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rPk0_AaJHIei"
      },
      "source": [
        "This ResNet implementation is adapted from torchvision: https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet50\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LuVwqfjgG_ZW"
      },
      "source": [
        "class ResNet(nn.Module):\n",
        "  def make_weights_sparse(self, q):\n",
        "    num_to_remove = int(q * N)\n",
        "    idx = torch.stack([torch.randperm(N)[:num_to_remove] for _ in range(K)])\n",
        "    for i in range(K):\n",
        "        self.fc.weight.data[i, idx[i]] = 0.\n",
        "    self.fc.bias.data.fill_(0.)\n",
        "\n",
        "  def __init__(self, block, layers, num_input_channels=1, img_size=32, groups=1, width_per_group=64,\n",
        "               replace_stride_with_dilation=None, device='cuda'):\n",
        "    super(ResNet, self).__init__()\n",
        "    self._norm_layer = nn.BatchNorm2d\n",
        "    self.img_size = img_size\n",
        "    self.inplanes = 64\n",
        "    self.dilation = 1\n",
        "    if replace_stride_with_dilation is None:\n",
        "      replace_stride_with_dilation = [False, False, False]\n",
        "    if len(replace_stride_with_dilation) != 3:\n",
        "      raise ValueError(\"replace_stride_with_dilation should be None \"\n",
        "                          \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n",
        "    self.groups = groups\n",
        "    self.base_width = width_per_group\n",
        "\n",
        "    # define encoder\n",
        "    self.conv1 = nn.Conv2d(num_input_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    self.bn1 = self._norm_layer(self.inplanes)\n",
        "    self.relu = nn.ReLU(inplace=True)\n",
        "    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
        "    self.layer1 = self._make_layer(block, 64, layers[0])\n",
        "    self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])\n",
        "    self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])\n",
        "    self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])\n",
        "    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
        "    \n",
        "    # define output layer\n",
        "    if layer_type not in ['1to1', 'ensemble']:\n",
        "      self.fc = nn.Linear(512 * block.expansion, K)\n",
        "\n",
        "      if layer_type in ['random', 'sparse']:\n",
        "        self.fc.requires_grad_(False)\n",
        "      if layer_type == 'sparse':\n",
        "        self.make_weights_sparse(q)\n",
        "\n",
        "    # weight initialization\n",
        "    for m in self.modules():\n",
        "      if isinstance(m, nn.Conv2d):\n",
        "        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
        "      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n",
        "        nn.init.constant_(m.weight, 1)\n",
        "        nn.init.constant_(m.bias, 0)\n",
        "      \n",
        "  def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n",
        "    norm_layer = self._norm_layer\n",
        "    downsample = None\n",
        "    previous_dilation = self.dilation\n",
        "    if dilate:\n",
        "      self.dilation *= stride\n",
        "      stride = 1\n",
        "    if stride != 1 or self.inplanes != planes * block.expansion:\n",
        "      downsample = nn.Sequential(\n",
        "        conv1x1(self.inplanes, planes * block.expansion, stride),\n",
        "        norm_layer(planes * block.expansion),\n",
        "      )\n",
        "\n",
        "    layers = []\n",
        "    layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n",
        "                        self.base_width, previous_dilation, norm_layer))\n",
        "    \n",
        "    self.inplanes = planes * block.expansion\n",
        "    for _ in range(1, blocks):\n",
        "      layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width,\n",
        "                          dilation=self.dilation, norm_layer=norm_layer))\n",
        "\n",
        "    return nn.Sequential(*layers)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = self.conv1(x)\n",
        "    x = self.bn1(x)\n",
        "    x = self.relu(x)\n",
        "    x = self.layer1(x)\n",
        "    x = self.layer2(x)\n",
        "    x = self.layer3(x)\n",
        "    x = self.layer4(x)\n",
        "    x = self.avgpool(x)\n",
        "    x = torch.flatten(x, 1)\n",
        "\n",
        "    if layer_type == 'scale':\n",
        "      x = x * alpha\n",
        "    elif layer_type == 'ensemble':\n",
        "      x = x * alpha_head\n",
        "\n",
        "    if layer_type == '1to1':\n",
        "      return x[:, :K]\n",
        "    elif layer_type == 'ensemble':\n",
        "      return x[:, :K*H]\n",
        "\n",
        "    pred = self.fc(x)\n",
        "    return pred\n",
        "\n",
        "def resnet50(**kwargs):\n",
        "  return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgzAfQHtIlgL"
      },
      "source": [
        "### Initialization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MJ3Ia41jIghn"
      },
      "source": [
        "net = resnet50(num_input_channels=num_input_channels, img_size=img_size, device=device)\n",
        "to_device(net, device)\n",
        "optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)\n",
        "# For the sake of simplicity, no scheduler is used here."
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w4a-w2MXNI1u",
        "outputId": "70d90dba-0236-4cae-a6bd-f4262f080f66"
      },
      "source": [
        "summary(net, (num_input_channels, img_size, img_size))"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "----------------------------------------------------------------\n",
            "        Layer (type)               Output Shape         Param #\n",
            "================================================================\n",
            "            Conv2d-1           [-1, 64, 32, 32]             576\n",
            "       BatchNorm2d-2           [-1, 64, 32, 32]             128\n",
            "              ReLU-3           [-1, 64, 32, 32]               0\n",
            "            Conv2d-4           [-1, 64, 32, 32]           4,096\n",
            "       BatchNorm2d-5           [-1, 64, 32, 32]             128\n",
            "              ReLU-6           [-1, 64, 32, 32]               0\n",
            "            Conv2d-7           [-1, 64, 32, 32]          36,864\n",
            "       BatchNorm2d-8           [-1, 64, 32, 32]             128\n",
            "              ReLU-9           [-1, 64, 32, 32]               0\n",
            "           Conv2d-10          [-1, 256, 32, 32]          16,384\n",
            "      BatchNorm2d-11          [-1, 256, 32, 32]             512\n",
            "           Conv2d-12          [-1, 256, 32, 32]          16,384\n",
            "      BatchNorm2d-13          [-1, 256, 32, 32]             512\n",
            "             ReLU-14          [-1, 256, 32, 32]               0\n",
            "       Bottleneck-15          [-1, 256, 32, 32]               0\n",
            "           Conv2d-16           [-1, 64, 32, 32]          16,384\n",
            "      BatchNorm2d-17           [-1, 64, 32, 32]             128\n",
            "             ReLU-18           [-1, 64, 32, 32]               0\n",
            "           Conv2d-19           [-1, 64, 32, 32]          36,864\n",
            "      BatchNorm2d-20           [-1, 64, 32, 32]             128\n",
            "             ReLU-21           [-1, 64, 32, 32]               0\n",
            "           Conv2d-22          [-1, 256, 32, 32]          16,384\n",
            "      BatchNorm2d-23          [-1, 256, 32, 32]             512\n",
            "             ReLU-24          [-1, 256, 32, 32]               0\n",
            "       Bottleneck-25          [-1, 256, 32, 32]               0\n",
            "           Conv2d-26           [-1, 64, 32, 32]          16,384\n",
            "      BatchNorm2d-27           [-1, 64, 32, 32]             128\n",
            "             ReLU-28           [-1, 64, 32, 32]               0\n",
            "           Conv2d-29           [-1, 64, 32, 32]          36,864\n",
            "      BatchNorm2d-30           [-1, 64, 32, 32]             128\n",
            "             ReLU-31           [-1, 64, 32, 32]               0\n",
            "           Conv2d-32          [-1, 256, 32, 32]          16,384\n",
            "      BatchNorm2d-33          [-1, 256, 32, 32]             512\n",
            "             ReLU-34          [-1, 256, 32, 32]               0\n",
            "       Bottleneck-35          [-1, 256, 32, 32]               0\n",
            "           Conv2d-36          [-1, 128, 32, 32]          32,768\n",
            "      BatchNorm2d-37          [-1, 128, 32, 32]             256\n",
            "             ReLU-38          [-1, 128, 32, 32]               0\n",
            "           Conv2d-39          [-1, 128, 16, 16]         147,456\n",
            "      BatchNorm2d-40          [-1, 128, 16, 16]             256\n",
            "             ReLU-41          [-1, 128, 16, 16]               0\n",
            "           Conv2d-42          [-1, 512, 16, 16]          65,536\n",
            "      BatchNorm2d-43          [-1, 512, 16, 16]           1,024\n",
            "           Conv2d-44          [-1, 512, 16, 16]         131,072\n",
            "      BatchNorm2d-45          [-1, 512, 16, 16]           1,024\n",
            "             ReLU-46          [-1, 512, 16, 16]               0\n",
            "       Bottleneck-47          [-1, 512, 16, 16]               0\n",
            "           Conv2d-48          [-1, 128, 16, 16]          65,536\n",
            "      BatchNorm2d-49          [-1, 128, 16, 16]             256\n",
            "             ReLU-50          [-1, 128, 16, 16]               0\n",
            "           Conv2d-51          [-1, 128, 16, 16]         147,456\n",
            "      BatchNorm2d-52          [-1, 128, 16, 16]             256\n",
            "             ReLU-53          [-1, 128, 16, 16]               0\n",
            "           Conv2d-54          [-1, 512, 16, 16]          65,536\n",
            "      BatchNorm2d-55          [-1, 512, 16, 16]           1,024\n",
            "             ReLU-56          [-1, 512, 16, 16]               0\n",
            "       Bottleneck-57          [-1, 512, 16, 16]               0\n",
            "           Conv2d-58          [-1, 128, 16, 16]          65,536\n",
            "      BatchNorm2d-59          [-1, 128, 16, 16]             256\n",
            "             ReLU-60          [-1, 128, 16, 16]               0\n",
            "           Conv2d-61          [-1, 128, 16, 16]         147,456\n",
            "      BatchNorm2d-62          [-1, 128, 16, 16]             256\n",
            "             ReLU-63          [-1, 128, 16, 16]               0\n",
            "           Conv2d-64          [-1, 512, 16, 16]          65,536\n",
            "      BatchNorm2d-65          [-1, 512, 16, 16]           1,024\n",
            "             ReLU-66          [-1, 512, 16, 16]               0\n",
            "       Bottleneck-67          [-1, 512, 16, 16]               0\n",
            "           Conv2d-68          [-1, 128, 16, 16]          65,536\n",
            "      BatchNorm2d-69          [-1, 128, 16, 16]             256\n",
            "             ReLU-70          [-1, 128, 16, 16]               0\n",
            "           Conv2d-71          [-1, 128, 16, 16]         147,456\n",
            "      BatchNorm2d-72          [-1, 128, 16, 16]             256\n",
            "             ReLU-73          [-1, 128, 16, 16]               0\n",
            "           Conv2d-74          [-1, 512, 16, 16]          65,536\n",
            "      BatchNorm2d-75          [-1, 512, 16, 16]           1,024\n",
            "             ReLU-76          [-1, 512, 16, 16]               0\n",
            "       Bottleneck-77          [-1, 512, 16, 16]               0\n",
            "           Conv2d-78          [-1, 256, 16, 16]         131,072\n",
            "      BatchNorm2d-79          [-1, 256, 16, 16]             512\n",
            "             ReLU-80          [-1, 256, 16, 16]               0\n",
            "           Conv2d-81            [-1, 256, 8, 8]         589,824\n",
            "      BatchNorm2d-82            [-1, 256, 8, 8]             512\n",
            "             ReLU-83            [-1, 256, 8, 8]               0\n",
            "           Conv2d-84           [-1, 1024, 8, 8]         262,144\n",
            "      BatchNorm2d-85           [-1, 1024, 8, 8]           2,048\n",
            "           Conv2d-86           [-1, 1024, 8, 8]         524,288\n",
            "      BatchNorm2d-87           [-1, 1024, 8, 8]           2,048\n",
            "             ReLU-88           [-1, 1024, 8, 8]               0\n",
            "       Bottleneck-89           [-1, 1024, 8, 8]               0\n",
            "           Conv2d-90            [-1, 256, 8, 8]         262,144\n",
            "      BatchNorm2d-91            [-1, 256, 8, 8]             512\n",
            "             ReLU-92            [-1, 256, 8, 8]               0\n",
            "           Conv2d-93            [-1, 256, 8, 8]         589,824\n",
            "      BatchNorm2d-94            [-1, 256, 8, 8]             512\n",
            "             ReLU-95            [-1, 256, 8, 8]               0\n",
            "           Conv2d-96           [-1, 1024, 8, 8]         262,144\n",
            "      BatchNorm2d-97           [-1, 1024, 8, 8]           2,048\n",
            "             ReLU-98           [-1, 1024, 8, 8]               0\n",
            "       Bottleneck-99           [-1, 1024, 8, 8]               0\n",
            "          Conv2d-100            [-1, 256, 8, 8]         262,144\n",
            "     BatchNorm2d-101            [-1, 256, 8, 8]             512\n",
            "            ReLU-102            [-1, 256, 8, 8]               0\n",
            "          Conv2d-103            [-1, 256, 8, 8]         589,824\n",
            "     BatchNorm2d-104            [-1, 256, 8, 8]             512\n",
            "            ReLU-105            [-1, 256, 8, 8]               0\n",
            "          Conv2d-106           [-1, 1024, 8, 8]         262,144\n",
            "     BatchNorm2d-107           [-1, 1024, 8, 8]           2,048\n",
            "            ReLU-108           [-1, 1024, 8, 8]               0\n",
            "      Bottleneck-109           [-1, 1024, 8, 8]               0\n",
            "          Conv2d-110            [-1, 256, 8, 8]         262,144\n",
            "     BatchNorm2d-111            [-1, 256, 8, 8]             512\n",
            "            ReLU-112            [-1, 256, 8, 8]               0\n",
            "          Conv2d-113            [-1, 256, 8, 8]         589,824\n",
            "     BatchNorm2d-114            [-1, 256, 8, 8]             512\n",
            "            ReLU-115            [-1, 256, 8, 8]               0\n",
            "          Conv2d-116           [-1, 1024, 8, 8]         262,144\n",
            "     BatchNorm2d-117           [-1, 1024, 8, 8]           2,048\n",
            "            ReLU-118           [-1, 1024, 8, 8]               0\n",
            "      Bottleneck-119           [-1, 1024, 8, 8]               0\n",
            "          Conv2d-120            [-1, 256, 8, 8]         262,144\n",
            "     BatchNorm2d-121            [-1, 256, 8, 8]             512\n",
            "            ReLU-122            [-1, 256, 8, 8]               0\n",
            "          Conv2d-123            [-1, 256, 8, 8]         589,824\n",
            "     BatchNorm2d-124            [-1, 256, 8, 8]             512\n",
            "            ReLU-125            [-1, 256, 8, 8]               0\n",
            "          Conv2d-126           [-1, 1024, 8, 8]         262,144\n",
            "     BatchNorm2d-127           [-1, 1024, 8, 8]           2,048\n",
            "            ReLU-128           [-1, 1024, 8, 8]               0\n",
            "      Bottleneck-129           [-1, 1024, 8, 8]               0\n",
            "          Conv2d-130            [-1, 256, 8, 8]         262,144\n",
            "     BatchNorm2d-131            [-1, 256, 8, 8]             512\n",
            "            ReLU-132            [-1, 256, 8, 8]               0\n",
            "          Conv2d-133            [-1, 256, 8, 8]         589,824\n",
            "     BatchNorm2d-134            [-1, 256, 8, 8]             512\n",
            "            ReLU-135            [-1, 256, 8, 8]               0\n",
            "          Conv2d-136           [-1, 1024, 8, 8]         262,144\n",
            "     BatchNorm2d-137           [-1, 1024, 8, 8]           2,048\n",
            "            ReLU-138           [-1, 1024, 8, 8]               0\n",
            "      Bottleneck-139           [-1, 1024, 8, 8]               0\n",
            "          Conv2d-140            [-1, 512, 8, 8]         524,288\n",
            "     BatchNorm2d-141            [-1, 512, 8, 8]           1,024\n",
            "            ReLU-142            [-1, 512, 8, 8]               0\n",
            "          Conv2d-143            [-1, 512, 4, 4]       2,359,296\n",
            "     BatchNorm2d-144            [-1, 512, 4, 4]           1,024\n",
            "            ReLU-145            [-1, 512, 4, 4]               0\n",
            "          Conv2d-146           [-1, 2048, 4, 4]       1,048,576\n",
            "     BatchNorm2d-147           [-1, 2048, 4, 4]           4,096\n",
            "          Conv2d-148           [-1, 2048, 4, 4]       2,097,152\n",
            "     BatchNorm2d-149           [-1, 2048, 4, 4]           4,096\n",
            "            ReLU-150           [-1, 2048, 4, 4]               0\n",
            "      Bottleneck-151           [-1, 2048, 4, 4]               0\n",
            "          Conv2d-152            [-1, 512, 4, 4]       1,048,576\n",
            "     BatchNorm2d-153            [-1, 512, 4, 4]           1,024\n",
            "            ReLU-154            [-1, 512, 4, 4]               0\n",
            "          Conv2d-155            [-1, 512, 4, 4]       2,359,296\n",
            "     BatchNorm2d-156            [-1, 512, 4, 4]           1,024\n",
            "            ReLU-157            [-1, 512, 4, 4]               0\n",
            "          Conv2d-158           [-1, 2048, 4, 4]       1,048,576\n",
            "     BatchNorm2d-159           [-1, 2048, 4, 4]           4,096\n",
            "            ReLU-160           [-1, 2048, 4, 4]               0\n",
            "      Bottleneck-161           [-1, 2048, 4, 4]               0\n",
            "          Conv2d-162            [-1, 512, 4, 4]       1,048,576\n",
            "     BatchNorm2d-163            [-1, 512, 4, 4]           1,024\n",
            "            ReLU-164            [-1, 512, 4, 4]               0\n",
            "          Conv2d-165            [-1, 512, 4, 4]       2,359,296\n",
            "     BatchNorm2d-166            [-1, 512, 4, 4]           1,024\n",
            "            ReLU-167            [-1, 512, 4, 4]               0\n",
            "          Conv2d-168           [-1, 2048, 4, 4]       1,048,576\n",
            "     BatchNorm2d-169           [-1, 2048, 4, 4]           4,096\n",
            "            ReLU-170           [-1, 2048, 4, 4]               0\n",
            "      Bottleneck-171           [-1, 2048, 4, 4]               0\n",
            "AdaptiveAvgPool2d-172           [-1, 2048, 1, 1]               0\n",
            "          Linear-173                   [-1, 10]          20,490\n",
            "================================================================\n",
            "Total params: 23,519,690\n",
            "Trainable params: 23,519,690\n",
            "Non-trainable params: 0\n",
            "----------------------------------------------------------------\n",
            "Input size (MB): 0.00\n",
            "Forward/backward pass size (MB): 88.58\n",
            "Params size (MB): 89.72\n",
            "Estimated Total Size (MB): 178.30\n",
            "----------------------------------------------------------------\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dBZUdYj-NVM_"
      },
      "source": [
        "### Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3RBbvapJNPDG",
        "outputId": "8b35d764-6397-48f7-c40f-14b6febdf026"
      },
      "source": [
        "running_CE = 0\n",
        "train_len = len(trainloader)\n",
        "\n",
        "net.train()\n",
        "for epoch in range(num_epochs):\n",
        "  for data in trainloader:\n",
        "    images, labels = data\n",
        "    optimizer.zero_grad()\n",
        "    \n",
        "    pred = net(images)\n",
        "    \n",
        "    if layer_type == 'ensemble':\n",
        "      loss = 0\n",
        "      for h in range(H):\n",
        "        loss += nn.CrossEntropyLoss()(pred[:,h*K:h*K+K], labels)\n",
        "      loss /= H\n",
        "    else:\n",
        "      loss = nn.CrossEntropyLoss()(pred, labels)\n",
        "    \n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    \n",
        "    running_CE += loss.item()\n",
        "      \n",
        "  print('[Train %d] loss: %.3f' % (epoch + 1, running_CE/train_len))\n",
        "  running_CE = 0.0"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[Train 1] loss: 9.882\n",
            "[Train 2] loss: 3.767\n",
            "[Train 3] loss: 2.576\n",
            "[Train 4] loss: 2.392\n",
            "[Train 5] loss: 2.114\n",
            "[Train 6] loss: 1.783\n",
            "[Train 7] loss: 1.293\n",
            "[Train 8] loss: 0.902\n",
            "[Train 9] loss: 0.593\n",
            "[Train 10] loss: 0.414\n",
            "[Train 11] loss: 0.373\n",
            "[Train 12] loss: 0.265\n",
            "[Train 13] loss: 0.136\n",
            "[Train 14] loss: 0.090\n",
            "[Train 15] loss: 0.151\n",
            "[Train 16] loss: 0.139\n",
            "[Train 17] loss: 0.068\n",
            "[Train 18] loss: 0.102\n",
            "[Train 19] loss: 0.104\n",
            "[Train 20] loss: 0.140\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ce4OwCxyOgP3"
      },
      "source": [
        "### Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jYEItncGN0iy",
        "outputId": "8ef04306-0941-4dc9-dbff-690a3ac6580e"
      },
      "source": [
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "net.eval()\n",
        "with torch.no_grad():\n",
        "  for data in valloader:\n",
        "    images, labels = data\n",
        "    pred = net(images)\n",
        "    \n",
        "    if layer_type == 'ensemble':\n",
        "      # average logits over heads\n",
        "      pred = torch.mean(pred.view(images.shape[0], H, K), dim=1)\n",
        "\n",
        "    _, pred_ids = torch.max(pred, dim=-1)\n",
        "    \n",
        "    total += labels.size(0)\n",
        "    correct += (pred_ids == labels).float().sum().item()\n",
        "  acc = 100 * correct / total\n",
        "    \n",
        "print(\"Accuracy: \", acc)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Accuracy:  85.3973063973064\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}