{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we provide the reconstruction pipeline code for the RECAST network, where RECAST module has been used in the MLP layers of a Vision Transfer. A similar pipeline can be used for other types of networks as well (RECAST applied to the attention layers, CNN layers etc.) by changing the [model definition](models) and the forward pass in `reconstruction_loss_dynamic()` accordingly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import init\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import gc\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import random\n",
    "import os\n",
    "import shutil\n",
    "import tqdm\n",
    "manualSeed = 42\n",
    "DEFAULT_THRESHOLD = 5e-3\n",
    "\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "torch.cuda.manual_seed(manualSeed)\n",
    "np.random.seed(manualSeed)\n",
    "cudnn.benchmark = False\n",
    "torch.backends.cudnn.enabled = False\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(\"Device: \", device)\n",
    "batch_size = 256\n",
    "import gc\n",
    "\n",
    "FACTORS = 6 # number of groups\n",
    "TEMPLATES = 2 # number of templates per bank, corresponds to number of layers in a group\n",
    "MULT = 1 # optional multiplier for the number of coefficients set\n",
    "num_cf = 2 # number of coefficients sets per target module\n",
    "\n",
    "def calculate_parameters(model):\n",
    "    attention_params = 0\n",
    "    template_params = 0\n",
    "    coefficients_params = 0\n",
    "    mlp_params = 0\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    classifer_params = sum(p.numel() for p in model.head.parameters())\n",
    " \n",
    "    for n, p in model.named_parameters():\n",
    "        if \".attn.\" in n:\n",
    "            attention_params += p.numel()\n",
    "            # print(\"Attention params: \", n, p.numel())\n",
    "        if \"template_banks\" in n:\n",
    "            template_params += p.numel()\n",
    "            # print(\"Template params: \", n, p.numel())\n",
    "        if \"mlp\" in n:\n",
    "            mlp_params += p.numel()\n",
    "            # print(\"MLP params: \", n, p.numel())\n",
    "        if \"coefficients\" in n:\n",
    "            coefficients_params += p.numel()\n",
    "            # print(\"Coefficients params: \", n, p.numel())\n",
    "\n",
    "\n",
    "    print(\"Classifier head: \", model.head)\n",
    "    print(f\"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}\")\n",
    "    print(f\"Attention parameters: {attention_params}\")\n",
    "    print(f\"Templates params: {template_params}\")\n",
    "    print(f\"MLP params: {mlp_params}\")\n",
    "    print(f\"Coefficients params: {coefficients_params}\")\n",
    "\n",
    "class MLPTemplateBank(nn.Module):\n",
    "    def __init__(self, num_templates, in_features, out_features):\n",
    "        super(MLPTemplateBank, self).__init__()\n",
    "        self.num_templates = num_templates\n",
    "        self.coefficient_shape = (num_templates, 1, 1)\n",
    "        templates = [torch.Tensor(out_features, in_features) for _ in range(num_templates)]\n",
    "        for i in range(num_templates):\n",
    "            init.kaiming_normal_(templates[i])\n",
    "        self.templates = nn.Parameter(torch.stack(templates))\n",
    "    def forward(self, coefficients):\n",
    "        params = self.templates * coefficients\n",
    "        summed_params = torch.sum(params, dim=0)\n",
    "        return summed_params    \n",
    "    \n",
    "    def __repr__(self):\n",
    "        return f\"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})\"\n",
    "        \n",
    "class SharedMLP(nn.Module):\n",
    "    # TARGET MODULE\n",
    "    def __init__(self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.):\n",
    "        super(SharedMLP, self).__init__()\n",
    "        self.bank1 = None\n",
    "        self.bank2 = None\n",
    "\n",
    "        if bank1 != None and bank2 != None:\n",
    "            self.bank1 = bank1\n",
    "            self.bank2 = bank2\n",
    "            self.coefficients1 = nn.ParameterList([nn.Parameter(torch.zeros(bank1.coefficient_shape), requires_grad = True) for _ in range(num_cf)])\n",
    "            self.coefficients2 = nn.ParameterList([nn.Parameter(torch.zeros(bank2.coefficient_shape), requires_grad = True) for _ in range(num_cf)])\n",
    "            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))\n",
    "            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))\n",
    "\n",
    "        self.act = act_layer()\n",
    "        self.norm = nn.Identity()\n",
    "        self.drop = nn.Dropout(drop)\n",
    "        self.init_weights()\n",
    "    def init_weights(self):\n",
    "        if self.bank1 != None:\n",
    "            for cf in self.coefficients1:\n",
    "                nn.init.orthogonal_(cf)\n",
    "        if self.bank2 != None:\n",
    "            for cf in self.coefficients2:\n",
    "                nn.init.orthogonal_(cf)\n",
    "    def forward(self, x):\n",
    "        # print(f\"CF1: \",self.coefficients1)\n",
    "        # print(f\"CF2: \",self.coefficients2)\n",
    "        if self.bank1 != None:\n",
    "            # weights1 = self.bank1(self.coefficients1)\n",
    "            weight1 = []\n",
    "            for c in self.coefficients1:\n",
    "                w = self.bank1(c)\n",
    "                weight1.append(w)\n",
    "            weights1 = torch.stack(weight1).mean(0) # TODO\n",
    "        if self.bank2 != None:\n",
    "            # weights2 = self.bank2(self.coefficients2)\n",
    "            weight2 = []\n",
    "            for c in self.coefficients2:\n",
    "                w = self.bank2(c)\n",
    "                weight2.append(w)\n",
    "            weights2 = torch.stack(weight2).mean(0) # TODO\n",
    "\n",
    "        x = F.linear(x, weights1, self.bias1)\n",
    "        x = self.act(x)\n",
    "        x = self.norm(x)\n",
    "        x = F.linear(x, weights2, self.bias2)\n",
    "        x = self.drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# original timm module for vision transformer\n",
    "class Attention(nn.Module):\n",
    "    def __init__(\n",
    "            self, dim, num_heads=6, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = dim // num_heads\n",
    "        self.scale = qk_scale or self.head_dim ** -0.5\n",
    "\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(proj_drop)\n",
    "\n",
    "    def forward(self,x):\n",
    "        B, N, C = x.shape\n",
    "        qkv = self.qkv(x).chunk(3, dim=-1)\n",
    "        q, k, v = qkv[0], qkv[1], qkv[2]\n",
    "        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        attn = (q @ k.transpose(-2, -1)) * self.scale\n",
    "        attn = F.softmax(attn, dim=-1) # attention proba\n",
    "        attn = self.attn_drop(attn)\n",
    "        x = attn @ v # attention output\n",
    "        x = x.transpose(1, 2).reshape(B, N, C)\n",
    "        x = self.proj(x)\n",
    "        x = self.proj_drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, bank1=None, bank2=None):\n",
    "        super().__init__()\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n",
    "        self.ls1 = nn.Identity()\n",
    "        self.ls2 = nn.Identity()\n",
    "        self.drop_path = nn.Identity()\n",
    "        self.norm2 = norm_layer(dim)\n",
    "        self.mlp = SharedMLP(bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))\n",
    "        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n",
    "        return x      \n",
    "\n",
    "class PatchEmbed(nn.Module):\n",
    "    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n",
    "        super().__init__()\n",
    "        img_size = (img_size, img_size)\n",
    "        patch_size = (patch_size, patch_size)\n",
    "        self.img_size = img_size\n",
    "        self.patch_size = patch_size\n",
    "        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0)\n",
    "        self.norm = nn.Identity()\n",
    "    def forward(self, x):\n",
    "        x = self.proj(x)\n",
    "        x = nn.Flatten(start_dim=2, end_dim = 3)(x).permute(0, 2, 1)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# create vision transformer with template bank\n",
    "class VisionTransformer(nn.Module):\n",
    "    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm):\n",
    "        super().__init__()\n",
    "        self.img_size = img_size\n",
    "        self.dim = embed_dim\n",
    "        self.num_classes = num_classes\n",
    "        self.embed_dim = embed_dim\n",
    "        self.patch_size = patch_size\n",
    "        self.num_features = self.embed_dim\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.num_patches = (img_size // patch_size) ** 2\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.has_class_token = True\n",
    "        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))\n",
    "        self.patch_embed = PatchEmbed(\n",
    "            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim\n",
    "        )\n",
    "\n",
    "        num_patches =  (self.img_size // self.patch_size) ** 2\n",
    "        print(\"Num patches: \", num_patches)\n",
    "        embed_len = num_patches + self.num_prefix_tokens \n",
    "        self.pos_embed =  nn.Parameter(torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim)*.02, requires_grad=True)\n",
    "\n",
    "        self.pos_drop = nn.Dropout(p=drop_rate)\n",
    "        self.patch_drop = nn.Identity()\n",
    "        self.fc_norm = nn.Identity()\n",
    "        self.head_drop = nn.Dropout(drop_rate)\n",
    "\n",
    "        self.num_groups = FACTORS\n",
    "        self.num_layers_in_group = depth // self.num_groups # how many consective encoder layers share the same template bank\n",
    "        print(\"Num layers in group: \", self.num_layers_in_group)\n",
    "        self.num_templates = TEMPLATES\n",
    "        mlp_hidden_dim = int(embed_dim * mlp_ratio)\n",
    "        self.template_banks1 = nn.ModuleList([MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim) for _ in range(self.num_groups)])\n",
    "        self.template_banks2 = nn.ModuleList([MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim) for _ in range(self.num_groups)])\n",
    "        self.depth = depth\n",
    "\n",
    "        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)]  # stochastic depth decay rule\n",
    "        self.blocks = nn.ModuleList()\n",
    "        for i in range(depth):\n",
    "            group_idx = i //self.num_layers_in_group\n",
    "            print(group_idx)\n",
    "            bank1 = self.template_banks1[group_idx]\n",
    "            bank2 = self.template_banks2[group_idx]\n",
    "            self.blocks.append(Block(\n",
    "                dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n",
    "                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[group_idx], norm_layer=norm_layer, bank1=bank1, bank2=bank2))\n",
    "        print(f\"Num blocks: {len(self.blocks)}\")\n",
    "        self.norm = norm_layer(self.embed_dim)\n",
    "        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                if m.bias is not None:\n",
    "                    nn.init.normal_(m.bias, std=1e-6)\n",
    "            elif isinstance(m, nn.LayerNorm):\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "                nn.init.constant_(m.weight, 1.0)\n",
    "        \n",
    "    def _pos_embed(self, x):\n",
    "        to_cat = []\n",
    "        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))\n",
    "        x = torch.cat(to_cat + [x], dim=1)\n",
    "        x = x + self.pos_embed\n",
    "        return self.pos_drop(x)\n",
    "    def forward_features(self, x):\n",
    "        B = x.shape[0]\n",
    "        x = self.patch_embed(x)\n",
    "        x = self._pos_embed(x)\n",
    "        x = self.patch_drop(x)\n",
    "        for blk in self.blocks:\n",
    "            x = blk(x)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "    \n",
    "    def forward_head(self, x):\n",
    "        x = x[:, 0]\n",
    "        x = self.fc_norm(x)\n",
    "        x = self.head_drop(x)\n",
    "        x = self.head(x)\n",
    "        return x\n",
    "    \n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.forward_features(x)\n",
    "        head_output = self.forward_head(features)\n",
    "        return head_output\n",
    "\n",
    "my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)\n",
    "calculate_parameters(my_model)\n",
    "\n",
    "from pytorch_model_summary import summary\n",
    "my_model.eval()\n",
    "print(summary(my_model, torch.zeros(1, 3,224,224), show_hierarchical=True))\n",
    "del my_model \n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load existing Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load timm vit weights into the custom model\n",
    "import timm \n",
    "vit = timm.create_model(\"vit_small_patch16_224\", pretrained=True)\n",
    "print(\"Base model\")\n",
    "calculate_parameters(vit)\n",
    "# freeze the weights of the timm model\n",
    "for param in vit.parameters():\n",
    "    param.requires_grad = False\n",
    "FOUND = []\n",
    "# load the weights from the timm model to the custom model\n",
    "def load_weights(model, timm_model):\n",
    "    model_dict = model.state_dict()\n",
    "    timm_dict = timm_model.state_dict()\n",
    "    new_dict = {}\n",
    "    for k, v in timm_dict.items():\n",
    "        if k in model_dict:\n",
    "            # check the shape of the weights\n",
    "            if model_dict[k].shape == timm_dict[k].shape:\n",
    "                new_dict[k] = v \n",
    "                FOUND.append(k)\n",
    "        else:\n",
    "            if \"mlp.fc1.bias\" in k:\n",
    "                new_k = k.replace(\"mlp.fc1.bias\", \"mlp.bias1\") # name in model_dict\n",
    "                print(f\"{k} --> {new_k}\")\n",
    "                if model_dict[new_k].shape == timm_dict[k].shape:\n",
    "                    new_dict[new_k] = v\n",
    "                    FOUND.append(new_k)\n",
    "            elif \"mlp.fc2.bias\" in k:\n",
    "                new_k = k.replace(\"mlp.fc2.bias\", \"mlp.bias2\",)\n",
    "                print(f\"{k} --> {new_k}\")\n",
    "                if model_dict[new_k].shape == timm_dict[k].shape:\n",
    "                    new_dict[new_k] = v\n",
    "                    FOUND.append(new_k)\n",
    "            else:\n",
    "                print(\"Key not found: \", k)\n",
    "    model_dict.update(new_dict)\n",
    "    model.load_state_dict(model_dict)\n",
    "    return model\n",
    "\n",
    "tpbvit = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)\n",
    "\n",
    "tpbvit = load_weights(tpbvit, vit)\n",
    "tpbvit_state_dict = list(tpbvit.state_dict().keys()) \n",
    "print(\"Found: \", len(FOUND))\n",
    "target_params = []\n",
    "\n",
    "for k in tpbvit_state_dict:\n",
    "    if k not in FOUND:\n",
    "        target_params.append(k)\n",
    "\n",
    "for name, param in tpbvit.named_parameters():\n",
    "    if name not in target_params :\n",
    "        param.requires_grad = False\n",
    "    # check if the param is layer norm or batch norm\n",
    "    # if \"norm\" in name:\n",
    "    #     param.requires_grad = True\n",
    "\n",
    "calculate_parameters(tpbvit) \n",
    "\n",
    "for n, p in tpbvit.named_parameters():\n",
    "    if p.requires_grad:\n",
    "        print(n, p.requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reconstruction Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import timm \n",
    "import copy\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import timm \n",
    "import copy\n",
    "def reconstruction_loss_dynamic(current_model, pretrained_model, criterion=nn.SmoothL1Loss(), w1_weight=2.0, w2_weight=2.0):\n",
    "    corr_state_dict = pretrained_model.state_dict()\n",
    "    loss_dict = {}\n",
    "    total_loss = 0.0\n",
    "    w1_loss = 0.0\n",
    "    w2_loss = 0.0\n",
    "    ortho_loss = 0.0\n",
    "    \n",
    "    # Determine the device of the current model\n",
    "    device = next(current_model.parameters()).device\n",
    "    \n",
    "    for id, block in enumerate(current_model.blocks):  \n",
    "        if block.mlp.bank1 is not None:\n",
    "            mlp_cf1 = block.mlp.coefficients1\n",
    "            mlp_bank1 = block.mlp.bank1\n",
    "            weights1 = []\n",
    "            noise_std1 = 1e-4 # set to zero for no noise\n",
    "            for c in mlp_cf1:\n",
    "                if current_model.training:\n",
    "                    noise = torch.randn_like(c) * noise_std1\n",
    "                    c = c + noise\n",
    "                w = mlp_bank1(c)\n",
    "                weights1.append(w)\n",
    "    \n",
    "            _weights1 = torch.stack(weights1).mean(0) # TODO\n",
    "            corr_weight1 = corr_state_dict[f'blocks.{id}.mlp.fc1.weight'].to(device)\n",
    "            w1_l = criterion(_weights1, corr_weight1) * w1_weight\n",
    "\n",
    "\n",
    "        if block.mlp.bank2 is not None:\n",
    "            mlp_cf2 = block.mlp.coefficients2\n",
    "            mlp_bank2 = block.mlp.bank2\n",
    "            noise_std2 = 1e-4  # set to zero for no noise\n",
    "            weights2 = []\n",
    "            for c in mlp_cf2:\n",
    "                if current_model.training:\n",
    "                    noise = torch.randn_like(c) * noise_std2\n",
    "                    c = c + noise\n",
    "                w = mlp_bank2(c)\n",
    "                weights2.append(w)\n",
    "            _weights2 = torch.stack(weights2).mean(0) # TODO\n",
    "            corr_weight2 = corr_state_dict[f'blocks.{id}.mlp.fc2.weight'].to(device)\n",
    "            w2_l = criterion(_weights2, corr_weight2) * w2_weight\n",
    "\n",
    "        \n",
    "        total_loss += w1_l + w2_l\n",
    "        w1_loss += w1_l.item()\n",
    "        w2_loss += w2_l.item()\n",
    "\n",
    "\n",
    "    return loss_dict, total_loss, w1_loss, w2_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.RMSprop(\n",
    "         tpbvit.parameters(),\n",
    "        lr=1e-1,\n",
    "    )\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)\n",
    "best_model = None\n",
    "best_model_loss = 1e9\n",
    "for epoch in range(1000):\n",
    "    optimizer.zero_grad()\n",
    "    losses_dict, total_loss, w1, w2 = reconstruction_loss_dynamic(tpbvit.to(device), vit)\n",
    "    if total_loss < best_model_loss:\n",
    "        best_model_loss = total_loss.item()\n",
    "        best_model = copy.deepcopy(tpbvit)\n",
    "        # print(\"New Best Loss: \", best_model_loss)\n",
    "    total_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(tpbvit.parameters(), max_norm=1.0)\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    if (epoch+1) %10 == 0:\n",
    "        print(\"Epoch: \", epoch, \"Loss: \", total_loss.item(), \"W1: \", w1, \"W2: \", w2)\n",
    "        print(\"LR: \", optimizer.param_groups[0]['lr'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Reconstruction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing Cosine Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_dynamic_reconstruction(current_model, pretrained_model):\n",
    "    similarity = 0.0\n",
    "    metric = nn.CosineSimilarity(dim=0)\n",
    "    corr_state_dict = pretrained_model.state_dict()\n",
    "    w1_loss = 0.0\n",
    "    w2_loss = 0.0\n",
    "    b1_loss = 0.0\n",
    "    b2_loss = 0.0\n",
    "    w1_l = 0\n",
    "    w2_l = 0\n",
    "    b1_l = 0\n",
    "    b2_l = 0\n",
    "    similarity_dict = {}\n",
    "    for id, block in enumerate(current_model.blocks):\n",
    "        if block.mlp.bank1 != None :\n",
    "            mlp_cf1 = block.mlp.coefficients1\n",
    "            mlp_bank1 = block.mlp.bank1\n",
    "            weights1 = torch.stack([mlp_bank1(cf) for cf in mlp_cf1]).mean(0) # TODO\n",
    "            corr_weight1 = corr_state_dict[f'blocks.{id}.mlp.fc1.weight']\n",
    "            w1_l = metric(weights1.view(-1), corr_weight1.view(-1))\n",
    "            w1_loss += w1_l.item()\n",
    "\n",
    "            bias1 = block.mlp.bias1\n",
    "            corr_bias1 = corr_state_dict[f'blocks.{id}.mlp.fc1.bias']\n",
    "            b1_l = metric(bias1.view(-1), corr_bias1.view(-1))\n",
    "            b1_loss += b1_l.item()\n",
    "            group_idx = id // current_model.num_layers_in_group\n",
    "            # print(\"Group: \", group_idx)\n",
    "            similarity_dict[f\"template_banks1.{group_idx}.templates\"] = w1_l.item()\n",
    "            similarity_dict[f\"blocks.{id}.mlp.bias1\"] = b1_l.item()\n",
    "            for i, cf in enumerate(mlp_cf1):\n",
    "                similarity_dict[f\"blocks.{id}.mlp.coefficients1.{i}\"] = w1_l.item()\n",
    "        if block.mlp.bank2 != None:\n",
    "            mlp_cf2 = block.mlp.coefficients2\n",
    "            mlp_bank2 = block.mlp.bank2\n",
    "            weights2 = torch.stack([mlp_bank2(cf) for cf in mlp_cf2]).mean(0) # TODO\n",
    "            corr_weight2 = corr_state_dict[f'blocks.{id}.mlp.fc2.weight']\n",
    "            w2_l = metric(weights2.view(-1), corr_weight2.view(-1))\n",
    "            w2_loss += w2_l.item()\n",
    "\n",
    "            bias2 = block.mlp.bias2\n",
    "            corr_bias2 = corr_state_dict[f'blocks.{id}.mlp.fc2.bias']\n",
    "            b2_l = metric(bias2.view(-1), corr_bias2.view(-1))\n",
    "            b2_loss += b2_l.item()\n",
    "            group_idx = id // current_model.num_layers_in_group\n",
    "            # print(\"Group: \", group_idx)\n",
    "            similarity_dict[f\"template_banks2.{group_idx}.templates\"] = w2_l.item()\n",
    "            similarity_dict[f\"blocks.{id}.mlp.bias2\"] = b2_l.item()\n",
    "            for i, cf in enumerate(mlp_cf2):\n",
    "                similarity_dict[f\"blocks.{id}.mlp.coefficients2.{i}\"] = w2_l.item()\n",
    "            \n",
    "        print(\"Block: \", id, \"W1: \", w1_l, \"W2: \", w2_l, \"B1: \", b1_l, \"B2: \", b2_l)\n",
    "\n",
    "    print(\"W1: \", w1_loss/len(current_model.blocks), \"W2: \", w2_loss/len(current_model.blocks), \"B1: \", b1_loss/len(current_model.blocks), \"B2: \", b2_loss/len(current_model.blocks))\n",
    "    return similarity_dict\n",
    "\n",
    "evaluate_dynamic_reconstruction(best_model.to(\"cpu\"), vit)\n",
    "calculate_parameters(best_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing layerwise feature similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "# Load the batch from disk\n",
    "batch = torch.load(\"<DATALOADER_BATCH_PATH>\")\n",
    "\n",
    "# Modified forward function to return intermediate features\n",
    "def forward_with_intermediate(model, x, cust = True):\n",
    "    B = x.shape[0]\n",
    "    x = model.patch_embed(x)\n",
    "    intermediate_features = []\n",
    "    if not cust:\n",
    "        print(\"classic\")\n",
    "        cls_tokens = model.cls_token.expand(B, -1, -1)\n",
    "        x = torch.cat((cls_tokens, x), dim=1)\n",
    "        x = x + model.pos_embed\n",
    "        x = model.pos_drop(x)\n",
    "        for blk in model.blocks:\n",
    "            x = blk(x)\n",
    "            intermediate_features.append(x)\n",
    "    else: \n",
    "        print(\"Custom\")\n",
    "        x = model._pos_embed(x)\n",
    "        x = model.patch_drop(x)\n",
    "        for blk in model.blocks:\n",
    "            x = blk(x)\n",
    "            intermediate_features.append(x)\n",
    "    x = model.norm(x)\n",
    "    return x, intermediate_features\n",
    "\n",
    "# Function to compare features layerwise\n",
    "def compare_features_layerwise(model1, model2, batch):\n",
    "    batch = batch.to(device)\n",
    "    _, features1 = forward_with_intermediate(model1.to(\"cpu\"), batch.to(\"cpu\"), True)\n",
    "    _, features2 = forward_with_intermediate(model2.to(\"cpu\"), batch.to(\"cpu\"), False)\n",
    "    \n",
    "    cos = nn.CosineSimilarity(dim=-1)  # Change dim to -1 for last dimension\n",
    "    similarities = []\n",
    "    \n",
    "    for f1, f2 in zip(features1, features2):\n",
    "        # print(f1.shape, f2.shape)\n",
    "        # Flatten the features except for the batch dimension\n",
    "        f1_flat = f1.view(f1.size(0), -1)\n",
    "        f2_flat = f2.view(f2.size(0), -1)\n",
    "        \n",
    "        # Compare the flattened features\n",
    "        similarity = cos(f1_flat, f2_flat).mean()\n",
    "        similarities.append(similarity.item())\n",
    "    \n",
    "    return similarities\n",
    "\n",
    "# # # Make sure both models are on the same device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "best_model = best_model.to(device)\n",
    "vit = vit.to(device)\n",
    "best_model.eval()\n",
    "vit.eval()\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "for i in range(1):\n",
    "    # Compare features layerwise\n",
    "    with torch.no_grad():\n",
    "        similarities = compare_features_layerwise(best_model.to(device), vit.to(device), batch.to(device))\n",
    "        av = (np.mean(similarities))\n",
    "        print(av)\n",
    "        # # Print similarities for each layer\n",
    "        for i, sim in enumerate(similarities):\n",
    "            print(f\"Layer {i+1} similarity: {sim:.4f}\")\n",
    "\n",
    "\n",
    "        plt.figure(figsize=(6, 6))\n",
    "        plt.ylim(-1, 2) \n",
    "        plt.plot(range(1, len(similarities) + 1), similarities, marker='o')\n",
    "        plt.title(f'Layerwise Cosine Similarity factor={FACTORS}, templates={TEMPLATES} num_cf = {num_cf}Average Sim: {av:.3f}')\n",
    "        plt.xlabel(\"Layer\")\n",
    "        plt.ylabel(\"Cosine Similarity\")\n",
    "        plt.grid(True)\n",
    "        plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
