{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "\n",
    "manualSeed = 42\n",
    "DEFAULT_THRESHOLD = 5e-3\n",
    "\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "torch.cuda.manual_seed(manualSeed)\n",
    "np.random.seed(manualSeed)\n",
    "cudnn.benchmark = False\n",
    "torch.backends.cudnn.enabled = False\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device: \", device)\n",
    "GEN_KERNEL = 3\n",
    "num_cf = 2\n",
    "\n",
    "\n",
    "class TemplateBank(nn.Module):\n",
    "    def __init__(self, num_templates, in_planes, out_planes, kernel_size):\n",
    "        super(TemplateBank, self).__init__()\n",
    "        self.in_planes = in_planes\n",
    "        self.out_planes = out_planes\n",
    "        self.coefficient_shape = (num_templates, 1, 1, 1, 1)\n",
    "        self.kernel_size = kernel_size\n",
    "        templates = [\n",
    "            torch.Tensor(out_planes, in_planes, kernel_size, kernel_size)\n",
    "            for _ in range(num_templates)\n",
    "        ]\n",
    "        for i in range(num_templates):\n",
    "            nn.init.kaiming_normal_(templates[i])\n",
    "        self.templates = nn.Parameter(\n",
    "            torch.stack(templates)\n",
    "        )  # this is what we will freeze later\n",
    "\n",
    "    def forward(self, coefficients):\n",
    "        weights = (self.templates * coefficients).sum(0)\n",
    "        return weights\n",
    "\n",
    "    def __repr__(self):\n",
    "        return (\n",
    "            self.__class__.__name__\n",
    "            + \" (\"\n",
    "            + \"num_templates=\"\n",
    "            + str(self.coefficient_shape[0])\n",
    "            + \", kernel_size=\"\n",
    "            + str(self.kernel_size)\n",
    "            + \")\"\n",
    "            + \", in_planes=\"\n",
    "            + str(self.in_planes)\n",
    "            + \", out_planes=\"\n",
    "            + str(self.out_planes)\n",
    "        )\n",
    "\n",
    "\n",
    "class SConv2d(nn.Module):\n",
    "    # TARGET MODULE\n",
    "    def __init__(self, bank, stride=1, padding=1):\n",
    "        super(SConv2d, self).__init__()\n",
    "        self.stride = stride\n",
    "        self.padding = padding\n",
    "        self.bank = bank\n",
    "        self.num_templates = bank.coefficient_shape[0]\n",
    "\n",
    "        self.coefficients = nn.ParameterList(\n",
    "            [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]\n",
    "        )\n",
    "\n",
    "    def forward(self, input):\n",
    "        param_list = []\n",
    "        for i in range(len(self.coefficients)):\n",
    "            params = self.bank(self.coefficients[i])\n",
    "            param_list.append(params)\n",
    "\n",
    "        final_params = torch.stack(param_list).mean(0)\n",
    "        return F.conv2d(input, final_params, stride=self.stride, padding=self.padding)\n",
    "\n",
    "\n",
    "class CustomResidualBlock(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels,\n",
    "        out_channels,\n",
    "        stride=1,\n",
    "        downsample=None,\n",
    "        bank1=None,\n",
    "        bank2=None,\n",
    "    ):\n",
    "        super(CustomResidualBlock, self).__init__()\n",
    "        self.bank1 = bank1\n",
    "        self.bank2 = bank2\n",
    "\n",
    "        # Ensure padding is always 1 for 3x3 convolutions\n",
    "        if self.bank1 and self.bank2:\n",
    "            self.conv1 = SConv2d(bank1, stride=stride, padding=1)\n",
    "            self.conv2 = SConv2d(bank2, stride=1, padding=1)\n",
    "        else:\n",
    "            self.conv1 = nn.Conv2d(\n",
    "                in_channels,\n",
    "                out_channels,\n",
    "                kernel_size=3,\n",
    "                stride=stride,\n",
    "                padding=1,\n",
    "                bias=False,\n",
    "            )\n",
    "            self.conv2 = nn.Conv2d(\n",
    "                out_channels,\n",
    "                out_channels,\n",
    "                kernel_size=3,\n",
    "                stride=1,\n",
    "                padding=1,\n",
    "                bias=False,\n",
    "            )\n",
    "\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "\n",
    "        # Implement downsample as 1x1 convolution when needed\n",
    "        if stride != 1 or in_channels != out_channels:\n",
    "            self.downsample = nn.Sequential(\n",
    "                nn.Conv2d(\n",
    "                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False\n",
    "                ),\n",
    "                nn.BatchNorm2d(out_channels),\n",
    "            )\n",
    "        else:\n",
    "            self.downsample = None\n",
    "\n",
    "        # Initialize weights\n",
    "        self._init_weights()\n",
    "\n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, SConv2d):\n",
    "                for coefficient in m.coefficients:\n",
    "                    nn.init.orthogonal_(coefficient)\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNetTPB(nn.Module):\n",
    "    def __init__(self, block, layers, num_classes=10):\n",
    "        super(ResNetTPB, self).__init__()\n",
    "        self.inplanes = 64\n",
    "        self.layers = layers\n",
    "        self.conv1 = nn.Conv2d(\n",
    "            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False\n",
    "        )\n",
    "        self.bn1 = nn.BatchNorm2d(self.inplanes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(512, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(\n",
    "                    self.inplanes, planes, kernel_size=1, stride=stride, bias=False\n",
    "                ),\n",
    "                nn.BatchNorm2d(planes),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "\n",
    "        # DYNAMICALLY CALCULATE THE NUMBER OF TEMPLATES TO USE FOR EACH RESIDUAL BLOCK\n",
    "        # Calculate parameters for remaining blocks\n",
    "        params_per_conv = 9 * planes * planes\n",
    "        params_per_template = 9 * planes * planes\n",
    "        num_templates1 = max(\n",
    "            1, int((blocks - 1) * params_per_conv / params_per_template)\n",
    "        )\n",
    "        num_templates2 = (\n",
    "            num_templates1  # You could potentially use a different calculation here\n",
    "        )\n",
    "\n",
    "        print(\n",
    "            f\"Layer with {planes} planes, {blocks} blocks, using {num_templates1} templates for conv1 and {num_templates2} for conv2\"\n",
    "        )\n",
    "\n",
    "        # Create separate TemplateBanks for conv1 and conv2\n",
    "        tpbank1 = TemplateBank(num_templates1, planes, planes, GEN_KERNEL)\n",
    "        tpbank2 = TemplateBank(num_templates2, planes, planes, GEN_KERNEL)\n",
    "\n",
    "        self.inplanes = planes\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(\n",
    "                block(\n",
    "                    in_channels=self.inplanes,\n",
    "                    out_channels=planes,\n",
    "                    bank1=tpbank1,\n",
    "                    bank2=tpbank2,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Weight for non-target module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet34, resnet18\n",
    "resnet34 = resnet34(pretrained=True)\n",
    "# resnet34 = resnet18(pretrained=True)\n",
    "\n",
    "FOUND = []\n",
    "\n",
    "def load_weights(current_model, target_model):\n",
    "    current_dict = current_model.state_dict()\n",
    "    target_dict = target_model.state_dict()\n",
    "\n",
    "    new_dict = {}\n",
    "    for k, v in target_dict.items():\n",
    "        if k in current_dict:\n",
    "            if current_dict[k].shape == target_dict[k].shape:\n",
    "                new_dict[k] = v\n",
    "                FOUND.append(k)\n",
    "            else:\n",
    "                print(f\"Shape mismatch for key: {k}\")\n",
    "        else:\n",
    "            print(f\"Key not found: {k}\")\n",
    "\n",
    "    current_dict.update(new_dict)\n",
    "    current_model.load_state_dict(current_dict)\n",
    "    return current_model\n",
    "\n",
    "my_model = ResNetTPB(CustomResidualBlock, [3,4,6,3], num_classes=1000)\n",
    "my_model = load_weights(my_model, resnet34)\n",
    "my_model_state = list(my_model.state_dict().keys())\n",
    "print(f\"Found: {len(FOUND)}\")\n",
    "print(f\"Total: {len(my_model_state)}\")\n",
    "\n",
    "target_params = []\n",
    "for k in my_model_state:\n",
    "    if k not in FOUND:\n",
    "        target_params.append(k)\n",
    "\n",
    "print(f\"Target params: {len(target_params)}\")\n",
    "for name, param in my_model.named_parameters():\n",
    "    if name in target_params:\n",
    "        param.requires_grad = True\n",
    "    else:\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reconstruction loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reconstruction_loss_dynamic(current_model, pretrained_model, criterion=nn.SmoothL1Loss(), w1_weight=3.5, w2_weight=2.5):\n",
    "    corr_state_dict = pretrained_model.state_dict()\n",
    "    loss_dict = {}\n",
    "    total_loss = 0.0\n",
    "    w1_loss = 0.0\n",
    "    w2_loss = 0.0\n",
    "    \n",
    "    # Determine the device of the current model\n",
    "    device = next(current_model.parameters()).device\n",
    "    \n",
    "    for layer_idx, layer in enumerate([current_model.layer1, current_model.layer2, current_model.layer3, current_model.layer4]):\n",
    "        for block_idx, block in enumerate(layer):\n",
    "            if isinstance(block, CustomResidualBlock) and block.bank1 is not None:\n",
    "                conv1_cf = block.conv1.coefficients\n",
    "                conv1_bank = block.bank1\n",
    "                weights1 = []\n",
    "                noise_std1 = 0.0\n",
    "                for c in conv1_cf:\n",
    "                    if current_model.training:\n",
    "                        noise = torch.randn_like(c) * noise_std1\n",
    "                        c = c + noise\n",
    "                    w = conv1_bank(c)\n",
    "                    weights1.append(w)\n",
    "        \n",
    "                _weights1 = torch.stack(weights1).mean(0)\n",
    "                corr_weight1 = corr_state_dict[f'layer{layer_idx+1}.{block_idx}.conv1.weight'].to(device)\n",
    "                w1_l = criterion(_weights1, corr_weight1) * w1_weight\n",
    "\n",
    "                loss_dict[f'layer{layer_idx+1}.{block_idx}.bank1.templates'] = w1_l\n",
    "                for i, cf in enumerate(conv1_cf):\n",
    "                    loss_dict[f'layer{layer_idx+1}.{block_idx}.conv1.coefficients.{i}'] = w1_l\n",
    "\n",
    "                w1_loss += w1_l.item()\n",
    "                total_loss += w1_l\n",
    "\n",
    "            if isinstance(block, CustomResidualBlock) and block.bank2 is not None:\n",
    "                conv2_cf = block.conv2.coefficients\n",
    "                conv2_bank = block.bank2\n",
    "                noise_std2 = 0.0\n",
    "                weights2 = []\n",
    "                for c in conv2_cf:\n",
    "                    if current_model.training:\n",
    "                        noise = torch.randn_like(c) * noise_std2\n",
    "                        c = c + noise\n",
    "                    w = conv2_bank(c)\n",
    "                    weights2.append(w)\n",
    "                _weights2 = torch.stack(weights2).mean(0)\n",
    "                corr_weight2 = corr_state_dict[f'layer{layer_idx+1}.{block_idx}.conv2.weight'].to(device)\n",
    "                w2_l = criterion(_weights2, corr_weight2) * w2_weight\n",
    "\n",
    "                loss_dict[f'layer{layer_idx+1}.{block_idx}.bank2.templates'] = w2_l\n",
    "                for i, cf in enumerate(conv2_cf):\n",
    "                    loss_dict[f'layer{layer_idx+1}.{block_idx}.conv2.coefficients.{i}'] = w2_l\n",
    "\n",
    "                w2_loss += w2_l.item()\n",
    "                total_loss += w2_l\n",
    "\n",
    "    return loss_dict, total_loss, w1_loss, w2_loss\n",
    "import copy\n",
    "optimizer = torch.optim.RMSprop(my_model.parameters(), lr=0.2)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)\n",
    "best_model = None\n",
    "best_model_loss = 1e9\n",
    "my_model.train()\n",
    "for epoch in range(3000):\n",
    "    loss_dict, total_loss, conv1_loss, conv2_loss = reconstruction_loss_dynamic(my_model, resnet34)\n",
    "    if total_loss < best_model_loss:\n",
    "        best_model_loss = total_loss.item()\n",
    "        best_model = copy.deepcopy(my_model)\n",
    "    total_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(my_model.parameters(), 1)\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    scheduler.step()\n",
    "    if epoch % 100 == 0:\n",
    "        print(f\"Epoch: {epoch} Total Loss: {total_loss.item()} \")\n",
    "        print(f\"Conv1 Loss: {conv1_loss} Conv2 Loss: {conv2_loss}\")\n",
    "        print(f\"LR: {scheduler.get_last_lr()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Reconstruction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Comparing Cosine Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_dynamic_reconstruction(current_model, pretrained_model):\n",
    "    similarity = 0.0\n",
    "    metric = nn.CosineSimilarity(dim=0)\n",
    "    corr_state_dict = pretrained_model.state_dict()\n",
    "    w1_loss = 0.0\n",
    "    w2_loss = 0.0\n",
    "    total_blocks = 0\n",
    "\n",
    "    for layer_idx, layer in enumerate([current_model.layer1, current_model.layer2, current_model.layer3, current_model.layer4], 1):\n",
    "        for block_idx, block in enumerate(layer):\n",
    "            w1_l = 0\n",
    "            w2_l = 0\n",
    "            \n",
    "            if block_idx == 0 or not isinstance(block, CustomResidualBlock):\n",
    "                # Handle first block or non-CustomResidualBlock\n",
    "                if hasattr(block, 'conv1') and hasattr(block.conv1, 'weight'):\n",
    "                    w1 = block.conv1.weight\n",
    "                    corr_w1 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv1.weight']\n",
    "                    w1_l = metric(w1.view(-1), corr_w1.view(-1))\n",
    "                    w1_loss += w1_l.item()\n",
    "\n",
    "                if hasattr(block, 'conv2') and hasattr(block.conv2, 'weight'):\n",
    "                    w2 = block.conv2.weight\n",
    "                    corr_w2 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv2.weight']\n",
    "                    w2_l = metric(w2.view(-1), corr_w2.view(-1))\n",
    "                    w2_loss += w2_l.item()\n",
    "\n",
    "            else:\n",
    "                # Handle CustomResidualBlock with template banks\n",
    "                if hasattr(block, 'bank1') and block.bank1 is not None:\n",
    "                    conv1_cf = block.conv1.coefficients\n",
    "                    conv1_bank = block.bank1\n",
    "                    weights1 = torch.stack([conv1_bank(cf) for cf in conv1_cf]).mean(0)\n",
    "                    corr_weight1 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv1.weight']\n",
    "                    w1_l = metric(weights1.view(-1), corr_weight1.view(-1))\n",
    "                    w1_loss += w1_l.item()\n",
    "\n",
    "                if hasattr(block, 'bank2') and block.bank2 is not None:\n",
    "                    conv2_cf = block.conv2.coefficients\n",
    "                    conv2_bank = block.bank2\n",
    "                    weights2 = torch.stack([conv2_bank(cf) for cf in conv2_cf]).mean(0)\n",
    "                    corr_weight2 = corr_state_dict[f'layer{layer_idx}.{block_idx}.conv2.weight']\n",
    "                    w2_l = metric(weights2.view(-1), corr_weight2.view(-1))\n",
    "                    w2_loss += w2_l.item()\n",
    "\n",
    "            print(f\"Layer: {layer_idx}, Block: {block_idx}, W1: {w1_l:.4f}, W2: {w2_l:.4f}\")\n",
    "            total_blocks += 1\n",
    "\n",
    "    avg_w1 = w1_loss / total_blocks\n",
    "    avg_w2 = w2_loss / total_blocks\n",
    "\n",
    "    print(f\"Average - W1: {avg_w1:.4f}, W2: {avg_w2:.4f}\")\n",
    "    return w1_loss, w2_loss\n",
    "\n",
    "similarity_dict = evaluate_dynamic_reconstruction(best_model, resnet34)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Comparing layerwise feature similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "calculate_parameters(my_model)\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "# Load the batch from disk\n",
    "batch = torch.load(\"<DATALOADER_BATCH_PATH>\")\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# Modified forward function to return intermediate features\n",
    "def forward_with_intermediate(model, x, is_custom=True):\n",
    "    if is_custom:\n",
    "        print(\"Using custom model\")\n",
    "    else:\n",
    "        print(\"Using default model\")\n",
    "    intermediate_features = []\n",
    "    x = model.conv1(x)\n",
    "    x = model.bn1(x)\n",
    "    x = model.relu(x)\n",
    "    x = model.maxpool(x)\n",
    "\n",
    "    layer1 = model.layer1(x)\n",
    "    intermediate_features.append(layer1)\n",
    "    layer2 = model.layer2(layer1)\n",
    "    intermediate_features.append(layer2)\n",
    "    layer3 = model.layer3(layer2)\n",
    "    intermediate_features.append(layer3)\n",
    "    layer4 = model.layer4(layer3)\n",
    "    intermediate_features.append(layer4)\n",
    "\n",
    "    x = model.avgpool(layer4)\n",
    "    x = torch.flatten(x, 1)\n",
    "    x = model.fc(x)\n",
    "\n",
    "    return x, intermediate_features\n",
    "\n",
    "# Function to compare features layerwise\n",
    "def compare_features_layerwise(model1, model2, batch):\n",
    "    batch = batch.to(\"cpu\")  # Assuming CPU for computation\n",
    "    _, features1 = forward_with_intermediate(model1.to(\"cpu\"), batch, is_custom=True)\n",
    "    _, features2 = forward_with_intermediate(model2.to(\"cpu\"), batch, is_custom=False)\n",
    "    \n",
    "    cos = nn.CosineSimilarity(dim=1)  # Change dim to 1 for channel dimension\n",
    "    similarities = []\n",
    "    \n",
    "    for f1, f2 in zip(features1, features2):\n",
    "        # Flatten the features except for the batch and channel dimensions\n",
    "        f1_flat = f1.view(f1.size(0), f1.size(1), -1)\n",
    "        f2_flat = f2.view(f2.size(0), f2.size(1), -1)\n",
    "        \n",
    "        # Compare the flattened features\n",
    "        similarity = cos(f1_flat.mean(dim=2), f2_flat.mean(dim=2)).mean()\n",
    "        similarities.append(similarity.item())\n",
    "    \n",
    "    return similarities\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "similarities = compare_features_layerwise(best_model, resnet34, batch)\n",
    "av_sim = sum(similarities) / len(similarities)\n",
    "print(f\"Average similarity: {av_sim:.4f}\")\n",
    "# Print or plot the similarities\n",
    "for i, sim in enumerate(similarities):\n",
    "    print(f\"Layer {i+1} similarity: {sim:.4f}\")\n",
    "\n",
    "# You can also plot these similarities\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(range(1, len(similarities) + 1), similarities, marker='o')\n",
    "plt.title(\"Layer-wise Feature Similarity\")\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Cosine Similarity\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
