{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we provide the reconstruction pipeline code for the RECAST network, where RECAST module has been used in the MLP layers of a Vision Transfer. A similar pipeline can be used for other types of networks as well (RECAST applied to the attention layers, CNN layers etc.) by changing the [model definition](models) and the forward pass in `reconstruction_loss_dynamic()` accordingly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import init\n",
    "import torch.nn.functional as F\n",
    "import gc\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "manualSeed = 42\n",
    "DEFAULT_THRESHOLD = 5e-3\n",
    "\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "torch.cuda.manual_seed(manualSeed)\n",
    "np.random.seed(manualSeed)\n",
    "cudnn.benchmark = False\n",
    "torch.backends.cudnn.enabled = False\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device: \", device)\n",
    "import gc\n",
    "FACTORS = 6\n",
    "TEMPLATES = 2\n",
    "MULT = 1\n",
    "num_cf = 2\n",
    "\n",
    "\n",
    "def calculate_parameters(model):\n",
    "    attention_params = 0\n",
    "    template_params = 0\n",
    "    coefficients_params = 0\n",
    "    mlp_params = 0\n",
    "    total_params = sum(p.numel() for p in model.parameters())\n",
    "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    classifer_params = sum(p.numel() for p in model.head.parameters())\n",
    "\n",
    "    for n, p in model.named_parameters():\n",
    "        if \".attn.\" in n:\n",
    "            attention_params += p.numel()\n",
    "            # print(\"Attention params: \", n, p.numel())\n",
    "        if \"template_banks\" in n:\n",
    "            template_params += p.numel()\n",
    "            # print(\"Template params: \", n, p.numel())\n",
    "        if \"mlp\" in n:\n",
    "            mlp_params += p.numel()\n",
    "            # print(\"MLP params: \", n, p.numel())\n",
    "        if \"coefficients\" in n:\n",
    "            coefficients_params += p.numel()\n",
    "            # print(\"Coefficients params: \", n, p.numel())\n",
    "\n",
    "    print(\"Classifier head: \", model.head)\n",
    "    # print(model.fc)\n",
    "    print(\n",
    "        f\"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}\"\n",
    "    )\n",
    "    print(f\"Attention parameters: {attention_params}\")\n",
    "    print(f\"Templates params: {template_params}\")\n",
    "    print(f\"MLP params: {mlp_params}\")\n",
    "    print(f\"Coefficients params: {coefficients_params}\")\n",
    "\n",
    "\n",
    "class MLPTemplateBank(nn.Module):\n",
    "    def __init__(self, num_templates, in_features, out_features):\n",
    "        super(MLPTemplateBank, self).__init__()\n",
    "        self.num_templates = num_templates\n",
    "        self.coefficient_shape = (num_templates, 1, 1)\n",
    "        templates = [\n",
    "            torch.Tensor(out_features, in_features) for _ in range(num_templates)\n",
    "        ]\n",
    "        for i in range(num_templates):\n",
    "            init.kaiming_normal_(templates[i])\n",
    "        self.templates = nn.Parameter(torch.stack(templates))\n",
    "\n",
    "    def forward(self, coefficients):\n",
    "        params = self.templates * coefficients\n",
    "        summed_params = torch.sum(params, dim=0)\n",
    "        return summed_params\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})\"\n",
    "\n",
    "\n",
    "class SharedMLP(nn.Module):\n",
    "    def __init__(\n",
    "        self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.0\n",
    "    ):\n",
    "        super(SharedMLP, self).__init__()\n",
    "        self.bank1 = None\n",
    "        self.bank2 = None\n",
    "\n",
    "        if bank1 != None and bank2 != None:\n",
    "            self.bank1 = bank1\n",
    "            self.bank2 = bank2\n",
    "            self.coefficients1 = nn.ParameterList(\n",
    "                [\n",
    "                    nn.Parameter(\n",
    "                        torch.zeros(bank1.coefficient_shape), requires_grad=True\n",
    "                    )\n",
    "                    for _ in range(num_cf)\n",
    "                ]\n",
    "            )\n",
    "            self.coefficients2 = nn.ParameterList(\n",
    "                [\n",
    "                    nn.Parameter(\n",
    "                        torch.zeros(bank2.coefficient_shape), requires_grad=True\n",
    "                    )\n",
    "                    for _ in range(num_cf)\n",
    "                ]\n",
    "            )\n",
    "            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))\n",
    "            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))\n",
    "\n",
    "        self.act = act_layer()\n",
    "        self.norm = nn.Identity()\n",
    "        self.drop = nn.Dropout(drop)\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        if self.bank1 != None:\n",
    "            for cf in self.coefficients1:\n",
    "                nn.init.orthogonal_(cf)\n",
    "        if self.bank2 != None:\n",
    "            for cf in self.coefficients2:\n",
    "                nn.init.orthogonal_(cf)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.bank1 != None:\n",
    "            weight1 = []\n",
    "            for c in self.coefficients1:\n",
    "                w = self.bank1(c)\n",
    "                weight1.append(w)\n",
    "            weights1 = torch.stack(weight1).mean(0)\n",
    "        if self.bank2 != None:\n",
    "            weight2 = []\n",
    "            for c in self.coefficients2:\n",
    "                w = self.bank2(c)\n",
    "                weight2.append(w)\n",
    "            weights2 = torch.stack(weight2).mean(0)\n",
    "\n",
    "        x = F.linear(x, weights1, self.bias1)\n",
    "        x = self.act(x)\n",
    "        x = self.norm(x)\n",
    "        x = F.linear(x, weights2, self.bias2)\n",
    "        x = self.drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# original timm module for vision transformer\n",
    "class Attention(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads=6,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        attn_drop=0.0,\n",
    "        proj_drop=0.0,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = dim // num_heads\n",
    "        self.scale = qk_scale or self.head_dim**-0.5\n",
    "\n",
    "        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n",
    "        self.attn_drop = nn.Dropout(attn_drop)\n",
    "        self.proj = nn.Linear(dim, dim)\n",
    "        self.proj_drop = nn.Dropout(proj_drop)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, N, C = x.shape\n",
    "        qkv = self.qkv(x).chunk(3, dim=-1)\n",
    "        q, k, v = qkv[0], qkv[1], qkv[2]\n",
    "        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n",
    "        attn = (q @ k.transpose(-2, -1)) * self.scale\n",
    "        attn = F.softmax(attn, dim=-1)  # attention proba\n",
    "        attn = self.attn_drop(attn)\n",
    "        x = attn @ v  # attention output\n",
    "        x = x.transpose(1, 2).reshape(B, N, C)\n",
    "        x = self.proj(x)\n",
    "        x = self.proj_drop(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        dim,\n",
    "        num_heads,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        drop=0.0,\n",
    "        attn_drop=0.0,\n",
    "        drop_path=0.0,\n",
    "        act_layer=nn.GELU,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "        bank1=None,\n",
    "        bank2=None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.norm1 = norm_layer(dim)\n",
    "        self.attn = Attention(\n",
    "            dim,\n",
    "            num_heads=num_heads,\n",
    "            qkv_bias=qkv_bias,\n",
    "            qk_scale=qk_scale,\n",
    "            attn_drop=attn_drop,\n",
    "            proj_drop=drop,\n",
    "        )\n",
    "        self.ls1 = nn.Identity()\n",
    "        self.ls2 = nn.Identity()\n",
    "        self.drop_path = nn.Identity()\n",
    "        self.norm2 = norm_layer(dim)\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "        self.mlp = SharedMLP(\n",
    "            bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))\n",
    "        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))\n",
    "        return x\n",
    "\n",
    "\n",
    "class PatchEmbed(nn.Module):\n",
    "    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n",
    "        super().__init__()\n",
    "        img_size = (img_size, img_size)\n",
    "        patch_size = (patch_size, patch_size)\n",
    "        self.img_size = img_size\n",
    "        self.patch_size = patch_size\n",
    "        self.proj = nn.Conv2d(\n",
    "            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0\n",
    "        )\n",
    "        self.norm = nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.proj(x)\n",
    "        x = nn.Flatten(start_dim=2, end_dim=3)(x).permute(0, 2, 1)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class VisionTransformer(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_chans=3,\n",
    "        num_classes=1000,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        mlp_ratio=4.0,\n",
    "        qkv_bias=False,\n",
    "        qk_scale=None,\n",
    "        drop_rate=0.0,\n",
    "        attn_drop_rate=0.0,\n",
    "        drop_path_rate=0.0,\n",
    "        norm_layer=nn.LayerNorm,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.img_size = img_size\n",
    "        self.dim = embed_dim\n",
    "        self.num_classes = num_classes\n",
    "        self.embed_dim = embed_dim\n",
    "        self.patch_size = patch_size\n",
    "        self.num_features = self.embed_dim\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.num_patches = (img_size // patch_size) ** 2\n",
    "        self.num_prefix_tokens = 1\n",
    "        self.has_class_token = True\n",
    "        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))\n",
    "        self.patch_embed = PatchEmbed(\n",
    "            img_size=img_size,\n",
    "            patch_size=patch_size,\n",
    "            in_chans=in_chans,\n",
    "            embed_dim=embed_dim,\n",
    "        )\n",
    "\n",
    "        num_patches = (self.img_size // self.patch_size) ** 2\n",
    "        print(\"Num patches: \", num_patches)\n",
    "        embed_len = num_patches + self.num_prefix_tokens\n",
    "        self.pos_embed = nn.Parameter(\n",
    "            torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim) * 0.02,\n",
    "            requires_grad=True,\n",
    "        )\n",
    "\n",
    "        self.pos_drop = nn.Dropout(p=drop_rate)\n",
    "        self.patch_drop = nn.Identity()\n",
    "        self.fc_norm = nn.Identity()\n",
    "        self.head_drop = nn.Dropout(drop_rate)\n",
    "\n",
    "        self.num_groups = FACTORS\n",
    "        self.num_layers_in_group = (\n",
    "            depth // self.num_groups\n",
    "        ) \n",
    "        print(\"Num layers in group: \", self.num_layers_in_group)\n",
    "        self.num_templates = TEMPLATES\n",
    "        mlp_hidden_dim = int(embed_dim * mlp_ratio)\n",
    "        self.template_banks1 = nn.ModuleList(\n",
    "            [\n",
    "                MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim)\n",
    "                for _ in range(self.num_groups)\n",
    "            ]\n",
    "        )\n",
    "        self.template_banks2 = nn.ModuleList(\n",
    "            [\n",
    "                MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim)\n",
    "                for _ in range(self.num_groups)\n",
    "            ]\n",
    "        )\n",
    "        self.depth = depth\n",
    "\n",
    "        dpr = [\n",
    "            x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)\n",
    "        ]  # stochastic depth decay rule\n",
    "        self.blocks = nn.ModuleList()\n",
    "        for i in range(depth):\n",
    "            group_idx = i // self.num_layers_in_group\n",
    "            print(group_idx)\n",
    "            bank1 = self.template_banks1[group_idx]\n",
    "            bank2 = self.template_banks2[group_idx]\n",
    "            self.blocks.append(\n",
    "                Block(\n",
    "                    dim=self.embed_dim,\n",
    "                    num_heads=num_heads,\n",
    "                    mlp_ratio=mlp_ratio,\n",
    "                    qkv_bias=qkv_bias,\n",
    "                    qk_scale=qk_scale,\n",
    "                    drop=drop_rate,\n",
    "                    attn_drop=attn_drop_rate,\n",
    "                    drop_path=dpr[group_idx],\n",
    "                    norm_layer=norm_layer,\n",
    "                    bank1=bank1,\n",
    "                    bank2=bank2,\n",
    "                )\n",
    "            )\n",
    "        print(f\"Num blocks: {len(self.blocks)}\")\n",
    "        self.norm = norm_layer(self.embed_dim)\n",
    "        self.head = (\n",
    "            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n",
    "        )\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                if m.bias is not None:\n",
    "                    nn.init.normal_(m.bias, std=1e-6)\n",
    "            elif isinstance(m, nn.LayerNorm):\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "                nn.init.constant_(m.weight, 1.0)\n",
    "\n",
    "    def _pos_embed(self, x):\n",
    "        to_cat = []\n",
    "        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))\n",
    "        x = torch.cat(to_cat + [x], dim=1)\n",
    "        x = x + self.pos_embed\n",
    "        return self.pos_drop(x)\n",
    "\n",
    "    def forward_features(self, x):\n",
    "        B = x.shape[0]\n",
    "        x = self.patch_embed(x)\n",
    "        x = self._pos_embed(x)\n",
    "        x = self.patch_drop(x)\n",
    "        for blk in self.blocks:\n",
    "            x = blk(x)\n",
    "        x = self.norm(x)\n",
    "        return x\n",
    "\n",
    "    def forward_head(self, x):\n",
    "        x = x[:, 0]\n",
    "        x = self.fc_norm(x)\n",
    "        x = self.head_drop(x)\n",
    "        x = self.head(x)\n",
    "        return x\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.forward_features(x)\n",
    "        head_output = self.forward_head(features)\n",
    "        return head_output\n",
    "\n",
    "my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=10, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)\n",
    "calculate_parameters(my_model)\n",
    "\n",
    "my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)\n",
    "calculate_parameters(my_model)\n",
    "\n",
    "from pytorch_model_summary import summary\n",
    "my_model.eval()\n",
    "print(summary(my_model, torch.zeros(1, 3,224,224), show_hierarchical=True))\n",
    "del my_model \n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load existing weight from timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load timm vit weights into the custom model\n",
    "import timm \n",
    "vit = timm.create_model(\"vit_small_patch16_224\", pretrained=True)\n",
    "# freeze the weights of the timm model\n",
    "for param in vit.parameters():\n",
    "    param.requires_grad = False\n",
    "FOUND = []\n",
    "# load the weights from the timm model to the custom model\n",
    "def load_weights(model, timm_model):\n",
    "    model_dict = model.state_dict()\n",
    "    timm_dict = timm_model.state_dict()\n",
    "    new_dict = {}\n",
    "    for k, v in timm_dict.items():\n",
    "        if k in model_dict:\n",
    "            # check the shape of the weights\n",
    "            if model_dict[k].shape == timm_dict[k].shape:\n",
    "                new_dict[k] = v \n",
    "                FOUND.append(k)\n",
    "        else:\n",
    "            if \"mlp.fc1.bias\" in k:\n",
    "                new_k = k.replace(\"mlp.fc1.bias\", \"mlp.bias1\") # name in model_dict\n",
    "                print(f\"{k} --> {new_k}\")\n",
    "                if model_dict[new_k].shape == timm_dict[k].shape:\n",
    "                    new_dict[new_k] = v\n",
    "                    FOUND.append(new_k)\n",
    "            elif \"mlp.fc2.bias\" in k:\n",
    "                new_k = k.replace(\"mlp.fc2.bias\", \"mlp.bias2\",)\n",
    "                print(f\"{k} --> {new_k}\")\n",
    "                if model_dict[new_k].shape == timm_dict[k].shape:\n",
    "                    new_dict[new_k] = v\n",
    "                    FOUND.append(new_k)\n",
    "            else:\n",
    "                print(\"Key not found: \", k)\n",
    "    model_dict.update(new_dict)\n",
    "    model.load_state_dict(model_dict)\n",
    "    return model\n",
    "\n",
    "tpbvit = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)\n",
    "\n",
    "tpbvit = load_weights(tpbvit, vit)\n",
    "tpbvit_state_dict = list(tpbvit.state_dict().keys()) \n",
    "print(\"Found: \", len(FOUND))\n",
    "target_params = []\n",
    "\n",
    "for k in tpbvit_state_dict:\n",
    "    if k not in FOUND:\n",
    "        target_params.append(k)\n",
    "\n",
    "for name, param in tpbvit.named_parameters():\n",
    "    if name not in target_params or \"coefficients1\" in name or \"coefficients2\" in name:\n",
    "        param.requires_grad = False\n",
    "\n",
    "calculate_parameters(tpbvit) \n",
    "\n",
    "for n, p in tpbvit.named_parameters():\n",
    "    if p.requires_grad:\n",
    "        print(n, p.requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main Reconstruction Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tensorboardX import SummaryWriter\n",
    "import tqdm\n",
    "import copy\n",
    "class GroupWeightReconstructor(nn.Module):\n",
    "    def __init__(self, in_features, out_features, num_templates, hidden_dim, bank):\n",
    "        super(GroupWeightReconstructor, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.num_templates = num_templates\n",
    "        self.num_cf = int(self.num_templates * MULT)\n",
    "        hidden_dim =  hidden_dim \n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(in_features * out_features, hidden_dim),\n",
    "            nn.LeakyReLU(),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.LeakyReLU(),\n",
    "            nn.Linear(hidden_dim, num_templates * self.num_cf),\n",
    "            # nn.LeakyReLU()\n",
    "        )\n",
    "        self.bank = bank\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.shape[0]\n",
    "        coefficients = self.encoder(x.view(batch_size, -1))\n",
    "        coefficients = coefficients.view(batch_size, self.num_cf, self.num_templates, 1, 1)\n",
    "        \n",
    "        reconstructed_weights = []\n",
    "        for i in range(batch_size):\n",
    "            weight_list = [self.bank(coeff) for coeff in coefficients[i]]\n",
    "            reconstructed = torch.stack(weight_list).mean(0)\n",
    "            reconstructed_weights.append(reconstructed)\n",
    "        \n",
    "        return torch.stack(reconstructed_weights), coefficients\n",
    "\n",
    "def train_group_reconstructor(model, fc_weights, num_epochs=2500, learning_rate=2e-3, train_bank=True):\n",
    "    params_to_optimize = model.parameters()\n",
    "    optimizer = torch.optim.RMSprop(params_to_optimize, lr=learning_rate)\n",
    "    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)\n",
    "    similarity_metric = nn.CosineSimilarity(dim=1)  # Changed to dim=1 for batch processing\n",
    "    best_sim = -float('inf')\n",
    "    writer = SummaryWriter()\n",
    "    best_model = None\n",
    "    for epoch in tqdm.tqdm(range(num_epochs), desc=\"Training Group Reconstructor\"):\n",
    "        optimizer.zero_grad()\n",
    "        recon_batch, _ = model(fc_weights)\n",
    "        loss = F.smooth_l1_loss(recon_batch, fc_weights)\n",
    "        sim = similarity_metric(recon_batch.view(recon_batch.size(0), -1), \n",
    "                                fc_weights.view(fc_weights.size(0), -1)).mean()\n",
    "        \n",
    "        if sim > best_sim:\n",
    "            best_sim = sim\n",
    "            best_model = copy.deepcopy(model)\n",
    "        \n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0)\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        writer.add_scalar('Loss/train', loss.item(), epoch)\n",
    "        writer.add_scalar('Similarity/train', sim.item(), epoch)\n",
    "        if (epoch + 1) % 500 == 0:\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Similarity: {sim.item():.4f}')\n",
    "            print(\"Best sim: \", best_sim)\n",
    "\n",
    "    writer.close()\n",
    "    return best_model\n",
    "\n",
    "def prepare_and_train_reconstructors(timm_model, custom_model, train_banks=False):\n",
    "    reconstructors = []\n",
    "    overall_similarity = 0.0\n",
    "    similarity_metric = nn.CosineSimilarity(dim=1)\n",
    "\n",
    "    for group in range(custom_model.num_groups):\n",
    "        print(f\"\\nTraining reconstructors for group {group}\")\n",
    "        bank1 = custom_model.template_banks1[group]\n",
    "        bank2 = custom_model.template_banks2[group]\n",
    "\n",
    "        # Collect weights for all layers in this group\n",
    "        fc1_weights = []\n",
    "        fc2_weights = []\n",
    "        start_layer = group * custom_model.num_layers_in_group\n",
    "        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)\n",
    "        for i in range(start_layer, end_layer):\n",
    "            fc1_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc1.weight'])\n",
    "            fc2_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc2.weight'])\n",
    "        \n",
    "        fc1_weights = torch.stack(fc1_weights)\n",
    "        fc2_weights = torch.stack(fc2_weights)\n",
    "\n",
    "        hidden_dim = 8  # Adjust as needed\n",
    "\n",
    "        rec1 = GroupWeightReconstructor(fc1_weights.shape[2], fc1_weights.shape[1], custom_model.num_templates, hidden_dim, bank1)\n",
    "        rec2 = GroupWeightReconstructor(fc2_weights.shape[2], fc2_weights.shape[1], custom_model.num_templates, hidden_dim, bank2)\n",
    "\n",
    "        print(f\"Training Reconstructor for group {group} FC1 ....\")\n",
    "        trained_rec1 = train_group_reconstructor(rec1, fc1_weights, train_bank=train_banks)\n",
    "        print(f\"Training Reconstructor for group {group} FC2 ....\")\n",
    "        trained_rec2 = train_group_reconstructor(rec2, fc2_weights, train_bank=train_banks)\n",
    "\n",
    "        reconstructors.append((trained_rec1, trained_rec2))\n",
    "\n",
    "        # Compute final similarity for this group\n",
    "        with torch.no_grad():\n",
    "            recon_fc1, _ = trained_rec1(fc1_weights)\n",
    "            recon_fc2, _ = trained_rec2(fc2_weights)\n",
    "            sim_fc1 = similarity_metric(recon_fc1.view(recon_fc1.size(0), -1), \n",
    "                                        fc1_weights.view(fc1_weights.size(0), -1)).mean()\n",
    "            sim_fc2 = similarity_metric(recon_fc2.view(recon_fc2.size(0), -1), \n",
    "                                        fc2_weights.view(fc2_weights.size(0), -1)).mean()\n",
    "            current_similarity = (sim_fc1 + sim_fc2) / 2\n",
    "            print(f\"Group {group} av similarity: {current_similarity.item()}\")\n",
    "            overall_similarity += current_similarity * (end_layer - start_layer)\n",
    "\n",
    "    overall_similarity /= custom_model.depth\n",
    "    print(f\"\\nOverall average similarity across all blocks: {overall_similarity:.4f}\")\n",
    "\n",
    "    return reconstructors, custom_model\n",
    "\n",
    "def set_custom_model_weights(custom_model, trained_reconstructors, timm_model):\n",
    "    device = next(custom_model.parameters()).device\n",
    "    assert len(trained_reconstructors) == custom_model.num_groups, \"Mismatch in number of groups\"\n",
    "    \n",
    "    similarity_metric = nn.CosineSimilarity(dim=0)\n",
    "    \n",
    "    # First, set all weights\n",
    "    for group in range(custom_model.num_groups):\n",
    "        start_layer = group * custom_model.num_layers_in_group\n",
    "        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)\n",
    "        \n",
    "        fc1_weights = []\n",
    "        fc2_weights = []\n",
    "        for i in range(start_layer, end_layer):\n",
    "            fc1_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc1.weight'].to(device))\n",
    "            fc2_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc2.weight'].to(device))\n",
    "        \n",
    "        fc1_weights = torch.stack(fc1_weights)\n",
    "        fc2_weights = torch.stack(fc2_weights)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            _, coefficients1 = trained_reconstructors[group][0](fc1_weights)\n",
    "            _, coefficients2 = trained_reconstructors[group][1](fc2_weights)\n",
    "        \n",
    "        for i, layer_idx in enumerate(range(start_layer, end_layer)):\n",
    "            for j in range(len(custom_model.blocks[layer_idx].mlp.coefficients1)):\n",
    "                custom_model.blocks[layer_idx].mlp.coefficients1[j].data = coefficients1[i, j].to(device)\n",
    "                custom_model.blocks[layer_idx].mlp.coefficients2[j].data = coefficients2[i, j].to(device)\n",
    "            \n",
    "            custom_model.blocks[layer_idx].mlp.bias1.data = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc1.bias'].to(device)\n",
    "            custom_model.blocks[layer_idx].mlp.bias2.data = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc2.bias'].to(device)\n",
    "\n",
    "    # Now, calculate similarities for all layers\n",
    "    overall_similarity_fc1 = 0.0\n",
    "    overall_similarity_fc2 = 0.0\n",
    "    num_layers = 0\n",
    "\n",
    "    for group in range(custom_model.num_groups):\n",
    "        start_layer = group * custom_model.num_layers_in_group\n",
    "        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)\n",
    "        \n",
    "        for layer_idx in range(start_layer, end_layer):\n",
    "            fc1_weight = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc1.weight'].to(device)\n",
    "            fc2_weight = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc2.weight'].to(device)\n",
    "\n",
    "            # Reconstruct weights\n",
    "            weight1_list = [custom_model.template_banks1[group](c) for c in custom_model.blocks[layer_idx].mlp.coefficients1]\n",
    "            weight2_list = [custom_model.template_banks2[group](c) for c in custom_model.blocks[layer_idx].mlp.coefficients2]\n",
    "            \n",
    "            my_weight1 = torch.stack(weight1_list).mean(0)\n",
    "            my_weight2 = torch.stack(weight2_list).mean(0)\n",
    "\n",
    "            sim1 = similarity_metric(fc1_weight.view(-1), my_weight1.view(-1))\n",
    "            sim2 = similarity_metric(fc2_weight.view(-1), my_weight2.view(-1))\n",
    "            \n",
    "            overall_similarity_fc1 += sim1.item()\n",
    "            overall_similarity_fc2 += sim2.item()\n",
    "            num_layers += 1\n",
    "\n",
    "            print(f\"Block {layer_idx} - FC1 Similarity: {sim1.item():.4f}, FC2 Similarity: {sim2.item():.4f}\")\n",
    "\n",
    "    avg_similarity_fc1 = overall_similarity_fc1 / num_layers\n",
    "    avg_similarity_fc2 = overall_similarity_fc2 / num_layers\n",
    "    print(f\"\\nOverall average similarity - FC1: {avg_similarity_fc1:.4f}, FC2: {avg_similarity_fc2:.4f}\")\n",
    "\n",
    "    return custom_model\n",
    "\n",
    "# Usage\n",
    "trained_reconstructors, custom_model = prepare_and_train_reconstructors(vit, tpbvit, train_banks=True)\n",
    "custom_model = set_custom_model_weights(custom_model, trained_reconstructors, vit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Cosine Similarity and Feature Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# Load the batch from disk\n",
    "batch = torch.load(\"<batch.pt>\")\n",
    "\n",
    "# Modified forward function to return intermediate features\n",
    "def forward_with_intermediate(model, x, cust = True):\n",
    "    B = x.shape[0]\n",
    "    x = model.patch_embed(x)\n",
    "    intermediate_features = []\n",
    "    if not cust:\n",
    "        print(\"classic\")\n",
    "        cls_tokens = model.cls_token.expand(B, -1, -1)\n",
    "        x = torch.cat((cls_tokens, x), dim=1)\n",
    "        x = x + model.pos_embed\n",
    "        x = model.pos_drop(x)\n",
    "        for blk in model.blocks:\n",
    "            x = blk(x)\n",
    "            intermediate_features.append(x)\n",
    "    else: \n",
    "        print(\"Custom\")\n",
    "        x = model._pos_embed(x)\n",
    "        x = model.patch_drop(x)\n",
    "        for blk in model.blocks:\n",
    "            x = blk(x)\n",
    "            intermediate_features.append(x)\n",
    "    x = model.norm(x)\n",
    "    return x, intermediate_features\n",
    "\n",
    "# Function to compare features layerwise\n",
    "def compare_features_layerwise(model1, model2, batch):\n",
    "    batch = batch.to(device)\n",
    "    _, features1 = forward_with_intermediate(model1, batch, True)\n",
    "    _, features2 = forward_with_intermediate(model2, batch, False)\n",
    "    \n",
    "    cos = nn.CosineSimilarity(dim=-1)  # Change dim to -1 for last dimension\n",
    "    similarities = []\n",
    "    \n",
    "    for f1, f2 in zip(features1, features2):\n",
    "        # Flatten the features except for the batch dimension\n",
    "        f1_flat = f1.view(f1.size(0), -1)\n",
    "        f2_flat = f2.view(f2.size(0), -1)\n",
    "        \n",
    "        # Compare the flattened features\n",
    "        similarity = cos(f1_flat, f2_flat).mean()\n",
    "        similarities.append(similarity.item())\n",
    "    \n",
    "    return similarities\n",
    "\n",
    "\n",
    "for i in range(1):\n",
    "    # Compare features layerwise\n",
    "    with torch.no_grad():\n",
    "        similarities = compare_features_layerwise(custom_model.to(device), vit.to(device), batch)\n",
    "        av = (np.mean(similarities))\n",
    "        print(av)\n",
    "        # # Print similarities for each layer\n",
    "        for i, sim in enumerate(similarities):\n",
    "            print(f\"Layer {i+1} similarity: {sim:.4f}\")\n",
    "\n",
    "\n",
    "        plt.figure(figsize=(6, 6))\n",
    "        plt.ylim(-1, 2) \n",
    "        plt.plot(range(1, len(similarities) + 1), similarities, marker='o')\n",
    "        plt.title(f'Layerwise Cosine Similarity factor={FACTORS}, templates={TEMPLATES} Average Sim: {av}')\n",
    "        plt.xlabel(\"Layer\")\n",
    "        plt.ylabel(\"Cosine Similarity\")\n",
    "        plt.grid(True)\n",
    "        plt.show()\n",
    "        \n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
