{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "manualSeed = 42\n",
    "DEFAULT_THRESHOLD = 5e-3\n",
    "\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "torch.cuda.manual_seed(manualSeed)\n",
    "np.random.seed(manualSeed)\n",
    "cudnn.benchmark = False\n",
    "torch.backends.cudnn.enabled = False\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device: \", device)\n",
    "GEN_KERNEL = 3\n",
    "num_cf = 2\n",
    "\n",
    "\n",
    "class TemplateBank(nn.Module):\n",
    "    def __init__(self, num_templates, in_planes, out_planes, kernel_size):\n",
    "        super(TemplateBank, self).__init__()\n",
    "        self.in_planes = in_planes\n",
    "        self.out_planes = out_planes\n",
    "        self.coefficient_shape = (num_templates, 1, 1, 1, 1)\n",
    "        self.kernel_size = kernel_size\n",
    "        templates = [\n",
    "            torch.Tensor(out_planes, in_planes, kernel_size, kernel_size)\n",
    "            for _ in range(num_templates)\n",
    "        ]\n",
    "        for i in range(num_templates):\n",
    "            nn.init.kaiming_normal_(templates[i])\n",
    "        self.templates = nn.Parameter(\n",
    "            torch.stack(templates)\n",
    "        )  # this is what we will freeze later\n",
    "\n",
    "    def forward(self, coefficients):\n",
    "        weights = (self.templates * coefficients).sum(0)\n",
    "        return weights\n",
    "\n",
    "    def __repr__(self):\n",
    "        return (\n",
    "            self.__class__.__name__\n",
    "            + \" (\"\n",
    "            + \"num_templates=\"\n",
    "            + str(self.coefficient_shape[0])\n",
    "            + \", kernel_size=\"\n",
    "            + str(self.kernel_size)\n",
    "            + \")\"\n",
    "            + \", in_planes=\"\n",
    "            + str(self.in_planes)\n",
    "            + \", out_planes=\"\n",
    "            + str(self.out_planes)\n",
    "        )\n",
    "\n",
    "\n",
    "class SConv2d(nn.Module):\n",
    "    # TARGET MODULE\n",
    "    def __init__(self, bank, stride=1, padding=1):\n",
    "        super(SConv2d, self).__init__()\n",
    "        self.stride = stride\n",
    "        self.padding = padding\n",
    "        self.bank = bank\n",
    "        self.num_templates = bank.coefficient_shape[0]\n",
    "\n",
    "        self.coefficients = nn.ParameterList(\n",
    "            [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]\n",
    "        )\n",
    "\n",
    "    def forward(self, input):\n",
    "        param_list = []\n",
    "        for i in range(len(self.coefficients)):\n",
    "            params = self.bank(self.coefficients[i])\n",
    "            param_list.append(params)\n",
    "\n",
    "        final_params = torch.stack(param_list).mean(0)\n",
    "        return F.conv2d(input, final_params, stride=self.stride, padding=self.padding)\n",
    "\n",
    "\n",
    "class CustomResidualBlock(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels,\n",
    "        out_channels,\n",
    "        stride=1,\n",
    "        downsample=None,\n",
    "        bank1=None,\n",
    "        bank2=None,\n",
    "    ):\n",
    "        super(CustomResidualBlock, self).__init__()\n",
    "        self.bank1 = bank1\n",
    "        self.bank2 = bank2\n",
    "\n",
    "        # Ensure padding is always 1 for 3x3 convolutions\n",
    "        if self.bank1 and self.bank2:\n",
    "            self.conv1 = SConv2d(bank1, stride=stride, padding=1)\n",
    "            self.conv2 = SConv2d(bank2, stride=1, padding=1)\n",
    "        else:\n",
    "            self.conv1 = nn.Conv2d(\n",
    "                in_channels,\n",
    "                out_channels,\n",
    "                kernel_size=3,\n",
    "                stride=stride,\n",
    "                padding=1,\n",
    "                bias=False,\n",
    "            )\n",
    "            self.conv2 = nn.Conv2d(\n",
    "                out_channels,\n",
    "                out_channels,\n",
    "                kernel_size=3,\n",
    "                stride=1,\n",
    "                padding=1,\n",
    "                bias=False,\n",
    "            )\n",
    "\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "\n",
    "        # Implement downsample as 1x1 convolution when needed\n",
    "        if stride != 1 or in_channels != out_channels:\n",
    "            self.downsample = nn.Sequential(\n",
    "                nn.Conv2d(\n",
    "                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False\n",
    "                ),\n",
    "                nn.BatchNorm2d(out_channels),\n",
    "            )\n",
    "        else:\n",
    "            self.downsample = None\n",
    "\n",
    "        # Initialize weights\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, SConv2d):\n",
    "                for coefficient in m.coefficients:\n",
    "                    nn.init.orthogonal_(coefficient)\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNetTPB(nn.Module):\n",
    "    def __init__(self, block, layers, num_classes=10):\n",
    "        super(ResNetTPB, self).__init__()\n",
    "        self.inplanes = 64\n",
    "        self.layers = layers\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n",
    "        )\n",
    "        self.bn1 = nn.BatchNorm2d(self.inplanes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(512, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(\n",
    "                    self.inplanes, planes, kernel_size=1, stride=stride, bias=False\n",
    "                ),\n",
    "                nn.BatchNorm2d(planes),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "\n",
    "        # DYNAMICALLY CALCULATE THE NUMBER OF TEMPLATES TO USE FOR EACH RESIDUAL BLOCK\n",
    "        # Calculate parameters for remaining blocks\n",
    "        params_per_conv = 9 * planes * planes\n",
    "        params_per_template = 9 * planes * planes\n",
    "        num_templates1 = max(\n",
    "            1, int((blocks - 1) * params_per_conv / params_per_template)\n",
    "        )\n",
    "        num_templates2 = (\n",
    "            num_templates1  # You could potentially use a different calculation here\n",
    "        )\n",
    "\n",
    "        print(\n",
    "            f\"Layer with {planes} planes, {blocks} blocks, using {num_templates1} templates for conv1 and {num_templates2} for conv2\"\n",
    "        )\n",
    "\n",
    "        # Create separate TemplateBanks for conv1 and conv2\n",
    "        tpbank1 = TemplateBank(num_templates1, planes, planes, GEN_KERNEL)\n",
    "        tpbank2 = TemplateBank(num_templates2, planes, planes, GEN_KERNEL)\n",
    "\n",
    "        self.inplanes = planes\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(\n",
    "                block(\n",
    "                    in_channels=self.inplanes,\n",
    "                    out_channels=planes,\n",
    "                    bank1=tpbank1,\n",
    "                    bank2=tpbank2,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "def test():\n",
    "    net = ResNetTPB(CustomResidualBlock, [2, 2, 2, 2])\n",
    "    y = net(torch.randn(1, 3, 32, 32))\n",
    "    print(y.size())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloaders\n",
    "Feel free to use any dataloader of your choice. I have used the following dataloader for my experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "import tqdm\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset\n",
    "from PIL import Image\n",
    "import torch\n",
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import transforms\n",
    "# Write a base dataloader class for image classification\n",
    "class ImageDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.data_path = \"\"\n",
    "        self.data_name = \"\"\n",
    "        self.num_classes = 0\n",
    "        self.train_transform = None\n",
    "        self.train_csv_path = \"\"\n",
    "        self.image_paths = []\n",
    "        self.labels = []\n",
    "\n",
    "    def get_num_classes(self):\n",
    "        return self.num_classes\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img_path = self.image_paths[index]\n",
    "        label = self.labels[index]\n",
    "        img = Image.open(img_path).convert(\"RGB\")\n",
    "\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "\n",
    "    @property\n",
    "    def label_dict(self):\n",
    "        return {i: self.class_map[i] for i in range(self.num_classes)}\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"ImageDataset({self.data_name}) with {self.__len__} instances\"\n",
    "\n",
    "\n",
    "class CARS(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = CARS_DATA\n",
    "        self.data_name = \"cars\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"cars.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values.tolist()\n",
    "        self.num_classes = 196\n",
    "        self.split = None\n",
    "        # json file that contains the class names\n",
    "        self.class_json = os.path.join(BASE_PATH, \"CARS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class AIRCRAFT(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = AIRCRAFT_DATA\n",
    "        self.data_name = \"aircraft\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"aircrafts.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 55\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"AIRCRAFTS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class FLOWERS(ImageDataset):\n",
    "    def __init__(self):\n",
    "\n",
    "        super().__init__()\n",
    "        self.data_path = FLOWERS_DATA\n",
    "        self.data_name = \"flowers\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"flowers.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 103  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"FLOWERS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class SCENES(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = SCENES_DATA\n",
    "        self.data_name = \"scenes\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"scenes.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 67\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"SCENES.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class CHARS(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = CHARS_DATA\n",
    "        self.data_name = \"chars\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"chars.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 63  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"CHARS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class BIRDS(ImageDataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.data_path = BIRDS_DATA\n",
    "        self.data_name = \"birds\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.Normalize(mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734])\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"birds.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 201  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"BIRDS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class ACTION(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = ACTION_DATA\n",
    "        self.data_name = \"actions\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"action.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 20  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"ACTION.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class SVHN(ImageDataset):\n",
    "    # TODO: ektu tricky beparshepar\n",
    "    def __init__(self, split=\"train\", transform=None):\n",
    "        super().__init__()\n",
    "        self.data_path = SVHN_DATA\n",
    "        self.data_name = \"svhn\"\n",
    "        self.task_id = 6  # Assign a unique task_id for SVHN\n",
    "        self.split = split\n",
    "        self.transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.dataset = datasets.SVHN(root=SVHN_DATA, split=split, download=True)\n",
    "        self.num_classes = 10\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, label = self.dataset[index]\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label, self.task_id\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "\n",
    "def collate_fn(batch):\n",
    "    images, labels, task_ids = zip(*batch)\n",
    "    images = torch.stack(images, dim=0)\n",
    "    labels = torch.tensor(labels)\n",
    "    task_ids = task_ids[0]\n",
    "    return images, labels\n",
    "\n",
    "\n",
    "class TransformedDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, dataset, transform):\n",
    "        self.dataset = dataset\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, label = self.dataset[index]\n",
    "        if isinstance(img, torch.Tensor):\n",
    "            img = img.numpy().transpose(1, 2, 0)\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "\n",
    "def get_dataloaders(\n",
    "    dataset_name, train_size=0.8, val_size=0.0, batch_size=32, mode=\"train\"\n",
    "):\n",
    "    if dataset_name == \"cars\":\n",
    "        dataset = CARS()\n",
    "    elif dataset_name == \"aircraft\":\n",
    "        dataset = AIRCRAFT()\n",
    "    elif dataset_name == \"flowers\":\n",
    "        dataset = FLOWERS()\n",
    "    elif dataset_name == \"scenes\":\n",
    "        dataset = SCENES()\n",
    "    elif dataset_name == \"chars\":\n",
    "        dataset = CHARS()\n",
    "    elif dataset_name == \"birds\":\n",
    "        dataset = BIRDS()\n",
    "\n",
    "    elif dataset_name == \"cifar10\":\n",
    "        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "        train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "            ]\n",
    "        )\n",
    "        test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "            ]\n",
    "        )\n",
    "        cifar_train = datasets.CIFAR10(\n",
    "            root=CIFAR_DATA, train=True, download=True, transform=train_transform\n",
    "        )\n",
    "        cifar_test = datasets.CIFAR10(\n",
    "            root=CIFAR_DATA, train=False, download=True, transform=test_transform\n",
    "        )\n",
    "\n",
    "        train_size = int(train_size * len(cifar_train))\n",
    "        val_size = int(val_size * len(cifar_train))\n",
    "        test_size = len(cifar_train) - train_size - val_size\n",
    "\n",
    "        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "            cifar_train, [train_size, val_size, test_size]\n",
    "        )\n",
    "\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            train_dataset, batch_size=batch_size, shuffle=True\n",
    "        )\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            val_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "        test_loader = torch.utils.data.DataLoader(\n",
    "            test_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "\n",
    "        return train_loader, val_loader, test_loader, 10\n",
    "\n",
    "    elif dataset_name == \"cifar100\":\n",
    "        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "        train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(*stats, inplace=True),\n",
    "            ]\n",
    "        )\n",
    "        cifar_train = datasets.CIFAR100(\n",
    "            root=CIFAR_DATA, train=True, download=True, transform=train_transform\n",
    "        )\n",
    "        cifar_test = datasets.CIFAR100(\n",
    "            root=CIFAR_DATA, train=False, download=True, transform=transforms.ToTensor()\n",
    "        )\n",
    "\n",
    "        train_size = int(train_size * len(cifar_train))\n",
    "        val_size = int(val_size * len(cifar_train))\n",
    "        test_size = len(cifar_train) - train_size - val_size\n",
    "\n",
    "        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "            cifar_train, [train_size, val_size, test_size]\n",
    "        )\n",
    "\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            train_dataset, batch_size=batch_size, shuffle=True\n",
    "        )\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            val_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "        test_loader = torch.utils.data.DataLoader(\n",
    "            test_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "\n",
    "        return train_loader, val_loader, test_loader, 100\n",
    "\n",
    "   \n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Dataset {dataset_name} not found\")\n",
    "\n",
    "    # split the dataset into train, val, and test\n",
    "    train_size = int(train_size * len(dataset))\n",
    "    val_size = int(val_size * len(dataset))\n",
    "    test_size = len(dataset) - train_size - val_size\n",
    "\n",
    "    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "        dataset, [train_size, val_size, test_size]\n",
    "    )\n",
    "    print(\n",
    "        f\"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}\"\n",
    "    )\n",
    "    # print(dataset.train_transform)\n",
    "    print(dataset.test_transform)\n",
    "    # Create transformed datasets for each split\n",
    "    train_dataset = TransformedDataset(train_dataset, dataset.train_transform)\n",
    "    val_dataset = TransformedDataset(val_dataset, dataset.test_transform)\n",
    "    test_dataset = TransformedDataset(test_dataset, dataset.test_transform)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=True,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "    val_loader = torch.utils.data.DataLoader(\n",
    "        val_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        test_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    return train_loader, val_loader, test_loader, dataset.get_num_classes()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_parameters(model):\n",
    "    params = sum([np.prod(p.size()) for p in model.parameters()])\n",
    "\n",
    "    trainable_params = sum(\n",
    "        [\n",
    "            np.prod(p.size())\n",
    "            for p in filter(lambda p: p.requires_grad, model.parameters())\n",
    "        ]\n",
    "    )\n",
    "    batch_norm_params = 0\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, nn.BatchNorm2d):\n",
    "            batch_norm_params += np.prod(module.weight.size())\n",
    "\n",
    "    untrainable_params = params - trainable_params\n",
    "    template_params = 0\n",
    "    coefficients = 0\n",
    "    for name, param in model.named_parameters():\n",
    "        if \"templates\" in name:\n",
    "            template_params += np.prod(param.size())\n",
    "        if \"coefficients\" in name:\n",
    "            coefficients += np.prod(param.size())\n",
    "\n",
    "    print(\"* number of parameters: {}\".format(params))\n",
    "    print(\"* untrainable params: {}\".format(untrainable_params))\n",
    "    print(\"* trainable params: {}\".format(trainable_params))\n",
    "\n",
    "    print(\"* template params: {}\".format(template_params))\n",
    "    print(\"* coefficients: {}\".format(coefficients))\n",
    "    print(\"* batch norm params: {}\".format(batch_norm_params))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc \n",
    "dataloader_dict = {}\n",
    "TASK_NAME = [\"cars\", \"aircraft\", \"flowers\", \"scenes\", \"chars\", \"birds\", \"cifar10\", \"cifar100\"]\n",
    "num_classes = [196, 55, 103, 67, 63, 201, 10, 100]\n",
    "SHARED_WEIGHT = \"<SharedWeight.pt>\"\n",
    "for task, num_class in zip(TASK_NAME, num_classes):\n",
    "    print(f\"Task: {task}, Num classes: {num_class}\")\n",
    "    if task == \"cifar100\" or task == \"cifar10\":\n",
    "        print(\"Using CIFAR\")\n",
    "        train_loader, _, test_loader, num_class = get_dataloaders(\n",
    "            task, train_size=0.8, val_size=0.0, batch_size=256, mode=\"train\"\n",
    "        )\n",
    "    else:\n",
    "        train_loader = dataloader_dict[task][\"train\"]\n",
    "        test_loader = dataloader_dict[task][\"test\"]\n",
    "    print(\n",
    "        f\"Train size: {len(train_loader.dataset)}, Test size: {len(test_loader.dataset)}\"\n",
    "    )\n",
    "\n",
    "    training_model = ResNetTPB(CustomResidualBlock, [3, 4, 6, 3], num_classes=1000)\n",
    "    print(\n",
    "        training_model.load_state_dict(\n",
    "            torch.load(SHARED_WEIGHT)[\"state_dict\"], strict=True  # TODO\n",
    "        )\n",
    "    )\n",
    "\n",
    "    for param in training_model.parameters():\n",
    "        param.requires_grad = False\n",
    "\n",
    "    resnet_embeddim = training_model.fc.in_features\n",
    "    training_model.fc = nn.Linear(resnet_embeddim, num_class, bias=True)\n",
    "    training_model.fc.weight.requires_grad = True\n",
    "    training_model.fc.bias.requires_grad = True\n",
    "    cf_parameters = []\n",
    "    for n, p in training_model.named_parameters():\n",
    "        if \"coefficients\" in n:\n",
    "            p.requires_grad = True\n",
    "            cf_parameters.append(p)\n",
    "\n",
    "        if p.requires_grad:\n",
    "            print(f\"Trainable: {n}\")\n",
    "    calculate_parameters(training_model)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    training_model.to(device)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        training_model.parameters(), lr=2e-3, weight_decay=1e-46\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=33, gamma=0.1)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    printing_iter = len(train_loader) // 6\n",
    "    num_epochs = 100\n",
    "    best_acc = 0.0\n",
    "    best_loss = 0.0\n",
    "    best_model = None\n",
    "    for epoch in range(num_epochs):\n",
    "        training_model.train()\n",
    "        mini_batch_size = 64  # TODO\n",
    "        classifier_loss = 0.0\n",
    "        for i, (images, labels) in tqdm.tqdm(\n",
    "            enumerate(train_loader), total=len(train_loader), desc=f\"training\"\n",
    "        ):\n",
    "            for j in range(0, images.size(0), mini_batch_size):\n",
    "                optimizer.zero_grad()\n",
    "                classifier_output = training_model(\n",
    "                    images[j : j + mini_batch_size].to(device)\n",
    "                )\n",
    "                loss = criterion(\n",
    "                    classifier_output, labels[j : j + mini_batch_size].to(device)\n",
    "                )\n",
    "                classifier_loss += loss\n",
    "                loss.backward()\n",
    "                # gradient clipping\n",
    "                torch.nn.utils.clip_grad_norm_(training_model.parameters(), 1.0)\n",
    "                optimizer.step()\n",
    "\n",
    "            if i % printing_iter == 0:\n",
    "                print(\n",
    "                    f\"Epoch {epoch}, Iteration {i}, LR: {optimizer.param_groups[0]['lr']}, Loss: {classifier_loss / (i + 1)}\"\n",
    "                )\n",
    "\n",
    "        training_model.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            total = 0\n",
    "            correct = 0\n",
    "            val_loss = 0.0\n",
    "            total_samples = 0\n",
    "            for i, (images, labels) in tqdm.tqdm(\n",
    "                enumerate(test_loader), total=len(test_loader), desc=\"Evaluating\"\n",
    "            ):\n",
    "                for j in range(0, images.size(0), mini_batch_size):\n",
    "                    total_samples += images.size(0)\n",
    "                    sub_images = images[j : j + mini_batch_size].to(device)\n",
    "                    sub_labels = labels[j : j + mini_batch_size].to(device)\n",
    "                    classifier_output = training_model(sub_images)\n",
    "                    _, predicted = torch.max(classifier_output.data, 1)\n",
    "                    total += sub_labels.size(0)\n",
    "                    correct += (predicted == sub_labels).sum().item()\n",
    "                    sub_val_loss = criterion(classifier_output, sub_labels)\n",
    "                    val_loss += sub_val_loss.item()\n",
    "\n",
    "            accuracy = 100 * correct / total\n",
    "            print(\n",
    "                f\"Validation Accuracy: {accuracy}, Average Loss: {val_loss / total_samples}\"\n",
    "            )\n",
    "\n",
    "            if accuracy > best_acc:\n",
    "                best_acc = accuracy\n",
    "                best_loss = val_loss / total_samples\n",
    "                best_model = training_model.state_dict()\n",
    "                print(f\"New best accuracy: {best_acc}, and current Loss: {best_loss}\")\n",
    "        scheduler.step()\n",
    "\n",
    "    print(f\"Best accuracy: {best_acc}, Loss: {best_loss}\")\n",
    "\n",
    "    del (\n",
    "        training_model,\n",
    "        optimizer,\n",
    "        criterion,\n",
    "        scheduler,\n",
    "        train_loader,\n",
    "        test_loader,\n",
    "        best_model,\n",
    "    )\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
