{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# General E(2)-Equivariant Steerable CNNs  -  A concrete example\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from e2cnn import gspaces\n",
    "from e2cnn import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we build a **Steerable CNN** and try it MNIST.\n",
    "\n",
    "Let's also use a group a bit larger: we now build a model equivariant to $8$ rotations.\n",
    "We indicate the group of $N$ discrete rotations as $C_N$, i.e. the **cyclic group** of order $N$.\n",
    "In this case, we will use $C_8$.\n",
    "\n",
    "Because the inputs are still gray-scale images, the input type of the model is again a *scalar field*.\n",
    "\n",
    "However, internally we use *regular fields*: this is equivalent to a *group-equivariant convolutional neural network*.\n",
    "\n",
    "Finally, we build *invariant* features for the final classification task by pooling over the group using *Group Pooling*.\n",
    "\n",
    "The final classification is performed by a two fully connected layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The model\n",
    "\n",
    "Here is the definition of our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class C8SteerableCNN(torch.nn.Module):\n",
    "    \n",
    "    def __init__(self, n_classes=10):\n",
    "        \n",
    "        super(C8SteerableCNN, self).__init__()\n",
    "        \n",
    "        # the model is equivariant under rotations by 45 degrees, modelled by C8\n",
    "        self.r2_act = gspaces.Rot2dOnR2(N=8)\n",
    "        \n",
    "        # the input image is a scalar field, corresponding to the trivial representation\n",
    "        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])\n",
    "        \n",
    "        # we store the input type for wrapping the images into a geometric tensor during the forward pass\n",
    "        self.input_type = in_type\n",
    "        \n",
    "        # convolution 1\n",
    "        # first specify the output type of the convolutional layer\n",
    "        # we choose 24 feature fields, each transforming under the regular representation of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 24*[self.r2_act.regular_repr])\n",
    "        self.block1 = nn.SequentialModule(\n",
    "            nn.MaskModule(in_type, 29, margin=1),\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        \n",
    "        # convolution 2\n",
    "        # the old output type is the input type to the next layer\n",
    "        in_type = self.block1.out_type\n",
    "        # the output type of the second convolution layer are 48 regular feature fields of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])\n",
    "        self.block2 = nn.SequentialModule(\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        self.pool1 = nn.SequentialModule(\n",
    "            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)\n",
    "        )\n",
    "        \n",
    "        # convolution 3\n",
    "        # the old output type is the input type to the next layer\n",
    "        in_type = self.block2.out_type\n",
    "        # the output type of the third convolution layer are 48 regular feature fields of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 48*[self.r2_act.regular_repr])\n",
    "        self.block3 = nn.SequentialModule(\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        \n",
    "        # convolution 4\n",
    "        # the old output type is the input type to the next layer\n",
    "        in_type = self.block3.out_type\n",
    "        # the output type of the fourth convolution layer are 96 regular feature fields of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])\n",
    "        self.block4 = nn.SequentialModule(\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        self.pool2 = nn.SequentialModule(\n",
    "            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)\n",
    "        )\n",
    "        \n",
    "        # convolution 5\n",
    "        # the old output type is the input type to the next layer\n",
    "        in_type = self.block4.out_type\n",
    "        # the output type of the fifth convolution layer are 96 regular feature fields of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 96*[self.r2_act.regular_repr])\n",
    "        self.block5 = nn.SequentialModule(\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        \n",
    "        # convolution 6\n",
    "        # the old output type is the input type to the next layer\n",
    "        in_type = self.block5.out_type\n",
    "        # the output type of the sixth convolution layer are 64 regular feature fields of C8\n",
    "        out_type = nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr])\n",
    "        self.block6 = nn.SequentialModule(\n",
    "            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),\n",
    "            nn.InnerBatchNorm(out_type),\n",
    "            nn.ReLU(out_type, inplace=True)\n",
    "        )\n",
    "        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)\n",
    "        \n",
    "        self.gpool = nn.GroupPooling(out_type)\n",
    "        \n",
    "        # number of output channels\n",
    "        c = self.gpool.out_type.size\n",
    "        \n",
    "        # Fully Connected\n",
    "        self.fully_net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(c, 64),\n",
    "            torch.nn.BatchNorm1d(64),\n",
    "            torch.nn.ELU(inplace=True),\n",
    "            torch.nn.Linear(64, n_classes),\n",
    "        )\n",
    "    \n",
    "    def forward(self, input: torch.Tensor):\n",
    "        # wrap the input tensor in a GeometricTensor\n",
    "        # (associate it with the input type)\n",
    "        x = nn.GeometricTensor(input, self.input_type)\n",
    "        \n",
    "        # apply each equivariant block\n",
    "        \n",
    "        # Each layer has an input and an output type\n",
    "        # A layer takes a GeometricTensor in input.\n",
    "        # This tensor needs to be associated with the same representation of the layer's input type\n",
    "        #\n",
    "        # The Layer outputs a new GeometricTensor, associated with the layer's output type.\n",
    "        # As a result, consecutive layers need to have matching input/output types\n",
    "        x = self.block1(x)\n",
    "        x = self.block2(x)\n",
    "        x = self.pool1(x)\n",
    "        \n",
    "        x = self.block3(x)\n",
    "        x = self.block4(x)\n",
    "        x = self.pool2(x)\n",
    "        \n",
    "        x = self.block5(x)\n",
    "        x = self.block6(x)\n",
    "        \n",
    "        # pool over the spatial dimensions\n",
    "        x = self.pool3(x)\n",
    "        \n",
    "        # pool over the group\n",
    "        x = self.gpool(x)\n",
    "\n",
    "        # unwrap the output GeometricTensor\n",
    "        # (take the Pytorch tensor and discard the associated representation)\n",
    "        x = x.tensor\n",
    "        \n",
    "        # classify with the final fully connected layers)\n",
    "        x = self.fully_net(x.reshape(x.shape[0], -1))\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try the model on *rotated* MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File ‘mnist_rotation_new.zip’ already there; not retrieving.\r\n",
      "\r\n",
      "Archive:  mnist_rotation_new.zip\r\n"
     ]
    }
   ],
   "source": [
    "# download the dataset\n",
    "!wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip\n",
    "# uncompress the zip file\n",
    "!unzip -n mnist_rotation_new.zip -d mnist_rotation_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "from torchvision.transforms import RandomRotation\n",
    "from torchvision.transforms import Pad\n",
    "from torchvision.transforms import Resize\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision.transforms import Compose\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MnistRotDataset(Dataset):\n",
    "    \n",
    "    def __init__(self, mode, transform=None):\n",
    "        assert mode in ['train', 'test']\n",
    "            \n",
    "        if mode == \"train\":\n",
    "            file = \"mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat\"\n",
    "        else:\n",
    "            file = \"mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat\"\n",
    "        \n",
    "        self.transform = transform\n",
    "\n",
    "        data = np.loadtxt(file, delimiter=' ')\n",
    "            \n",
    "        self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)\n",
    "        self.labels = data[:, -1].astype(np.int64)\n",
    "        self.num_samples = len(self.labels)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        image, label = self.images[index], self.labels[index]\n",
    "        image = Image.fromarray(image)\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        return image, label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "# images are padded to have shape 29x29.\n",
    "# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model\n",
    "pad = Pad((0, 0, 1, 1), fill=0)\n",
    "\n",
    "# to reduce interpolation artifacts (e.g. when testing the model on rotated images),\n",
    "# we upsample an image by a factor of 3, rotate it and finally downsample it again\n",
    "resize1 = Resize(87)\n",
    "resize2 = Resize(29)\n",
    "\n",
    "totensor = ToTensor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's build the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = C8SteerableCNN().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is now randomly initialized. \n",
    "Therefore, we do not expect it to produce the right class probabilities.\n",
    "\n",
    "However, the model should still produce the same output for rotated versions of the same image.\n",
    "This is true for rotations by multiples of $\\frac{\\pi}{2}$, but is only approximate for rotations by $\\frac{\\pi}{4}$.\n",
    "\n",
    "Let's test it on a random test image:\n",
    "we feed eight rotated versions of the first image in the test set and print the output logits of the model for each of them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def test_model(model: torch.nn.Module, x: Image):\n",
    "    # evaluate the `model` on 8 rotated versions of the input image `x`\n",
    "    model.eval()\n",
    "    \n",
    "    wrmup = model(torch.randn(1, 1, 29, 29).to(device))\n",
    "    del wrmup\n",
    "    \n",
    "    x = resize1(pad(x))\n",
    "    \n",
    "    print()\n",
    "    print('##########################################################################################')\n",
    "    header = 'angle |  ' + '  '.join([\"{:6d}\".format(d) for d in range(10)])\n",
    "    print(header)\n",
    "    with torch.no_grad():\n",
    "        for r in range(8):\n",
    "            x_transformed = totensor(resize2(x.rotate(r*45., Image.BILINEAR))).reshape(1, 1, 29, 29)\n",
    "            x_transformed = x_transformed.to(device)\n",
    "\n",
    "            y = model(x_transformed)\n",
    "            y = y.to('cpu').numpy().squeeze()\n",
    "            \n",
    "            angle = r * 45\n",
    "            print(\"{:5d} : {}\".format(angle, y))\n",
    "    print('##########################################################################################')\n",
    "    print()\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the test set    \n",
    "raw_mnist_test = MnistRotDataset(mode='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "##########################################################################################\n",
      "angle |       0       1       2       3       4       5       6       7       8       9\n",
      "    0 : [-0.0209 -0.0258  0.033   0.0623 -0.0476 -0.0614 -0.0077 -0.0545 -0.0947 -0.0124]\n",
      "   45 : [-0.0206 -0.0258  0.0329  0.0624 -0.0477 -0.0613 -0.0076 -0.0544 -0.0946 -0.0124]\n",
      "   90 : [-0.0209 -0.0258  0.033   0.0623 -0.0476 -0.0614 -0.0077 -0.0545 -0.0947 -0.0124]\n",
      "  135 : [-0.0206 -0.0258  0.0329  0.0624 -0.0477 -0.0613 -0.0076 -0.0544 -0.0946 -0.0124]\n",
      "  180 : [-0.0209 -0.0258  0.033   0.0623 -0.0476 -0.0614 -0.0077 -0.0545 -0.0947 -0.0124]\n",
      "  225 : [-0.0206 -0.0258  0.0329  0.0624 -0.0477 -0.0613 -0.0076 -0.0544 -0.0946 -0.0124]\n",
      "  270 : [-0.0209 -0.0258  0.033   0.0623 -0.0476 -0.0614 -0.0077 -0.0545 -0.0947 -0.0124]\n",
      "  315 : [-0.0206 -0.0258  0.0329  0.0624 -0.0477 -0.0613 -0.0076 -0.0544 -0.0946 -0.0124]\n",
      "##########################################################################################\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# retrieve the first image from the test set\n",
    "x, y = next(iter(raw_mnist_test))\n",
    "\n",
    "# evaluate the model\n",
    "test_model(model, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output of the model is already almost invariant.\n",
    "However, we still observe small fluctuations in the outputs.\n",
    "\n",
    "This is because the model contains some operations which might break equivariance.\n",
    "For instance, every convolution includes a padding of $2$ pixels per side. This is adds information about the actual orientation of the grid where the image/feature map is sampled because the padding is not rotated with the image. \n",
    "\n",
    "During training, the model will observe rotated patterns and will learn to ignore the noise coming from the padding."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, let's train the model now.\n",
    "The model is exactly the same used to train a normal *PyTorch* architecture:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transform = Compose([\n",
    "    pad,\n",
    "    resize1,\n",
    "    RandomRotation(180, resample=Image.BILINEAR, expand=False),\n",
    "    resize2,\n",
    "    totensor,\n",
    "])\n",
    "\n",
    "mnist_train = MnistRotDataset(mode='train', transform=train_transform)\n",
    "train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64)\n",
    "\n",
    "\n",
    "test_transform = Compose([\n",
    "    pad,\n",
    "    totensor,\n",
    "])\n",
    "mnist_test = MnistRotDataset(mode='test', transform=test_transform)\n",
    "test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64)\n",
    "\n",
    "loss_function = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 | test accuracy: 89.628\n",
      "epoch 10 | test accuracy: 97.468\n",
      "epoch 20 | test accuracy: 98.016\n",
      "epoch 30 | test accuracy: 98.24000000000001\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for epoch in range(31):\n",
    "    model.train()\n",
    "    for i, (x, t) in enumerate(train_loader):\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        x = x.to(device)\n",
    "        t = t.to(device)\n",
    "\n",
    "        y = model(x)\n",
    "\n",
    "        loss = loss_function(y, t)\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "    \n",
    "    if epoch % 10 == 0:\n",
    "        total = 0\n",
    "        correct = 0\n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            for i, (x, t) in enumerate(test_loader):\n",
    "\n",
    "                x = x.to(device)\n",
    "                t = t.to(device)\n",
    "                \n",
    "                y = model(x)\n",
    "\n",
    "                _, prediction = torch.max(y.data, 1)\n",
    "                total += t.shape[0]\n",
    "                correct += (prediction == t).sum().item()\n",
    "        print(f\"epoch {epoch} | test accuracy: {correct/total*100.}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "##########################################################################################\n",
      "angle |       0       1       2       3       4       5       6       7       8       9\n",
      "    0 : [-1.36   -1.2538 -1.9034 -1.6643 -0.2062 -0.7859  7.2849 -0.3175 -1.3725 -0.8086]\n",
      "   45 : [-1.2159 -0.8674 -1.889  -1.7495  0.1019 -0.7801  7.4423 -0.5904 -1.5792 -1.4166]\n",
      "   90 : [-1.36   -1.2539 -1.9034 -1.6643 -0.2062 -0.7859  7.2849 -0.3175 -1.3725 -0.8086]\n",
      "  135 : [-1.2159 -0.8675 -1.889  -1.7495  0.1019 -0.7801  7.4423 -0.5904 -1.5792 -1.4166]\n",
      "  180 : [-1.36   -1.2539 -1.9034 -1.6643 -0.2062 -0.7859  7.2849 -0.3175 -1.3725 -0.8086]\n",
      "  225 : [-1.2159 -0.8675 -1.889  -1.7495  0.1019 -0.7801  7.4423 -0.5904 -1.5792 -1.4166]\n",
      "  270 : [-1.36   -1.2539 -1.9034 -1.6643 -0.2062 -0.7859  7.2849 -0.3175 -1.3725 -0.8086]\n",
      "  315 : [-1.2159 -0.8675 -1.889  -1.7495  0.1019 -0.7801  7.4423 -0.5904 -1.5792 -1.4166]\n",
      "##########################################################################################\n",
      "\n"
     ]
    }
   ],
   "source": [
    "    \n",
    "# retrieve the first image from the test set\n",
    "x, y = next(iter(raw_mnist_test))\n",
    "\n",
    "\n",
    "# evaluate the model\n",
    "test_model(model, x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}