{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ugdv4QO7-hs3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 386
        },
        "outputId": "9a683d28-ff00-4660-f5c4-0e937196d618"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7c6fdc4929b0>"
            ]
          },
          "metadata": {},
          "execution_count": 66
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 400x400 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAFfCAYAAAAPnATFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUyUlEQVR4nO3df2zUB/3H8Vd/2Cts1xO6Ftr0+KFDGeuKQIFgcT9cB1ZGnDHMLF2smGAkV0fXaOb5x/BH5PAPDU6x/AjCjCJoYrc5BQJMSoyrlBIMOMOG2+wxpHVu3JX+cZDeff8wnvbLHezT9t3PfY7nI/n80dvdPq8U8tzl2n0+BalUKiUAgJlCtwcAQL4jtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAseKJPmEymdTFixfl9/tVUFAw0acHgHGTSqU0ODio6upqFRZmf9864aG9ePGigsHgRJ8WAMxEo1HV1NRk/ecTHlq/3y9JqqmpueF/Adzw9NNPuz0hq7a2NrcnZPTmm2+6PSGriooKtydkdPnyZbcnZLV06VK3J2T0yU9+0u0JGSUSCW3dujXdtWwmPLT/+bigsLAw50I7efJktydklasfs5SVlbk9ISu+Z84VFRW5PSEjn8/n9oQbutnftdwqHQDkIUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAsVGFduvWrZo1a5ZKS0u1dOlSnThxYrx3AUDecBza/fv3q729XRs3btSpU6c0f/58rVy5UgMDAxb7AMDzHIf2+9//vtatW6e1a9dq3rx52rZtmyZPnqyf/OQnFvsAwPMchfbq1avq7e1VY2Pjf/8FhYVqbGzUyy+/nPE1iURC8Xh8xAEAtxJHoX377bc1PDysadOmjXh82rRpunTpUsbXRCIRBQKB9MH9wgDcasx/6yAcDisWi6WPaDRqfUoAyCmO7hl2xx13qKioSP39/SMe7+/v1/Tp0zO+xufz5fz9fgDAkqN3tCUlJVq0aJGOHj2afiyZTOro0aNatmzZuI8DgHzg+C647e3tamlpUX19vZYsWaItW7ZoaGhIa9eutdgHAJ7nOLSf/exn9c9//lNPP/20Ll26pI985CM6ePDgdT8gAwD8m+PQSlJra6taW1vHewsA5CWudQAAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYI7QAYKwglUqlJvKE8XhcgUBAzc3NKikpmchT39Ts2bPdnpBVVVWV2xMy+sQnPuH2hKx++tOfuj0ho76+PrcnZFVTU+P2hIy+/vWvuz0ho3g8rilTpigWi6msrCzr83hHCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGDMcWiPHz+u1atXq7q6WgUFBXruuecMZgFA/nAc2qGhIc2fP19bt2612AMAeafY6QuamprU1NRksQUA8pLj0DqVSCSUSCTSX8fjcetTAkBOMf9hWCQSUSAQSB/BYND6lACQU8xDGw6HFYvF0kc0GrU+JQDkFPOPDnw+n3w+n/VpACBn8Xu0AGDM8TvaK1eu6Pz58+mv33jjDZ0+fVpTp07VjBkzxnUcAOQDx6E9efKkHnjggfTX7e3tkqSWlhbt2bNn3IYBQL5wHNr7779fqVTKYgsA5CU+owUAY4QWAIwRWgAwRmgBwBihBQBjhBYAjBFaADBGaAHAGKEFAGOEFgCMEVoAMEZoAcCY+YW/s5k1a5ZKS0vdOn1GuXxn39raWrcnZNTd3e32hKy++tWvuj0ho/7+frcnZLVw4UK3J2T0ta99ze0JGf3v/RBvhHe0AGCM0AKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYchTYSiWjx4sXy+/2qrKzUI488onPnzlltA4C84Ci0XV1dCoVC6u7u1uHDh3Xt2jWtWLFCQ0NDVvsAwPMc3WHh4MGDI77es2ePKisr1dvbq3vvvXdchwFAvhjTrWxisZgkaerUqVmfk0gkRtzuIR6Pj+WUAOA5o/5hWDKZVFtbmxoaGm54P6tIJKJAIJA+gsHgaE8JAJ406tCGQiGdPXtW+/btu+HzwuGwYrFY+ohGo6M9JQB40qg+OmhtbdWLL76o48ePq6am5obP9fl88vl8oxoHAPnAUWhTqZS+/OUvq7OzU8eOHdPs2bOtdgFA3nAU2lAopL179+r555+X3+/XpUuXJEmBQECTJk0yGQgAXufoM9qOjg7FYjHdf//9qqqqSh/79++32gcAnuf4owMAgDNc6wAAjBFaADBGaAHAGKEFAGOEFgCMEVoAMEZoAcAYoQUAY4QWAIwRWgAwRmgBwNiYbmUzFhcuXFBJSYlbp8+os7PT7QlZ9ff3uz0ho1deecXtCVl95jOfcXtCRuXl5W5PyGrDhg1uT8ho165dbk/I6L1e/4V3tABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGHIW2o6NDdXV1KisrU1lZmZYtW6YDBw5YbQOAvOAotDU1Ndq8ebN6e3t18uRJffzjH9enPvUp/eUvf7HaBwCe5+gOC6tXrx7x9Xe+8x11dHSou7tbd99997gOA4B8Mepb2QwPD+tXv/qVhoaGtGzZsqzPSyQSSiQS6a/j8fhoTwkAnuT4h2FnzpzR7bffLp/Ppy996Uvq7OzUvHnzsj4/EokoEAikj2AwOKbBAOA1jkP74Q9/WKdPn9af/vQnrV+/Xi0tLTe8QV84HFYsFksf0Wh0TIMBwGscf3RQUlKiO++8U5K0aNEi9fT06Ac/+IG2b9+e8fk+n08+n29sKwHAw8b8e7TJZHLEZ7AAgJEcvaMNh8NqamrSjBkzNDg4qL179+rYsWM6dOiQ1T4A8DxHoR0YGNDnPvc5/eMf/1AgEFBdXZ0OHTqkhx56yGofAHieo9Du2rXLagcA5C2udQAAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYI7QAYIzQAoCxUd8zbKx6e3tVVFTk1ukzOnPmjNsTspo5c6bbEzK6evWq2xOy2rJli9sTMvrtb3/r9oSsnnnmGbcn5CXe0QKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYG1NoN2/erIKCArW1tY3THADIP6MObU9Pj7Zv3666urrx3AMAeWdUob1y5Yqam5u1c+dOTZkyZbw3AUBeGVVoQ6GQVq1apcbGxps+N5FIKB6PjzgA4Fbi+J5h+/bt06lTp9TT0/Oenh+JRPTNb37T8TAAyBeO3tFGo1Ft2LBBP//5z1VaWvqeXhMOhxWLxdJHNBod1VAA8CpH72h7e3s1MDCghQsXph8bHh7W8ePH9aMf/UiJROK6O9v6fD75fL7xWQsAHuQotA8++OB1t+Reu3at5s6dq6eeeirnbh8OALnAUWj9fr9qa2tHPHbbbbepvLz8uscBAP/G/xkGAMYc/9bB/3fs2LFxmAEA+Yt3tABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYI7QAYIzQAoAxQgsAxsZ8rYPRynTtWrf9+c9/dntCVt3d3W5P8Jx169a5PSGj5cuXuz0hq5deesntCRn99a9/dXtCRoODg7rzzjtv+jze0QKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYcxTab3zjGyooKBhxzJ0712obAOQFxxf+vvvuu3XkyJH//guKXbt2OAB4guNKFhcXa/r06RZbACAvOf6M9rXXXlN1dbU+8IEPqLm5WX19fTd8fiKRUDweH3EAwK3EUWiXLl2qPXv26ODBg+ro6NAbb7yhj33sYxocHMz6mkgkokAgkD6CweCYRwOAlzgKbVNTk9asWaO6ujqtXLlSv/vd73T58mX98pe/zPqacDisWCyWPqLR6JhHA4CXjOknWe9///v1oQ99SOfPn8/6HJ/PJ5/PN5bTAICnjen3aK9cuaK//e1vqqqqGq89AJB3HIX2K1/5irq6uvTmm2/qj3/8oz796U+rqKhIjz32mNU+APA8Rx8dXLhwQY899pj+9a9/qaKiQsuXL1d3d7cqKiqs9gGA5zkK7b59+6x2AEDe4loHAGCM0AKAMUILAMYILQAYI7QAYIzQAoAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCtIpVKpiTxhPB5XIBDQ/v37NXny5Ik89U2tXr3a7QkYR/fee6/bEzJatGiR2xOymjp1qtsTMnrnnXfcnpBRIpHQj3/8Y8ViMZWVlWV9Hu9oAcAYoQUAY4QWAIwRWgAwRmgBwBihBQBjhBYAjBFaADBGaAHAGKEFAGOEFgCMEVoAMEZoAcAYoQUAY4QWAIw5Du1bb72lxx9/XOXl5Zo0aZLuuecenTx50mIbAOSFYidPfvfdd9XQ0KAHHnhABw4cUEVFhV577TVNmTLFah8AeJ6j0H73u99VMBjU7t2704/Nnj173EcBQD5x9NHBCy+8oPr6eq1Zs0aVlZVasGCBdu7cecPXJBIJxePxEQcA3Eochfb1119XR0eH5syZo0OHDmn9+vV64okn9Oyzz2Z9TSQSUSAQSB/BYHDMowHASxyFNplMauHChdq0aZMWLFigL37xi1q3bp22bduW9TXhcFixWCx9RKPRMY8GAC9xFNqqqirNmzdvxGN33XWX+vr6sr7G5/OprKxsxAEAtxJHoW1oaNC5c+dGPPbqq69q5syZ4zoKAPKJo9A++eST6u7u1qZNm3T+/Hnt3btXO3bsUCgUstoHAJ7nKLSLFy9WZ2enfvGLX6i2tlbf/va3tWXLFjU3N1vtAwDPc/R7tJL08MMP6+GHH7bYAgB5iWsdAIAxQgsAxggtABgjtABgjNACgDFCCwDGCC0AGCO0AGCM0AKAMUILAMYILQAYc3ytg/EyZ84c3X777W6dHreAd955x+0JGX3rW99ye0JWjz76qNsTMsrVGwYMDw+/p+fxjhYAjBFaADBGaAHAGKEFAGOEFgCMEVoAMEZoAcAYoQUAY4QWAIwRWgAwRmgBwBihBQBjhBYAjBFaADBGaAHAmKPQzpo1SwUFBdcdoVDIah8AeJ6jC3/39PSMuNDt2bNn9dBDD2nNmjXjPgwA8oWj0FZUVIz4evPmzfrgBz+o++67b1xHAUA+GfWtbK5evaqf/exnam9vV0FBQdbnJRIJJRKJ9NfxeHy0pwQATxr1D8Oee+45Xb58WZ///Odv+LxIJKJAIJA+gsHgaE8JAJ406tDu2rVLTU1Nqq6uvuHzwuGwYrFY+sjVm6wBgJVRfXTw97//XUeOHNGvf/3rmz7X5/PJ5/ON5jQAkBdG9Y529+7dqqys1KpVq8Z7DwDkHcehTSaT2r17t1paWlRcPOqfpQHALcNxaI8cOaK+vj594QtfsNgDAHnH8VvSFStWKJVKWWwBgLzEtQ4AwBihBQBjhBYAjBFaADBGaAHAGKEFAGOEFgCMEVoAMEZoAcAYoQUAY4QWAIxN+OW3/nOdhCtXrkz0qXGL+d8bieaSXL6d07Vr19yekFGu/ln+Z9fNrv9SkJrgK8RcuHCB29kAyCvRaFQ1NTVZ//mEhzaZTOrixYvy+/03vKnjexGPxxUMBhWNRlVWVjZOC/Mb3zPn+J45d6t8z1KplAYHB1VdXa3CwuyfxE74RweFhYU3LP9olJWV5fUfpgW+Z87xPXPuVvieBQKBmz6HH4YBgDFCCwDGPB1an8+njRs3cpddB/ieOcf3zDm+ZyNN+A/DAOBW4+l3tADgBYQWAIwRWgAwRmgBwBihBQBjng3t1q1bNWvWLJWWlmrp0qU6ceKE25NyViQS0eLFi+X3+1VZWalHHnlE586dc3uWp2zevFkFBQVqa2tze0pOe+utt/T444+rvLxckyZN0j333KOTJ0+6Pct1ngzt/v371d7ero0bN+rUqVOaP3++Vq5cqYGBAben5aSuri6FQiF1d3fr8OHDunbtmlasWKGhoSG3p3lCT0+Ptm/frrq6Oren5LR3331XDQ0Net/73qcDBw7olVde0fe+9z1NmTLF7WnuS3nQkiVLUqFQKP318PBwqrq6OhWJRFxc5R0DAwMpSamuri63p+S8wcHB1Jw5c1KHDx9O3XfffakNGza4PSlnPfXUU6nly5e7PSMnee4d7dWrV9Xb26vGxsb0Y4WFhWpsbNTLL7/s4jLviMVikqSpU6e6vCT3hUIhrVq1asTfN2T2wgsvqL6+XmvWrFFlZaUWLFignTt3uj0rJ3gutG+//baGh4c1bdq0EY9PmzZNly5dcmmVdySTSbW1tamhoUG1tbVuz8lp+/bt06lTpxSJRNye4gmvv/66Ojo6NGfOHB06dEjr16/XE088oWeffdbtaa6b8Mskwl2hUEhnz57VH/7wB7en5LRoNKoNGzbo8OHDKi0tdXuOJySTSdXX12vTpk2SpAULFujs2bPatm2bWlpaXF7nLs+9o73jjjtUVFSk/v7+EY/39/dr+vTpLq3yhtbWVr344ov6/e9/P+7XBM43vb29GhgY0MKFC1VcXKzi4mJ1dXXpmWeeUXFxcc7eWsVNVVVVmjdv3ojH7rrrLvX19bm0KHd4LrQlJSVatGiRjh49mn4smUzq6NGjWrZsmYvLclcqlVJra6s6Ozv10ksvafbs2W5PynkPPvigzpw5o9OnT6eP+vp6NTc36/Tp0yoqKnJ7Ys5paGi47tcGX331Vc2cOdOlRbnDkx8dtLe3q6WlRfX19VqyZIm2bNmioaEhrV271u1pOSkUCmnv3r16/vnn5ff7059lBwIBTZo0yeV1ucnv91/3GfZtt92m8vJyPtvO4sknn9RHP/pRbdq0SY8++qhOnDihHTt2aMeOHW5Pc5/bv/YwWj/84Q9TM2bMSJWUlKSWLFmS6u7udntSzpKU8di9e7fb0zyFX++6ud/85jep2tralM/nS82dOze1Y8cOtyflBK5HCwDGPPcZLQB4DaEFAGOEFgCMEVoAMEZoAcAYoQUAY4QWAIwRWgAwRmgBwBihBQBjhBYAjP0fx2ZX2z3r6EsAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "import random\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import os\n",
        "from PIL import Image, ImageDraw\n",
        "\n",
        "def create_image(colour='white', size=(224,224), side=10, x0=None, y0=None):\n",
        "  ## Random noise background\n",
        "  image = Image.new('L', size, 'white')\n",
        "  draw = ImageDraw.Draw(image)\n",
        "  for x in range(size[0]):\n",
        "    for y in range(size[1]):\n",
        "      color = (random.randint(0, 255))\n",
        "      draw.point((x, y), fill=color)\n",
        "  ## Adding a square shape\n",
        "  x0 = x0 if x0 else random.randint(0, size[0] - side)\n",
        "  y0 = y0 if y0 else random.randint(0, size[1] - side)\n",
        "  x1 = x0 + side\n",
        "  y1 = y0 + side\n",
        "  draw.rectangle((x0, y0, x1 - 1, y1 - 1), fill=colour)\n",
        "  return image\n",
        "\n",
        "#image = create_image(colour='white', size=(39,39), side=16)\n",
        "image = create_image(colour='black', size=(8, 8), side=3)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "fig = plt.figure(figsize=(4,4))\n",
        "plt.imshow(image, cmap='gray')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import numpy as np\n",
        "\n",
        "def generate_dataset(dataset_name, classes = ['white', 'black'], image_size=39, side=16, num_images_per_class=20, x0=None, y0=None):\n",
        "  images = []\n",
        "  labels = []\n",
        "\n",
        "  if not os.path.exists(dataset_name):\n",
        "    os.makedirs(dataset_name)\n",
        "\n",
        "  for c in classes:\n",
        "    if not os.path.exists(f'{dataset_name}/class_{c}'):\n",
        "      os.makedirs(f'{dataset_name}/class_{c}')\n",
        "\n",
        "  for c in classes:\n",
        "    for i in range(num_images_per_class):\n",
        "      image = create_image(c, size=(image_size, image_size), side=side, x0=x0, y0=y0)\n",
        "      image.save(f'{dataset_name}/class_{c}/{i}.png')\n",
        "\n",
        "      images.append(np.array(image).reshape(1, image_size, image_size))\n",
        "      labels.append(0 if c == 'white' else 1)\n",
        "\n",
        "  return images, labels\n",
        "image_size = 4 #3\n",
        "side = 2\n",
        "images, labels = generate_dataset('dataset', image_size=image_size, side=side, num_images_per_class=300)\n",
        "\n",
        "class SimpleDataset(torch.utils.data.Dataset):\n",
        "  def __init__(self, images, labels):\n",
        "        self.images = images\n",
        "        self.labels = labels\n",
        "\n",
        "  def __len__(self):\n",
        "      return len(self.images)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "      if torch.is_tensor(idx):\n",
        "          idx = idx.tolist()\n",
        "      image = self.images[idx]\n",
        "      label = self.labels[idx]\n",
        "      return torch.tensor(image).to(torch.float)/255, torch.tensor(label)\n",
        "\n",
        "dataset = SimpleDataset(images, labels)\n",
        "dataset_images, dataset_labels = dataset[:]\n",
        "\n",
        "plt.figure(figsize=(15,10))\n",
        "grid = torchvision.utils.make_grid(nrow=10, tensor=dataset_images[:10])\n",
        "print(f\"image tensor: {dataset_images[:10].shape}\")\n",
        "print(f\"labels: {dataset_labels[:10]}\")\n",
        "plt.axis('off')\n",
        "plt.imshow(np.transpose(grid, axes=(1,2,0)), cmap='gray');\n"
      ],
      "metadata": {
        "id": "b3cRRmUC-jqh",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "outputId": "2030a502-fbc4-4126-d15b-376dce64808b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "image tensor: torch.Size([10, 1, 4, 4])\n",
            "labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1500x1000 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAABJ4AAACqCAYAAADst4VCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIiklEQVR4nO3az4vN7R/H8WsYRbGgKTFCkcUkC4lCKJRSspENSSNLmqhpLMjCxmbSLCz8WFiQhVDYKCyUX1nZSFKaFRs/Nixwvv/AXfd88351Dvfjsf706mqaM5/Tc66+TqfTaQAAAABQbFq3DwAAAADA30l4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgon+qD/b19SXPAQAAAMAfpNPp/OszbjwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQ0d/tA/yOc+fOlW2Njo6WbX369Klsa8WKFSU7k5OTJTvVzp49W7b19u3bsq2fP3+W7Bw7dqxkp7XWhoaGyrYqff78uWRnYmKiZKe11l68eFG2deHChbKt+fPnl21V6XQ63T7CP7p27VrJzo4dO0p2Wmtt3rx5ZVuVNm3aVLKzdevWkp3WWrt+/XrZVtV7sLXWbt26VbZVZdmyZWVbY2NjZVvj4+MlO0ePHi3Zaa21w4cPl21V2rJlS9nWw4cPy7aqrF27tmyr8v1cpVffgxcvXizbmjat5q7B8PBwyU613bt3l23t3bu3ZGfbtm0lO621NmvWrLKt2bNnl21V6dXP4LNnz8q2Hj9+XLJz/Pjxkp1ucOMJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACAiP5uH+B37Nq1q2zr6tWrZVtXrlwp21q+fHnJzuTkZMlOtZcvX5ZtXbhwoWxrzpw5ZVt/uwULFpTsbNmypWSntdbu3btXtkV3PH/+vGRn5syZJTu9bGJiomRn1apVJTuttXby5MmyrZ07d5Zt9aLz58+XbQ0NDZVtrV69umTn1atXJTu97MCBA90+QtSzZ8/KtqZN8z/vqTpz5kzZ1o8fP8q2etHBgwfLtubOnVuyMzAwULJD9yxatKhsa3BwsGzrT+WvPwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEf3dPsDvWLp0adnW06dPy7ZOnz5dtjV9+vSyrV60fv36sq1fv36VbTF1ly9fLtn58OFDyQ7/n0uXLpVtPXz4sGxrZGSkZOfLly8lO71syZIl3T5C1L59+8q27t27V7ZVZfv27d0+wj8aHh4u2RkbGyvZ6WUvXrwo26r6ua9bt65kp7XWNmzYULbVi0ZHR8u2Nm7cWLZV+Xt18eLFkp1e/Tw/efKkbGvz5s1lW0zNnj17yrZev35dtjU0NFS29V/4Pvpv3HgCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgQngCAAAAIEJ4AgAAACBCeAIAAAAgor/bB/gbzZgxo2zr3bt3ZVu96MGDB2VbS5YsKdtavnx5yc7NmzdLdnrZmzdvSnZu375dstNaa3fv3i3bWrlyZdlWLxodHS3bGhkZKdsaHx8v2fn+/XvJTi/7+vVryc6aNWtKdlprbXBwsGxr4cKFZVtM3f3790t2vn37VrLTy/bv31+2tWjRopKd9+/fl+y01tqjR4/KtnrR1q1by7Y+fvxYtjUwMFC2derUqbKtXnTixImyrSNHjpTsVH7/OHToUNlWL7pz507Z1uLFi8u2+vr6yrZu3LhRsjN79uySnW5w4wkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAICIvk6n05nSg3196bMAAAAA8IeYSlJy4wkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACAiP6pPtjpdJLnAAAAAOAv48YTAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEf8DzBnl+L6dgOQAAAAASUVORK5CYII=\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from sklearn.model_selection import train_test_split\n",
        "from torch.utils.data import Subset\n",
        "\n",
        "train_idx, validation_idx = train_test_split(np.arange(len(dataset)), test_size=0.66,\n",
        "                                             random_state=999, shuffle=True, stratify=dataset.labels)\n",
        "\n",
        "train_dataset = Subset(dataset, train_idx)\n",
        "validation_dataset = Subset(dataset, validation_idx)\n",
        "\n",
        "class SimpleNetwork(nn.Module): # extend nn.Module class of nn\n",
        "    def __init__(self, num_filters):\n",
        "        super().__init__() # super class constructor\n",
        "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(2, 2))\n",
        "        self.fc1 = nn.Linear(in_features=num_filters*(image_size-side+1)*(image_size-side+1), out_features=2)\n",
        "        self.num_filters = num_filters\n",
        "    def forward(self, t): # implements the forward method (flow of tensors)\n",
        "        # hidden conv layer\n",
        "        t = self.conv1(t)\n",
        "        t = F.relu(t)\n",
        "        t = t.reshape(-1, self.num_filters*(image_size-side+1)*(image_size-side+1))\n",
        "        t = self.fc1(t)\n",
        "        return t\n",
        "    def forward_intermediate(self, t, use_relu=True):\n",
        "        if use_relu:\n",
        "          t = F.relu(t)\n",
        "        t = t.reshape(-1, self.num_filters*(image_size-side+1)*(image_size-side+1))\n",
        "        print(\"t\", t)\n",
        "        print(\"fc1\", self.fc1)\n",
        "        t = self.fc1(t)\n",
        "        return t\n",
        "\n",
        "def get_item(preds, labels):\n",
        "    \"\"\"function that returns the accuracy of our architecture\"\"\"\n",
        "    return preds.argmax(dim=1).eq(labels).sum().item()"
      ],
      "metadata": {
        "id": "vPrgfucuEVaR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "\n",
        "try:\n",
        "  del cnn_model\n",
        "except:\n",
        "  print(\"no CNN model\")\n",
        "\n",
        "num_filters = 1\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "cnn_model = SimpleNetwork(num_filters).to(device)\n",
        "optimizer = torch.optim.SGD(cnn_model.parameters(), lr=0.05)\n",
        "#optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.01)\n",
        "\n",
        "accuracy_epoch = []\n",
        "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=1)\n",
        "test_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=2, shuffle=False, num_workers=1)\n",
        "\n",
        "for epoch in range(100):\n",
        "  start_time = time.time()\n",
        "  total_correct = 0\n",
        "  total_loss = 0\n",
        "  for batch in train_dataloader:\n",
        "      imgs, lbls = batch\n",
        "      imgs, lbls = imgs.to(device), lbls.to(device)\n",
        "      preds = cnn_model(imgs) # get preds\n",
        "      loss = F.cross_entropy(preds, lbls) # compute loss\n",
        "      optimizer.zero_grad() # zero grads\n",
        "      loss.backward() # calculates gradients\n",
        "      optimizer.step() # update the weights\n",
        "      total_loss += loss.item()\n",
        "\n",
        "  for batch in test_dataloader:\n",
        "      imgs, lbls = batch\n",
        "      imgs, lbls = imgs.to(device), lbls.to(device)\n",
        "      preds = cnn_model(imgs)\n",
        "      total_correct += get_item(preds, lbls)\n",
        "      accuracy = total_correct/len(validation_dataset)\n",
        "\n",
        "  accuracy_epoch.append(accuracy)\n",
        "  end_time = time.time() - start_time\n",
        "  if epoch == 0 or (epoch + 1) % 5 == 0:\n",
        "    print(\"Epoch no.\",epoch+1 ,\"|test accuracy: \", round(accuracy*100, 3),\"%\", \"|total_loss: \", total_loss, \"| epoch_duration: \", round(end_time,2),\"sec\")\n"
      ],
      "metadata": {
        "id": "ABTGrP2VGxWo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "836fd673-2eb1-41b0-c307-3fdf8c0f9e53"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch no. 1 |test accuracy:  48.737 % |total_loss:  14.57321846485138 | epoch_duration:  0.7 sec\n",
            "Epoch no. 5 |test accuracy:  50.758 % |total_loss:  14.569975435733795 | epoch_duration:  0.68 sec\n",
            "Epoch no. 10 |test accuracy:  50.253 % |total_loss:  14.562027871608734 | epoch_duration:  0.69 sec\n",
            "Epoch no. 15 |test accuracy:  49.242 % |total_loss:  14.579477548599243 | epoch_duration:  1.0 sec\n",
            "Epoch no. 20 |test accuracy:  49.242 % |total_loss:  14.569651782512665 | epoch_duration:  0.69 sec\n",
            "Epoch no. 25 |test accuracy:  51.263 % |total_loss:  14.555954933166504 | epoch_duration:  0.68 sec\n",
            "Epoch no. 30 |test accuracy:  49.747 % |total_loss:  14.542636454105377 | epoch_duration:  0.74 sec\n",
            "Epoch no. 35 |test accuracy:  50.505 % |total_loss:  14.541406989097595 | epoch_duration:  0.69 sec\n",
            "Epoch no. 40 |test accuracy:  66.162 % |total_loss:  13.527920722961426 | epoch_duration:  0.7 sec\n",
            "Epoch no. 45 |test accuracy:  83.586 % |total_loss:  8.198082581162453 | epoch_duration:  0.7 sec\n",
            "Epoch no. 50 |test accuracy:  90.404 % |total_loss:  5.248928517103195 | epoch_duration:  0.68 sec\n",
            "Epoch no. 55 |test accuracy:  94.192 % |total_loss:  3.6010555401444435 | epoch_duration:  0.68 sec\n",
            "Epoch no. 60 |test accuracy:  94.444 % |total_loss:  2.8058285173028708 | epoch_duration:  0.7 sec\n",
            "Epoch no. 65 |test accuracy:  94.697 % |total_loss:  2.3811375871300697 | epoch_duration:  0.92 sec\n",
            "Epoch no. 70 |test accuracy:  94.949 % |total_loss:  2.1173205859959126 | epoch_duration:  0.69 sec\n",
            "Epoch no. 75 |test accuracy:  94.949 % |total_loss:  1.9241018779575825 | epoch_duration:  0.7 sec\n",
            "Epoch no. 80 |test accuracy:  95.707 % |total_loss:  1.668175695464015 | epoch_duration:  0.69 sec\n",
            "Epoch no. 85 |test accuracy:  95.96 % |total_loss:  1.6005523800849915 | epoch_duration:  0.7 sec\n",
            "Epoch no. 90 |test accuracy:  95.96 % |total_loss:  1.4637939669191837 | epoch_duration:  0.7 sec\n",
            "Epoch no. 95 |test accuracy:  95.96 % |total_loss:  1.3791138725355268 | epoch_duration:  0.69 sec\n",
            "Epoch no. 100 |test accuracy:  94.949 % |total_loss:  1.5754989720880985 | epoch_duration:  0.97 sec\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Intermediate outputs for images with squares in the same position"
      ],
      "metadata": {
        "id": "hIkGvXx3EUFe"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Generating same position square dataset"
      ],
      "metadata": {
        "id": "CyLRkFe4EZpZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torchvision\n",
        "import numpy as np\n",
        "\n",
        "#images, labels = generate_dataset('same_position', image_size=image_size, side=side, num_images_per_class=10, x0=1, y0=1)\n",
        "images, labels = generate_dataset('same_position', image_size=image_size, side=side, num_images_per_class=10, x0=None, y0=None)\n",
        "num_white_images = len(images)//2\n",
        "white_images = images[:num_white_images]\n",
        "white_labels = labels[:num_white_images]\n",
        "black_images = images[num_white_images:]\n",
        "black_labels = labels[num_white_images:]\n",
        "white_dataset = SimpleDataset(white_images, white_labels)\n",
        "black_dataset = SimpleDataset(black_images, black_labels)\n",
        "dataset_images, dataset_labels = black_dataset[:]\n",
        "\n",
        "plt.figure(figsize=(15,10))\n",
        "grid = torchvision.utils.make_grid(nrow=10, tensor=dataset_images[:10])\n",
        "grid = 1 - grid\n",
        "print(f\"image tensor: {dataset_images[:10].shape}\")\n",
        "print(f\"labels: {dataset_labels[:10]}\")\n",
        "# Set the background color to white\n",
        "plt.gca().set_facecolor('white')\n",
        "# Turn off the axis\n",
        "plt.axis('off')\n",
        "# Display the image with a reversed grayscale colormap (white background)\n",
        "plt.imshow(np.transpose(grid, axes=(1,2,0)), cmap='gray_r');  # 'gray_r' is the reversed colormap"
      ],
      "metadata": {
        "id": "bKjmomgDEt2o",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "outputId": "de245542-4ecb-4dd2-e851-a9d82c414c35"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "image tensor: torch.Size([10, 1, 4, 4])\n",
            "labels: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1500x1000 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAABJ4AAACqCAYAAADst4VCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIc0lEQVR4nO3a32vO/x/H8fdYOZDtRDEHS+3AAcWRpkg7UJLW4owTaeZEHCmWciCyRO2M0rBQWq3IgZQfcbQcsdKOFEVzMEUptXJ9/oHvge/X89G17z632x/w6HW967quV/feHa1Wq9UAAAAAQLEV7T4AAAAAAMuT8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAEBEZ7sP8Dc+f/5cttXT01O2tWKFnvenLl++XLZ15syZsq0qX79+Ldtat25d2Valubm5kp2nT5+W7DRN0+zatats6+PHj2VbQ0NDZVtVJiYmyrbev39ftrVhw4aSnTt37pTsNE3TvH37tmyr0sDAQMnOw4cPS3aapmmeP39ettXZWXdV2b9/f9lWlZmZmbKtrq6usq0bN26U7IyPj5fsNE3TtFqtsq1K9+/fL9uqust8+vSpZOff4ObNm2Vblf+D8/PzZVuLi4slO1NTUyU7/wabNm0q2zp9+nTZ1vDwcNlWlcOHD5dtLSwslG1NT0+XbVXdIXfs2FGy0w4KCQAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABGd7T7A31izZk3Z1pEjR8q2Jicny7aWu66urnYfIWpwcLBsa2Zmpmyr0tjYWMnO+Ph4yU7TNM2TJ0/Ktg4ePFi2tRQdPXq03Uf4jw4dOlSy8+bNm5KdpezYsWMlO/Pz8yU7TdM0Hz58KNsaHh4u21qKFhYWyrY2btxYtrW4uFiyMzU1VbKzlN2+fbtsa+XKlWVb/JmRkZGyrdHR0bKt7u7usq0HDx6UbfFnjh8/Xrb1+/fvsq2lqLe3t2yr8t73/fv3sq2zZ8+W7Lx8+bJkpx288QQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQERHq9VqtfsQ/6vu7u6yrXv37pVtnTp1qmzrxYsXJTu9vb0lO9XOnz9ftjU/P1+2de7cuZKdzZs3l+w0TdP8+PGjbKvS9PR0yc6BAwdKdqqtXr26bOvnz59lW8vd7OxsyU7lM+/v7y/bqtTR0VGys3PnzpKdpmma169fl20td48fPy7bqnzuY2NjZVvL3cTERNnW4OBgyc7WrVtLdpqmaV69elW21dfXV7ZV5du3b2VbV69eLdu6ePFi2dZyd+LEibKtR48elezs3bu3ZKdpmubSpUtlW2vXri3bWu6uXbtWttXZ2Vmyc/LkyZKddvDGEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABHCEwAAAAARwhMAAAAAEcITAAAAABGd7T7A37hy5UrZ1pYtW8q2Lly4ULZ19+7dkp3R0dGSnWqrVq0q25qbmyvbun79esnO7t27S3aWsn379rX7CFFDQ0PtPkLU5ORk2dbMzEzZ1q1bt0p2Kn/7+vv7y7Yq7dmzp2Tny5cvJTv8d3p6esq2+vr6yrb4c8+ePSvbevfuXclO5fe56k7UNLV39yqVn2/16tVlW/y5kZGRsq3t27eX7Kxfv75kp2maZmBgoGxrdna2bGu5+/XrV9nWtm3byrb+X3njCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgAjhCQAAAIAI4QkAAACACOEJAAAAgIiOVqvVavchAAAAAFh+vPEEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAECE8AQAAABAhPAEAAAAQITwBAAAAEDEPxw5860fFezHAAAAAElFTkSuQmCC\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from functools import partial\n",
        "\n",
        "for name, param in cnn_model.named_parameters():\n",
        "    print(f\"{name}: {param.shape}\")\n",
        "    print(param.data)\n",
        "#cnn_model.conv1.weight.data\n",
        "\n",
        "def hook_fn(activations, module, input, output):\n",
        "  activations.append({})\n",
        "  print(f\"Layer {module.__class__.__name__}:\")\n",
        "  #print(f\"  Input: {input}\")\n",
        "  print(f\"  Output {[output.shape]}: {output}\")\n",
        "  activations[-1][f\"{module.__class__.__name__}\"] = output\n",
        "\n",
        "# Register hooks\n",
        "activations = []\n",
        "conv1_hook = cnn_model.conv1.register_forward_hook(partial(hook_fn, activations))\n",
        "fc1_hook = cnn_model.fc1.register_forward_hook(partial(hook_fn, activations))\n",
        "\n",
        "with torch.no_grad():\n",
        "  dataloader1 = torch.utils.data.DataLoader(white_dataset, batch_size=1, shuffle=True, num_workers=1)\n",
        "  dataloader2 = torch.utils.data.DataLoader(black_dataset, batch_size=1, shuffle=True, num_workers=1)\n",
        "  for dataloader in [dataloader1, dataloader2]:\n",
        "    count = 0\n",
        "    for batch in dataloader:\n",
        "      print(\"new batch\")\n",
        "      imgs, lbls = batch\n",
        "      imgs, lbls = imgs.to(device), lbls.to(device)\n",
        "      print(\"imgs shape\", imgs.shape, \"img\", imgs, \"label\", lbls)\n",
        "      preds = cnn_model(imgs) # get preds\n",
        "      print(\"prediction\", preds, \"label\", lbls)\n",
        "      count += 1\n",
        "      if count > 4:\n",
        "        break\n",
        "# Remove hooks\n",
        "conv1_hook.remove()\n",
        "fc1_hook.remove()"
      ],
      "metadata": {
        "id": "hMrp9p1aIU0Y",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4510c149-a093-4cf1-c73e-fce0d8a3de87"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "conv1.weight: torch.Size([1, 1, 2, 2])\n",
            "tensor([[[[1.0190, 1.0396],\n",
            "          [1.3159, 0.7619]]]])\n",
            "conv1.bias: torch.Size([1])\n",
            "tensor([-1.6514])\n",
            "fc1.weight: torch.Size([2, 9])\n",
            "tensor([[ 0.0908,  0.4424,  0.5132,  0.7442,  1.2633,  0.7397, -0.0564,  0.7359,\n",
            "          0.1787],\n",
            "        [-0.2317, -0.3863, -0.3313, -0.6362, -1.2957, -0.5972, -0.4600, -0.4999,\n",
            "         -0.0515]])\n",
            "fc1.bias: torch.Size([2])\n",
            "tensor([-1.9681,  2.5986])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.1765, 0.1333, 0.4588, 0.2078],\n",
            "          [0.3020, 0.1843, 1.0000, 1.0000],\n",
            "          [0.3216, 0.6000, 1.0000, 1.0000],\n",
            "          [0.0471, 0.7216, 0.5647, 0.9333]]]]) label tensor([0])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[-0.7952, -0.0341,  1.1101],\n",
            "          [-0.2718,  1.1275,  2.4851],\n",
            "          [-0.0883,  1.3794,  1.8615]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[ 3.2119, -1.4995]])\n",
            "prediction tensor([[ 3.2119, -1.4995]]) label tensor([0])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.3608, 0.5451, 0.8078, 0.0588],\n",
            "          [1.0000, 1.0000, 0.8392, 0.0235],\n",
            "          [1.0000, 1.0000, 0.6784, 0.9804],\n",
            "          [0.8000, 0.8902, 0.8275, 0.8980]]]]) label tensor([0])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[1.3608, 1.6992, 0.3552],\n",
            "          [2.4851, 2.0729, 0.8680],\n",
            "          [2.1382, 1.8748, 1.8323]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[ 5.7863, -5.2913]])\n",
            "prediction tensor([[ 5.7863, -5.2913]]) label tensor([0])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[1.0000, 1.0000, 0.9216, 0.0902],\n",
            "          [1.0000, 1.0000, 0.7765, 0.9059],\n",
            "          [0.0549, 0.9255, 0.5725, 0.9529],\n",
            "          [0.2784, 0.2745, 0.6118, 0.8588]]]]) label tensor([0])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[ 2.4851,  2.2332,  1.0935],\n",
            "          [ 1.1846,  1.8290,  1.5611],\n",
            "          [-0.0578,  0.7143,  1.3821]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[ 4.9263, -3.6861]])\n",
            "prediction tensor([[ 4.9263, -3.6861]]) label tensor([0])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.9451, 0.3451, 0.9882, 0.4627],\n",
            "          [0.8745, 0.9255, 1.0000, 1.0000],\n",
            "          [0.2000, 0.8039, 1.0000, 1.0000],\n",
            "          [0.0863, 0.4588, 0.8980, 0.3176]]]]) label tensor([0])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[ 1.5264,  1.7074,  1.9146],\n",
            "          [ 1.0776,  2.1511,  2.4851],\n",
            "          [-0.1487,  1.4954,  1.8310]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[ 6.6938, -4.8476]])\n",
            "prediction tensor([[ 6.6938, -4.8476]]) label tensor([0])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.2980, 0.3490, 0.8510, 0.7490],\n",
            "          [0.6510, 1.0000, 1.0000, 0.8471],\n",
            "          [0.5451, 1.0000, 1.0000, 0.3216],\n",
            "          [0.3765, 0.1137, 0.8745, 0.4784]]]]) label tensor([0])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[0.6337, 1.6668, 1.9558],\n",
            "          [1.5308, 2.4851, 1.8092],\n",
            "          [0.5257, 1.2232, 1.2173]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[ 6.5355, -5.0303]])\n",
            "prediction tensor([[ 6.5355, -5.0303]]) label tensor([0])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.7451, 0.0549, 0.3137, 0.2941],\n",
            "          [0.3529, 0.1490, 0.4431, 0.0549],\n",
            "          [1.0000, 0.7608, 0.0000, 0.0000],\n",
            "          [0.7255, 0.8275, 0.0000, 0.0000]]]]) label tensor([1])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[-0.2571, -0.7356, -0.4010],\n",
            "          [ 0.7588, -0.0377, -1.1428],\n",
            "          [ 1.7437,  0.2127, -1.6514]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[-1.3452,  1.2074]])\n",
            "prediction tensor([[-1.3452,  1.2074]]) label tensor([1])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.3922, 0.5529, 0.6471, 0.6667],\n",
            "          [0.9922, 0.5137, 0.5882, 0.9294],\n",
            "          [0.0706, 0.5176, 0.0000, 0.0000],\n",
            "          [0.4745, 0.4745, 0.0000, 0.0000]]]]) label tensor([1])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[ 1.0201,  0.7090,  1.1833],\n",
            "          [ 0.3810,  0.1648, -0.0858],\n",
            "          [-0.0554, -0.4995, -1.6514]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[-0.4628,  1.2404]])\n",
            "prediction tensor([[-0.4628,  1.2404]]) label tensor([1])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.6431, 0.7176, 0.0392, 0.2235],\n",
            "          [0.5255, 0.8510, 0.0980, 0.7490],\n",
            "          [0.0000, 0.0000, 0.4471, 0.4314],\n",
            "          [0.0000, 0.0000, 0.4784, 0.2863]]]]) label tensor([1])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[ 1.0899,  0.3152, -0.6794],\n",
            "          [-0.2313, -0.3417,  0.1441],\n",
            "          [-1.6514, -0.8221,  0.1003]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[-1.6051,  2.1331]])\n",
            "prediction tensor([[-1.6051,  2.1331]]) label tensor([1])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.4667, 0.0000, 0.0000, 0.0902],\n",
            "          [0.4118, 0.0000, 0.0000, 0.7843],\n",
            "          [0.5020, 0.5961, 0.1137, 0.9451],\n",
            "          [0.4784, 0.7529, 0.5176, 0.6980]]]]) label tensor([1])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[-0.6340, -1.6514, -0.9600],\n",
            "          [-0.1171, -0.7804,  0.0337],\n",
            "          [ 0.6831,  0.4595,  0.6600]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[-1.5256,  2.0006]])\n",
            "prediction tensor([[-1.5256,  2.0006]]) label tensor([1])\n",
            "new batch\n",
            "imgs shape torch.Size([1, 1, 4, 4]) img tensor([[[[0.0000, 0.0000, 0.8471, 0.3882],\n",
            "          [0.0000, 0.0000, 0.4431, 0.4000],\n",
            "          [0.5765, 0.8824, 0.6000, 0.0784],\n",
            "          [0.1804, 0.3647, 0.2706, 0.8235]]]]) label tensor([1])\n",
            "Layer Conv2d:\n",
            "  Output [torch.Size([1, 1, 3, 3])]: tensor([[[[-1.6514, -0.4332,  0.5033],\n",
            "          [-0.2205,  0.4275,  0.0653],\n",
            "          [ 0.3686,  0.5576,  0.0251]]]])\n",
            "Layer Linear:\n",
            "  Output [torch.Size([1, 2])]: tensor([[-0.7274,  1.3893]])\n",
            "prediction tensor([[-0.7274,  1.3893]]) label tensor([1])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "activations"
      ],
      "metadata": {
        "id": "CDoW8JUcKfkL",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bc85711f-757a-4d66-dc8d-970f1e21dde6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'Conv2d': tensor([[[[-0.7952, -0.0341,  1.1101],\n",
              "            [-0.2718,  1.1275,  2.4851],\n",
              "            [-0.0883,  1.3794,  1.8615]]]])},\n",
              " {'Linear': tensor([[ 3.2119, -1.4995]])},\n",
              " {'Conv2d': tensor([[[[1.3608, 1.6992, 0.3552],\n",
              "            [2.4851, 2.0729, 0.8680],\n",
              "            [2.1382, 1.8748, 1.8323]]]])},\n",
              " {'Linear': tensor([[ 5.7863, -5.2913]])},\n",
              " {'Conv2d': tensor([[[[ 2.4851,  2.2332,  1.0935],\n",
              "            [ 1.1846,  1.8290,  1.5611],\n",
              "            [-0.0578,  0.7143,  1.3821]]]])},\n",
              " {'Linear': tensor([[ 4.9263, -3.6861]])},\n",
              " {'Conv2d': tensor([[[[ 1.5264,  1.7074,  1.9146],\n",
              "            [ 1.0776,  2.1511,  2.4851],\n",
              "            [-0.1487,  1.4954,  1.8310]]]])},\n",
              " {'Linear': tensor([[ 6.6938, -4.8476]])},\n",
              " {'Conv2d': tensor([[[[0.6337, 1.6668, 1.9558],\n",
              "            [1.5308, 2.4851, 1.8092],\n",
              "            [0.5257, 1.2232, 1.2173]]]])},\n",
              " {'Linear': tensor([[ 6.5355, -5.0303]])},\n",
              " {'Conv2d': tensor([[[[-0.2571, -0.7356, -0.4010],\n",
              "            [ 0.7588, -0.0377, -1.1428],\n",
              "            [ 1.7437,  0.2127, -1.6514]]]])},\n",
              " {'Linear': tensor([[-1.3452,  1.2074]])},\n",
              " {'Conv2d': tensor([[[[ 1.0201,  0.7090,  1.1833],\n",
              "            [ 0.3810,  0.1648, -0.0858],\n",
              "            [-0.0554, -0.4995, -1.6514]]]])},\n",
              " {'Linear': tensor([[-0.4628,  1.2404]])},\n",
              " {'Conv2d': tensor([[[[ 1.0899,  0.3152, -0.6794],\n",
              "            [-0.2313, -0.3417,  0.1441],\n",
              "            [-1.6514, -0.8221,  0.1003]]]])},\n",
              " {'Linear': tensor([[-1.6051,  2.1331]])},\n",
              " {'Conv2d': tensor([[[[-0.6340, -1.6514, -0.9600],\n",
              "            [-0.1171, -0.7804,  0.0337],\n",
              "            [ 0.6831,  0.4595,  0.6600]]]])},\n",
              " {'Linear': tensor([[-1.5256,  2.0006]])},\n",
              " {'Conv2d': tensor([[[[-1.6514, -0.4332,  0.5033],\n",
              "            [-0.2205,  0.4275,  0.0653],\n",
              "            [ 0.3686,  0.5576,  0.0251]]]])},\n",
              " {'Linear': tensor([[-0.7274,  1.3893]])}]"
            ]
          },
          "metadata": {},
          "execution_count": 81
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Testing the effect of common activation values"
      ],
      "metadata": {
        "id": "KsRWXspIKoz9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "for name, param in cnn_model.named_parameters():\n",
        "    print(f\"{name}: {param.shape}\")\n",
        "    print(param.data)\n",
        "for _ in range(10):\n",
        "  new_activation = torch.rand_like(activations[0]['Conv2d'])*2-1\n",
        "  new_activation[activations[0]['Conv2d'] == activations[2]['Conv2d']] = activations[0]['Conv2d'][activations[0]['Conv2d'] == activations[2]['Conv2d']]\n",
        "  print(new_activation)\n",
        "  print(\"output\", cnn_model.forward_intermediate(new_activation))"
      ],
      "metadata": {
        "id": "_H3NmHQtKf7e",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0594448e-48a3-40a7-effc-c48bde72d7ea"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "conv1.weight: torch.Size([1, 1, 2, 2])\n",
            "tensor([[[[1.0190, 1.0396],\n",
            "          [1.3159, 0.7619]]]])\n",
            "conv1.bias: torch.Size([1])\n",
            "tensor([-1.6514])\n",
            "fc1.weight: torch.Size([2, 9])\n",
            "tensor([[ 0.0908,  0.4424,  0.5132,  0.7442,  1.2633,  0.7397, -0.0564,  0.7359,\n",
            "          0.1787],\n",
            "        [-0.2317, -0.3863, -0.3313, -0.6362, -1.2957, -0.5972, -0.4600, -0.4999,\n",
            "         -0.0515]])\n",
            "fc1.bias: torch.Size([2])\n",
            "tensor([-1.9681,  2.5986])\n",
            "tensor([[[[ 0.1941, -0.6537,  0.2643],\n",
            "          [-0.2502,  0.8903, -0.6140],\n",
            "          [ 0.6327,  0.2830,  0.2173]]]])\n",
            "t tensor([[0.1941, 0.0000, 0.2643, 0.0000, 0.8903, 0.0000, 0.6327, 0.2830, 0.2173]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-0.4787,  0.8688]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.1836, -0.4030, -0.2573],\n",
            "          [ 0.7564, -0.2335, -0.7329],\n",
            "          [-0.5944,  0.1500,  0.1044]]]])\n",
            "t tensor([[0.1836, 0.0000, 0.0000, 0.7564, 0.0000, 0.0000, 0.0000, 0.1500, 0.1044]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.2595,  1.9945]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[-0.2427,  0.4926,  0.9944],\n",
            "          [ 0.0947, -0.5249, -0.4358],\n",
            "          [-0.8161,  0.0554,  0.1613]]]])\n",
            "t tensor([[0.0000, 0.4926, 0.9944, 0.0947, 0.0000, 0.0000, 0.0000, 0.0554, 0.1613]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.0998,  1.9826]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.8967,  0.5062, -0.0293],\n",
            "          [-0.3038,  0.2273, -0.0819],\n",
            "          [-0.0991, -0.2660,  0.7672]]]])\n",
            "t tensor([[0.8967, 0.5062, 0.0000, 0.0000, 0.2273, 0.0000, 0.0000, 0.0000, 0.7672]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.2385,  1.8613]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.6853,  0.8673, -0.8468],\n",
            "          [-0.9681, -0.8004,  0.2223],\n",
            "          [-0.3816,  0.9904, -0.1798]]]])\n",
            "t tensor([[0.6853, 0.8673, 0.0000, 0.0000, 0.0000, 0.2223, 0.0000, 0.9904, 0.0000]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-0.6290,  1.4770]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.1950, -0.5107,  0.3063],\n",
            "          [-0.9334, -0.0903,  0.2977],\n",
            "          [ 0.8875, -0.4521,  0.0905]]]])\n",
            "t tensor([[0.1950, 0.0000, 0.3063, 0.0000, 0.0000, 0.2977, 0.8875, 0.0000, 0.0905]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.6069,  1.8613]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.2861, -0.4294, -0.3543],\n",
            "          [-0.5812, -0.8894,  0.3114],\n",
            "          [-0.9508, -0.8474,  0.2058]]]])\n",
            "t tensor([[0.2861, 0.0000, 0.0000, 0.0000, 0.0000, 0.3114, 0.0000, 0.0000, 0.2058]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.6750,  2.3358]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.6775, -0.1504, -0.8624],\n",
            "          [-0.2560, -0.9888, -0.5054],\n",
            "          [ 0.8314, -0.4900,  0.4553]]]])\n",
            "t tensor([[0.6775, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8314, 0.0000, 0.4553]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-1.8721,  2.0358]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.0744, -0.8411,  0.9702],\n",
            "          [ 0.4007,  0.1127,  0.2937],\n",
            "          [-0.7216,  0.0433,  0.3294]]]])\n",
            "t tensor([[0.0744, 0.0000, 0.9702, 0.4007, 0.1127, 0.2937, 0.0000, 0.0433, 0.3294]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-0.7149,  1.6450]], grad_fn=<AddmmBackward0>)\n",
            "tensor([[[[ 0.3117, -0.4547, -0.5759],\n",
            "          [-0.7199,  0.8005, -0.8444],\n",
            "          [-0.3132,  0.1524,  0.6259]]]])\n",
            "t tensor([[0.3117, 0.0000, 0.0000, 0.0000, 0.8005, 0.0000, 0.0000, 0.1524, 0.6259]])\n",
            "fc1 Linear(in_features=9, out_features=2, bias=True)\n",
            "output tensor([[-0.7046,  1.3809]], grad_fn=<AddmmBackward0>)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "lSrXrniTK3Z6"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}