{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Description\n",
    "A modified resnet where the MLP layer has been replaced with the SharedMLP layer that uses templatebanks and coefficients to generate weights. The same script can be used to train the other RECAST-ViT models by replacing the defined model and loading the appropriate weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "manualSeed = 42\n",
    "DEFAULT_THRESHOLD = 5e-3\n",
    "\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "torch.cuda.manual_seed(manualSeed)\n",
    "np.random.seed(manualSeed)\n",
    "cudnn.benchmark = False\n",
    "torch.backends.cudnn.enabled = False\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device: \", device)\n",
    "FACTORS = 6  # number of groups\n",
    "TEMPLATES = (\n",
    "    2  # number of templates per bank, corresponds to number of layers in a group\n",
    ")\n",
    "MULT = 1  # optional multiplier for the number of coefficients set\n",
    "num_cf = 2  # number of coefficients sets per target module\n",
    "def calculate_parameters(model):\n",
    "    attention_params = 0\n",
    "    template_params = 0\n",
    "    coefficients_params = 0\n",
    "    mlp_params = 0\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    classifer_params = sum(p.numel() for p in model.head.parameters())\n",
    "\n",
    "    for n, p in model.named_parameters():\n",
    "        if \".attn.\" in n:\n",
    "            attention_params += p.numel()\n",
    "            # print(\"Attention params: \", n, p.numel())\n",
    "        if \"template_banks\" in n:\n",
    "            template_params += p.numel()\n",
    "            # print(\"Template params: \", n, p.numel())\n",
    "        if \"mlp\" in n:\n",
    "            mlp_params += p.numel()\n",
    "            # print(\"MLP params: \", n, p.numel())\n",
    "        if \"coefficients\" in n:\n",
    "            coefficients_params += p.numel()\n",
    "            # print(\"Coefficients params: \", n, p.numel())\n",
    "\n",
    "    print(\"Classifier head: \", model.head)\n",
    "    print(\n",
    "        f\"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}\"\n",
    "    )\n",
    "    print(f\"Attention parameters: {attention_params}\")\n",
    "    print(f\"Templates params: {template_params}\")\n",
    "    print(f\"MLP params: {mlp_params}\")\n",
    "    print(f\"Coefficients params: {coefficients_params}\")\n",
    "\n",
    "\n",
    "class MLPTemplateBank(nn.Module):\n",
    "    def __init__(self, num_templates, in_features, out_features):\n",
    "        super(MLPTemplateBank, self).__init__()\n",
    "        self.num_templates = num_templates\n",
    "        self.coefficient_shape = (num_templates, 1, 1)\n",
    "        templates = [\n",
    "            torch.Tensor(out_features, in_features) for _ in range(num_templates)\n",
    "        ]\n",
    "        for i in range(num_templates):\n",
    "            nn.init.kaiming_normal_(templates[i])\n",
    "        self.templates = nn.Parameter(torch.stack(templates))\n",
    "\n",
    "    def forward(self, coefficients):\n",
    "        params = self.templates * coefficients\n",
    "        summed_params = torch.sum(params, dim=0)\n",
    "        return summed_params\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})\"\n",
    "\n",
    "\n",
    "class SharedMLP(nn.Module):\n",
    "    def __init__(\n",
    "        self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.0\n",
    "    ):\n",
    "        super(SharedMLP, self).__init__()\n",
    "        self.bank1 = None\n",
    "        self.bank2 = None\n",
    "\n",
    "        if bank1 != None and bank2 != None:\n",
    "            self.bank1 = bank1\n",
    "            self.bank2 = bank2\n",
    "            self.coefficients1 = nn.ParameterList(\n",
    "                [\n",
    "                    nn.Parameter(\n",
    "                        torch.zeros(bank1.coefficient_shape), requires_grad=True\n",
    "                    )\n",
    "                    for _ in range(num_cf)\n",
    "                ]\n",
    "            )\n",
    "            self.coefficients2 = nn.ParameterList(\n",
    "                [\n",
    "                    nn.Parameter(\n",
    "                        torch.zeros(bank2.coefficient_shape), requires_grad=True\n",
    "                    )\n",
    "                    for _ in range(num_cf)\n",
    "                ]\n",
    "            )\n",
    "            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))\n",
    "            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))\n",
    "\n",
    "        self.act = act_layer()\n",
    "        self.norm = nn.Identity()\n",
    "        self.drop = nn.Dropout(drop)\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        if self.bank1 != None:\n",
    "            for cf in self.coefficients1:\n",
    "                nn.init.orthogonal_(cf)\n",
    "        if self.bank2 != None:\n",
    "            for cf in self.coefficients2:\n",
    "                nn.init.orthogonal_(cf)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.bank1 != None:\n",
    "            weight1 = []\n",
    "            for c in self.coefficients1:\n",
    "                w = self.bank1(c)\n",
    "                weight1.append(w)\n",
    "            weights1 = torch.stack(weight1).mean(0)\n",
    "        if self.bank2 != None:\n",
    "            weight2 = []\n",
    "            for c in self.coefficients2:\n",
    "                w = self.bank2(c)\n",
    "                weight2.append(w)\n",
    "            weights2 = torch.stack(weight2).mean(0)\n",
    "\n",
    "        x = F.linear(x, weights1, self.bias1)\n",
    "        x = self.act(x)\n",
    "        x = self.norm(x)\n",
    "        x = F.linear(x, weights2, self.bias2)\n",
    "        x = self.drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# original timm module for vision transformer\n",
    "class Attention(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads=6,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        attn_drop=0.0,\n",
    "        proj_drop=0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = dim // num_heads\n",
    "        self.scale = qk_scale or self.head_dim**-0.5\n",
    "\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(proj_drop)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, N, C = x.shape\n",
    "        qkv = self.qkv(x).chunk(3, dim=-1)\n",
    "        q, k, v = qkv[0], qkv[1], qkv[2]\n",
    "        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        attn = (q @ k.transpose(-2, -1)) * self.scale\n",
    "        attn = F.softmax(attn, dim=-1)  # attention proba\n",
    "        attn = self.attn_drop(attn)\n",
    "        x = attn @ v  # attention output\n",
    "        x = x.transpose(1, 2).reshape(B, N, C)\n",
    "        x = self.proj(x)\n",
    "        x = self.proj_drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        drop=0.0,\n",
    "        attn_drop=0.0,\n",
    "        drop_path=0.0,\n",
    "        act_layer=nn.GELU,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "        bank1=None,\n",
    "        bank2=None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = Attention(\n",
    "            dim,\n",
    "            num_heads=num_heads,\n",
    "            qkv_bias=qkv_bias,\n",
    "            qk_scale=qk_scale,\n",
    "            attn_drop=attn_drop,\n",
    "            proj_drop=drop,\n",
    "        )\n",
    "        self.ls1 = nn.Identity()\n",
    "        self.ls2 = nn.Identity()\n",
    "        self.drop_path = nn.Identity()\n",
    "        self.norm2 = norm_layer(dim)\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "        self.mlp = SharedMLP(\n",
    "            bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))\n",
    "        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n",
    "        return x\n",
    "\n",
    "\n",
    "class PatchEmbed(nn.Module):\n",
    "    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n",
    "        super().__init__()\n",
    "        img_size = (img_size, img_size)\n",
    "        patch_size = (patch_size, patch_size)\n",
    "        self.img_size = img_size\n",
    "        self.patch_size = patch_size\n",
    "        self.proj = nn.Conv2d(\n",
    "            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0\n",
    "        )\n",
    "        self.norm = nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.proj(x)\n",
    "        x = nn.Flatten(start_dim=2, end_dim=3)(x).permute(0, 2, 1)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class VisionTransformer(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_chans=3,\n",
    "        num_classes=1000,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        drop_rate=0.0,\n",
    "        attn_drop_rate=0.0,\n",
    "        drop_path_rate=0.0,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.img_size = img_size\n",
    "        self.dim = embed_dim\n",
    "        self.num_classes = num_classes\n",
    "        self.embed_dim = embed_dim\n",
    "        self.patch_size = patch_size\n",
    "        self.num_features = self.embed_dim\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.num_patches = (img_size // patch_size) ** 2\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.has_class_token = True\n",
    "        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))\n",
    "        self.patch_embed = PatchEmbed(\n",
    "            img_size=img_size,\n",
    "            patch_size=patch_size,\n",
    "            in_chans=in_chans,\n",
    "            embed_dim=embed_dim,\n",
    "        )\n",
    "\n",
    "        num_patches = (self.img_size // self.patch_size) ** 2\n",
    "        print(\"Num patches: \", num_patches)\n",
    "        embed_len = num_patches + self.num_prefix_tokens\n",
    "        self.pos_embed = nn.Parameter(\n",
    "            torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim) * 0.02,\n",
    "            requires_grad=True,\n",
    "        )\n",
    "\n",
    "        self.pos_drop = nn.Dropout(p=drop_rate)\n",
    "        self.patch_drop = nn.Identity()\n",
    "        self.fc_norm = nn.Identity()\n",
    "        self.head_drop = nn.Dropout(drop_rate)\n",
    "\n",
    "        self.num_groups = FACTORS\n",
    "        self.num_layers_in_group = (\n",
    "            depth // self.num_groups\n",
    "        )  # how many consective layers share the same template bank\n",
    "        print(\"Num layers in group: \", self.num_layers_in_group)\n",
    "        self.num_templates = TEMPLATES\n",
    "        mlp_hidden_dim = int(embed_dim * mlp_ratio)\n",
    "        self.template_banks1 = nn.ModuleList(\n",
    "            [\n",
    "                MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim)\n",
    "                for _ in range(self.num_groups)\n",
    "            ]\n",
    "        )\n",
    "        self.template_banks2 = nn.ModuleList(\n",
    "            [\n",
    "                MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim)\n",
    "                for _ in range(self.num_groups)\n",
    "            ]\n",
    "        )\n",
    "        self.depth = depth\n",
    "\n",
    "        dpr = [\n",
    "            x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)\n",
    "        ]  # stochastic depth decay rule\n",
    "        self.blocks = nn.ModuleList()\n",
    "        for i in range(depth):\n",
    "            group_idx = i // self.num_layers_in_group\n",
    "            print(group_idx)\n",
    "            bank1 = self.template_banks1[group_idx]\n",
    "            bank2 = self.template_banks2[group_idx]\n",
    "            self.blocks.append(\n",
    "                Block(\n",
    "                    dim=self.embed_dim,\n",
    "                    num_heads=num_heads,\n",
    "                    mlp_ratio=mlp_ratio,\n",
    "                    qkv_bias=qkv_bias,\n",
    "                    qk_scale=qk_scale,\n",
    "                    drop=drop_rate,\n",
    "                    attn_drop=attn_drop_rate,\n",
    "                    drop_path=dpr[group_idx],\n",
    "                    norm_layer=norm_layer,\n",
    "                    bank1=bank1,\n",
    "                    bank2=bank2,\n",
    "                )\n",
    "            )\n",
    "        print(f\"Num blocks: {len(self.blocks)}\")\n",
    "        self.norm = norm_layer(self.embed_dim)\n",
    "        self.head = (\n",
    "            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n",
    "        )\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                if m.bias is not None:\n",
    "                    nn.init.normal_(m.bias, std=1e-6)\n",
    "            elif isinstance(m, nn.LayerNorm):\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "                nn.init.constant_(m.weight, 1.0)\n",
    "\n",
    "    def _pos_embed(self, x):\n",
    "        to_cat = []\n",
    "        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))\n",
    "        x = torch.cat(to_cat + [x], dim=1)\n",
    "        x = x + self.pos_embed\n",
    "        return self.pos_drop(x)\n",
    "\n",
    "    def forward_features(self, x):\n",
    "        B = x.shape[0]\n",
    "        x = self.patch_embed(x)\n",
    "        x = self._pos_embed(x)\n",
    "        x = self.patch_drop(x)\n",
    "        for blk in self.blocks:\n",
    "            x = blk(x)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "\n",
    "    def forward_head(self, x):\n",
    "        x = x[:, 0]\n",
    "        x = self.fc_norm(x)\n",
    "        x = self.head_drop(x)\n",
    "        x = self.head(x)\n",
    "        return x\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.forward_features(x)\n",
    "        head_output = self.forward_head(features)\n",
    "        return head_output\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataloaders\n",
    "Feel free to use any dataloader of your choice. I have used the following dataloader for my experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "import tqdm\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset\n",
    "from PIL import Image\n",
    "import torch\n",
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import transforms\n",
    "# Write a base dataloader class for image classification\n",
    "class ImageDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.data_path = \"\"\n",
    "        self.data_name = \"\"\n",
    "        self.num_classes = 0\n",
    "        self.train_transform = None\n",
    "        self.train_csv_path = \"\"\n",
    "        self.image_paths = []\n",
    "        self.labels = []\n",
    "\n",
    "    def get_num_classes(self):\n",
    "        return self.num_classes\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img_path = self.image_paths[index]\n",
    "        label = self.labels[index]\n",
    "        img = Image.open(img_path).convert(\"RGB\")\n",
    "\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "\n",
    "    @property\n",
    "    def label_dict(self):\n",
    "        return {i: self.class_map[i] for i in range(self.num_classes)}\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"ImageDataset({self.data_name}) with {self.__len__} instances\"\n",
    "\n",
    "\n",
    "class CARS(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = CARS_DATA\n",
    "        self.data_name = \"cars\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"cars.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values.tolist()\n",
    "        self.num_classes = 196\n",
    "        self.split = None\n",
    "        # json file that contains the class names\n",
    "        self.class_json = os.path.join(BASE_PATH, \"CARS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class AIRCRAFT(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = AIRCRAFT_DATA\n",
    "        self.data_name = \"aircraft\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"aircrafts.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 55\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"AIRCRAFTS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class FLOWERS(ImageDataset):\n",
    "    def __init__(self):\n",
    "\n",
    "        super().__init__()\n",
    "        self.data_path = FLOWERS_DATA\n",
    "        self.data_name = \"flowers\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"flowers.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 103  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"FLOWERS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class SCENES(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = SCENES_DATA\n",
    "        self.data_name = \"scenes\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.Normalize(\n",
    "                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"scenes.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 67\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"SCENES.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class CHARS(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = CHARS_DATA\n",
    "        self.data_name = \"chars\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"chars.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 63  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"CHARS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class BIRDS(ImageDataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.data_path = BIRDS_DATA\n",
    "        self.data_name = \"birds\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ColorJitter(\n",
    "                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2\n",
    "                ),\n",
    "                transforms.RandomAffine(\n",
    "                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)\n",
    "                ),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.Normalize(mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734])\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"birds.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 201  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"BIRDS.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class ACTION(ImageDataset):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.data_path = ACTION_DATA\n",
    "        self.data_name = \"actions\"\n",
    "        self.train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.train_csv_path = os.path.join(BASE_PATH, \"action.csv\")\n",
    "        self.image_paths = pd.read_csv(self.train_csv_path)[\"fname\"].values\n",
    "        self.labels = pd.read_csv(self.train_csv_path)[\"class\"].values\n",
    "        self.num_classes = 20  # not 0 indexed\n",
    "        self.split = None\n",
    "        self.class_json = os.path.join(BASE_PATH, \"ACTION.json\")\n",
    "        self.class_map = json.load(open(self.class_json))\n",
    "\n",
    "\n",
    "class SVHN(ImageDataset):\n",
    "    # TODO: ektu tricky beparshepar\n",
    "    def __init__(self, split=\"train\", transform=None):\n",
    "        super().__init__()\n",
    "        self.data_path = SVHN_DATA\n",
    "        self.data_name = \"svhn\"\n",
    "        self.task_id = 6  # Assign a unique task_id for SVHN\n",
    "        self.split = split\n",
    "        self.transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                # transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(\n",
    "                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
    "                ),\n",
    "            ]\n",
    "        )\n",
    "        self.dataset = datasets.SVHN(root=SVHN_DATA, split=split, download=True)\n",
    "        self.num_classes = 10\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, label = self.dataset[index]\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label, self.task_id\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "\n",
    "def collate_fn(batch):\n",
    "    images, labels, task_ids = zip(*batch)\n",
    "    images = torch.stack(images, dim=0)\n",
    "    labels = torch.tensor(labels)\n",
    "    task_ids = task_ids[0]\n",
    "    return images, labels\n",
    "\n",
    "\n",
    "class TransformedDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, dataset, transform):\n",
    "        self.dataset = dataset\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, label = self.dataset[index]\n",
    "        if isinstance(img, torch.Tensor):\n",
    "            img = img.numpy().transpose(1, 2, 0)\n",
    "        if self.transform:\n",
    "            img = self.transform(img)\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "\n",
    "def get_dataloaders(\n",
    "    dataset_name, train_size=0.8, val_size=0.0, batch_size=32, mode=\"train\"\n",
    "):\n",
    "    if dataset_name == \"cars\":\n",
    "        dataset = CARS()\n",
    "    elif dataset_name == \"aircraft\":\n",
    "        dataset = AIRCRAFT()\n",
    "    elif dataset_name == \"flowers\":\n",
    "        dataset = FLOWERS()\n",
    "    elif dataset_name == \"scenes\":\n",
    "        dataset = SCENES()\n",
    "    elif dataset_name == \"chars\":\n",
    "        dataset = CHARS()\n",
    "    elif dataset_name == \"birds\":\n",
    "        dataset = BIRDS()\n",
    "\n",
    "    elif dataset_name == \"cifar10\":\n",
    "        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "        train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "            ]\n",
    "        )\n",
    "        test_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "            ]\n",
    "        )\n",
    "        cifar_train = datasets.CIFAR10(\n",
    "            root=CIFAR_DATA, train=True, download=True, transform=train_transform\n",
    "        )\n",
    "        cifar_test = datasets.CIFAR10(\n",
    "            root=CIFAR_DATA, train=False, download=True, transform=test_transform\n",
    "        )\n",
    "\n",
    "        train_size = int(train_size * len(cifar_train))\n",
    "        val_size = int(val_size * len(cifar_train))\n",
    "        test_size = len(cifar_train) - train_size - val_size\n",
    "\n",
    "        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "            cifar_train, [train_size, val_size, test_size]\n",
    "        )\n",
    "\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            train_dataset, batch_size=batch_size, shuffle=True\n",
    "        )\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            val_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "        test_loader = torch.utils.data.DataLoader(\n",
    "            test_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "\n",
    "        return train_loader, val_loader, test_loader, 10\n",
    "\n",
    "    elif dataset_name == \"cifar100\":\n",
    "        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "        train_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Resize((224, 224)),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.Normalize(*stats, inplace=True),\n",
    "            ]\n",
    "        )\n",
    "        cifar_train = datasets.CIFAR100(\n",
    "            root=CIFAR_DATA, train=True, download=True, transform=train_transform\n",
    "        )\n",
    "        cifar_test = datasets.CIFAR100(\n",
    "            root=CIFAR_DATA, train=False, download=True, transform=transforms.ToTensor()\n",
    "        )\n",
    "\n",
    "        train_size = int(train_size * len(cifar_train))\n",
    "        val_size = int(val_size * len(cifar_train))\n",
    "        test_size = len(cifar_train) - train_size - val_size\n",
    "\n",
    "        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "            cifar_train, [train_size, val_size, test_size]\n",
    "        )\n",
    "\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            train_dataset, batch_size=batch_size, shuffle=True\n",
    "        )\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            val_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "        test_loader = torch.utils.data.DataLoader(\n",
    "            test_dataset, batch_size=batch_size, shuffle=False\n",
    "        )\n",
    "\n",
    "        return train_loader, val_loader, test_loader, 100\n",
    "\n",
    "   \n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Dataset {dataset_name} not found\")\n",
    "\n",
    "    # split the dataset into train, val, and test\n",
    "    train_size = int(train_size * len(dataset))\n",
    "    val_size = int(val_size * len(dataset))\n",
    "    test_size = len(dataset) - train_size - val_size\n",
    "\n",
    "    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(\n",
    "        dataset, [train_size, val_size, test_size]\n",
    "    )\n",
    "    print(\n",
    "        f\"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}\"\n",
    "    )\n",
    "    # print(dataset.train_transform)\n",
    "    print(dataset.test_transform)\n",
    "    # Create transformed datasets for each split\n",
    "    train_dataset = TransformedDataset(train_dataset, dataset.train_transform)\n",
    "    val_dataset = TransformedDataset(val_dataset, dataset.test_transform)\n",
    "    test_dataset = TransformedDataset(test_dataset, dataset.test_transform)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=True,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "    val_loader = torch.utils.data.DataLoader(\n",
    "        val_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        test_dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        drop_last=True,\n",
    "        pin_memory=True,\n",
    "    )\n",
    "\n",
    "    return train_loader, val_loader, test_loader, dataset.get_num_classes()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training loop\n",
    "I used a pickle file to store the dataloaders beforehand (`dataloader_dict`) and load the dataloaders in the training loop by using task name. \n",
    "`SHARED_WEIGHT` refers to the reconstructed weights of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_parameters(model):\n",
    "    attention_params = 0\n",
    "    template_params = 0\n",
    "    coefficients_params = 0\n",
    "    mlp_params = 0\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    classifer_params = sum(p.numel() for p in model.head.parameters())\n",
    "\n",
    "    for n, p in model.named_parameters():\n",
    "        if \".attn.\" in n:\n",
    "            attention_params += p.numel()\n",
    "            # print(\"Attention params: \", n, p.numel())\n",
    "        if \"template_banks\" in n:\n",
    "            template_params += p.numel()\n",
    "            # print(\"Template params: \", n, p.numel())\n",
    "        if \"mlp\" in n:\n",
    "            mlp_params += p.numel()\n",
    "            # print(\"MLP params: \", n, p.numel())\n",
    "        if \"coefficients\" in n:\n",
    "            coefficients_params += p.numel()\n",
    "            # print(\"Coefficients params: \", n, p.numel())\n",
    "\n",
    "    print(\"Classifier head: \", model.head)\n",
    "    print(\n",
    "        f\"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}\"\n",
    "    )\n",
    "    print(f\"Attention parameters: {attention_params}\")\n",
    "    print(f\"Templates params: {template_params}\")\n",
    "    print(f\"MLP params: {mlp_params}\")\n",
    "    print(f\"Coefficients params: {coefficients_params}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc \n",
    "dataloader_dict = {}\n",
    "TASK_NAME = [\"cars\", \"aircraft\", \"flowers\", \"scenes\", \"chars\", \"birds\", \"cifar10\", \"cifar100\"]\n",
    "num_classes = [196, 55, 103, 67, 63, 201, 10, 100]\n",
    "SHARED_WEIGHT = \"<SharedWeight.pt>\"\n",
    "for task, num_class in zip(TASK_NAME, num_classes):\n",
    "    print(f\"Task: {task}, Num classes: {num_class}\")\n",
    "    if task == \"cifar100\" or task == \"cifar10\":\n",
    "        print(\"Using CIFAR\")\n",
    "        train_loader, _, test_loader, num_class = get_dataloaders(\n",
    "            task, train_size=0.8, val_size=0.0, batch_size=256, mode=\"train\"\n",
    "        )\n",
    "    else:\n",
    "        train_loader = dataloader_dict[task][\"train\"]\n",
    "        test_loader = dataloader_dict[task][\"test\"]\n",
    "    print(\n",
    "        f\"Train size: {len(train_loader.dataset)}, Test size: {len(test_loader.dataset)}\"\n",
    "    )\n",
    "\n",
    "    training_model = VisionTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_chans=3,\n",
    "        num_classes=1000,\n",
    "        embed_dim=384,\n",
    "        depth=12,\n",
    "        num_heads=6,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=True,\n",
    "        qk_scale=None,\n",
    "        drop_rate=0.0,\n",
    "        attn_drop_rate=0.0,\n",
    "        drop_path_rate=0.0,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "    )\n",
    "    target_state_dict = torch.load(SHARED_WEIGHT)[\"state_dict\"]\n",
    "    print(training_model.load_state_dict(target_state_dict, strict=False))\n",
    "    for param in training_model.parameters():\n",
    "        param.requires_grad = False\n",
    "\n",
    "    training_model.head = nn.Linear(training_model.embed_dim, num_class, bias=True)\n",
    "    training_model.head.weight.requires_grad = True\n",
    "    training_model.head.bias.requires_grad = True\n",
    "\n",
    "    cf_parameters = []\n",
    "    for n, p in training_model.named_parameters():\n",
    "        if \"coefficients\" in n:\n",
    "            p.requires_grad = True\n",
    "            cf_parameters.append(p)\n",
    "        if p.requires_grad:\n",
    "            print(f\"Trainable: {n}\")\n",
    "    calculate_parameters(training_model)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    training_model.to(device)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        training_model.parameters(), lr=2e-3, weight_decay=1e-6\n",
    "    )  # ViT\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=33, gamma=0.1)\n",
    "    print(optimizer)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    printing_iter = len(train_loader) // 4\n",
    "    num_epochs = 100  # TODO\n",
    "    best_acc = 0.0\n",
    "    best_loss = 0.0\n",
    "    best_model = None\n",
    "    for epoch in range(num_epochs):\n",
    "        training_model.train()\n",
    "        mini_batch_size = 128  # TODO\n",
    "        classifier_loss = 0.0\n",
    "        for i, (images, labels) in tqdm.tqdm(\n",
    "            enumerate(train_loader), total=len(train_loader), desc=f\"training\"\n",
    "        ):\n",
    "            for j in range(0, images.size(0), mini_batch_size):\n",
    "                optimizer.zero_grad()\n",
    "                classifier_output = training_model(\n",
    "                    images[j : j + mini_batch_size].to(device)\n",
    "                )\n",
    "                loss = criterion(\n",
    "                    classifier_output, labels[j : j + mini_batch_size].to(device)\n",
    "                )\n",
    "                classifier_loss += loss\n",
    "                loss.backward()\n",
    "                # gradient clipping\n",
    "                torch.nn.utils.clip_grad_norm_(training_model.parameters(), 1.0)\n",
    "                optimizer.step()\n",
    "\n",
    "            if i % printing_iter == 0:\n",
    "                print(\n",
    "                    f\"Epoch {epoch}, Iteration {i}, LR: {optimizer.param_groups[0]['lr']}, Loss: {classifier_loss / (i + 1)}\"\n",
    "                )\n",
    "\n",
    "        training_model.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            total = 0\n",
    "            correct = 0\n",
    "            val_loss = 0.0\n",
    "            total_samples = 0\n",
    "            for i, (images, labels) in tqdm.tqdm(\n",
    "                enumerate(test_loader), total=len(test_loader), desc=\"Evaluating\"\n",
    "            ):\n",
    "                for j in range(0, images.size(0), mini_batch_size):\n",
    "                    total_samples += images.size(0)\n",
    "                    sub_images = images[j : j + mini_batch_size].to(device)\n",
    "                    sub_labels = labels[j : j + mini_batch_size].to(device)\n",
    "                    classifier_output = training_model(sub_images)\n",
    "                    _, predicted = torch.max(classifier_output.data, 1)\n",
    "                    total += sub_labels.size(0)\n",
    "                    correct += (predicted == sub_labels).sum().item()\n",
    "                    sub_val_loss = criterion(classifier_output, sub_labels)\n",
    "                    val_loss += sub_val_loss.item()\n",
    "\n",
    "            accuracy = 100 * correct / total\n",
    "            print(\n",
    "                f\"Validation Accuracy: {accuracy}, Average Loss: {val_loss / total_samples}\"\n",
    "            )\n",
    "\n",
    "            if accuracy > best_acc:\n",
    "                best_acc = accuracy\n",
    "                best_loss = val_loss / total_samples\n",
    "                best_model = training_model.state_dict()\n",
    "                print(f\"New best accuracy: {best_acc}, and current Loss: {best_loss}\")\n",
    "        scheduler.step()\n",
    "\n",
    "    print(f\"Best accuracy: {best_acc}, Loss: {best_loss}\")\n",
    "\n",
    "    del (\n",
    "        training_model,\n",
    "        optimizer,\n",
    "        criterion,\n",
    "        scheduler,\n",
    "        train_loader,\n",
    "        test_loader,\n",
    "        best_model,\n",
    "    )\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
