{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "9689d8da-fa03-4f5b-bd19-2b50568bcbdd"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XMDkrWBhLGg",
        "outputId": "0873584d-e257-4f3b-8e8f-cca2f857282b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [05:56<00:00, 37.9MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing val2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [03:01<00:00, 36.6MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:06<00:00, 36.4MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: CLIP ViT-B/16 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "\n",
        "# after: model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model = model.float()                     # unify everything to fp32\n",
        "for p in model.parameters():              # (LoRA gets injected later; it will also be fp32)\n",
        "    p.requires_grad = p.requires_grad     # no-op; keeps current flags\n",
        "\n",
        "\n",
        "#model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bcbcf2ca-ef97-465f-8caa-07612b73eb13"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 335M/335M [00:18<00:00, 18.7MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 149,620,737\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the MSCOCO Data Loaders"
      ],
      "metadata": {
        "id": "h_cfc6FcltX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision import transforms, datasets\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "\n",
        "class CocoDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=None)\n",
        "        self.transform = transform\n",
        "        self.data = []\n",
        "        # Flatten (image, captions) so each item is (image, single_caption)\n",
        "        for img_idx in range(len(self.dataset)):\n",
        "            image, captions = self.dataset[img_idx]\n",
        "            for caption in captions:\n",
        "                self.data.append((img_idx, caption))\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_idx, caption = self.data[idx]\n",
        "        image, _ = self.dataset[img_idx]  # get the image\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        text = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return image, text\n",
        "\n",
        "train_img_dir = os.path.join(data_dir, 'train2014')\n",
        "train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')\n",
        "\n",
        "train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=64,\n",
        "    shuffle=True,\n",
        "    num_workers=2,\n",
        "    pin_memory=True\n",
        ")\n"
      ],
      "metadata": {
        "id": "dGosFCtHhMMi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ccd25034-466a-4467-bff2-99512d189a18"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.64s)\n",
            "creating index...\n",
            "index created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ftKgCZzre6B-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "#  CLIP + LoRA + TE Pruning\n",
        "#  (NaN-safe training + COCO eval) - Perturbation TE (no more TE=0)\n",
        "# =========================\n",
        "\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "from typing import List, Tuple, Union\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import AdamW\n",
        "from torch.nn.utils import clip_grad_norm_\n",
        "\n",
        "# -------------------------\n",
        "# Utils: safe numerics\n",
        "# -------------------------\n",
        "def safe_l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:\n",
        "    return x / x.norm(dim=dim, keepdim=True).clamp_min(eps)\n",
        "\n",
        "def get_safe_logit_scale(model, clamp_low: float = -5.0, clamp_high: float = 5.0) -> torch.Tensor:\n",
        "    # Clamp BEFORE exp to avoid overflow, keep on model device\n",
        "    if hasattr(model, \"logit_scale\"):\n",
        "        return model.logit_scale.float().clamp(clamp_low, clamp_high).exp()\n",
        "    return torch.tensor(1.0, device=next(model.parameters()).device)\n",
        "\n",
        "# -------------------------\n",
        "# LoRA wrapper (MHA-compatible)\n",
        "# -------------------------\n",
        "class LoRALinear(nn.Module):\n",
        "    \"\"\"\n",
        "    Wraps an nn.Linear and exposes a read-only merged .weight/.bias\n",
        "    so nn.MultiheadAttention can read out_proj.weight directly.\n",
        "    \"\"\"\n",
        "    def __init__(self, base_linear: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        assert isinstance(base_linear, nn.Linear), \"LoRALinear expects an nn.Linear as base\"\n",
        "        self.base = base_linear\n",
        "        for p in self.base.parameters():\n",
        "            p.requires_grad = False\n",
        "\n",
        "        self.r = int(r)\n",
        "        self.scaling = float(alpha) / float(r)\n",
        "\n",
        "        dev, dt = self.base.weight.device, self.base.weight.dtype\n",
        "        self.lora_A = nn.Parameter(torch.zeros(r, self.base.in_features,  device=dev, dtype=dt))\n",
        "        self.lora_B = nn.Parameter(torch.zeros(self.base.out_features, r, device=dev, dtype=dt))\n",
        "        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n",
        "        nn.init.zeros_(self.lora_B)\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
        "\n",
        "    @property\n",
        "    def weight(self) -> torch.Tensor:\n",
        "        # Used directly by nn.MultiheadAttention\n",
        "        return self.base.weight + self.scaling * (self.lora_B @ self.lora_A)\n",
        "\n",
        "    @property\n",
        "    def bias(self):\n",
        "        return self.base.bias\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.base(x)\n",
        "        delta = (self.dropout(x) @ self.lora_A.t()) @ self.lora_B.t()\n",
        "        return out + self.scaling * delta\n",
        "\n",
        "# -------------------------\n",
        "# Injection / upgrade helpers\n",
        "# -------------------------\n",
        "def _wrap_or_upgrade_linear(mod: nn.Module, attr: str, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    If mod.<attr> is:\n",
        "      - nn.Linear: wrap with LoRALinear\n",
        "      - old LoRALinear (without .weight): upgrade to new LoRALinear and copy A/B\n",
        "      - new LoRALinear: leave as-is\n",
        "    \"\"\"\n",
        "    cur = getattr(mod, attr, None)\n",
        "\n",
        "    # Already the new wrapper?\n",
        "    if isinstance(cur, LoRALinear) and isinstance(getattr(type(cur), \"weight\", None), property):\n",
        "        return\n",
        "\n",
        "    # Old wrapper (no .weight property) -> upgrade\n",
        "    if isinstance(cur, nn.Module) and cur.__class__.__name__ == \"LoRALinear\" and not hasattr(cur, \"weight\"):\n",
        "        base = cur.base\n",
        "        new = LoRALinear(base, r=r, alpha=alpha, dropout=dropout)\n",
        "        if hasattr(cur, \"lora_A\") and hasattr(cur, \"lora_B\"):\n",
        "            with torch.no_grad():\n",
        "                if new.lora_A.shape == cur.lora_A.shape:\n",
        "                    new.lora_A.copy_(cur.lora_A.data)\n",
        "                if new.lora_B.shape == cur.lora_B.shape:\n",
        "                    new.lora_B.copy_(cur.lora_B.data)\n",
        "        setattr(mod, attr, new)\n",
        "        return\n",
        "\n",
        "    # Fresh wrap\n",
        "    if isinstance(cur, nn.Linear):\n",
        "        setattr(mod, attr, LoRALinear(cur, r=r, alpha=alpha, dropout=dropout))\n",
        "\n",
        "\n",
        "\n",
        "# --- helpers to iterate resblocks safely ---\n",
        "def _iter_resblocks(stack_module: nn.Module):\n",
        "    # Works for model.visual and for the text tower (model itself)\n",
        "    if hasattr(stack_module, \"transformer\") and hasattr(stack_module.transformer, \"resblocks\"):\n",
        "        return list(stack_module.transformer.resblocks)\n",
        "    return []\n",
        "\n",
        "def add_lora_to_clip_vit_b16(model, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    Add/upgrade LoRA at:\n",
        "      - Vision blocks: attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "      - Text blocks:   attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "    Skips pruned/identity blocks or any block missing the expected attrs.\n",
        "    \"\"\"\n",
        "    # Vision\n",
        "    for blk in _iter_resblocks(model.visual):\n",
        "        # skip identity/pruned or non-standard blocks\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Text\n",
        "    for blk in _iter_resblocks(model):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Freeze all, then enable ONLY LoRA params (and keep logit_scale frozen)\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad = False\n",
        "    if hasattr(model, \"logit_scale\") and isinstance(model.logit_scale, torch.Tensor):\n",
        "        model.logit_scale.requires_grad_(False)\n",
        "    for n, p in model.named_parameters():\n",
        "        if n.endswith(\"lora_A\") or n.endswith(\"lora_B\"):\n",
        "            p.requires_grad = True\n",
        "\n",
        "    return [p for p in model.parameters() if p.requires_grad]\n",
        "\n",
        "\n",
        "def _assert_outproj_has_weight(model):\n",
        "    \"\"\"\n",
        "    Only verifies blocks that actually have attn.out_proj.\n",
        "    Skips Identity/pruned blocks cleanly.\n",
        "    \"\"\"\n",
        "    def _check_tower(tower, tower_name: str):\n",
        "        for i, blk in enumerate(_iter_resblocks(tower)):\n",
        "            attn = getattr(getattr(blk, \"attn\", None), \"out_proj\", None)\n",
        "            if attn is None:\n",
        "                # likely Identity/pruned or non-standard; skip\n",
        "                continue\n",
        "            w = getattr(attn, \"weight\", None)\n",
        "            if not isinstance(w, torch.Tensor):\n",
        "                raise AssertionError(\n",
        "                    f\"{tower_name} resblock {i} out_proj has no Tensor `.weight` \"\n",
        "                    f\"(got {type(w)} from {type(attn)})\"\n",
        "                )\n",
        "    _check_tower(model.visual, \"vision\")\n",
        "    _check_tower(model, \"text\")\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Hooks & TE computation (PERTURBATION-BASED)\n",
        "# -------------------------\n",
        "class BlockHook:\n",
        "    \"\"\"\n",
        "    Captures each transformer block output.\n",
        "    Auto-detects orientation (LND vs NLD) by choosing the version\n",
        "    with higher between-sample variance after token pooling.\n",
        "    Stores CPU float tensors.\n",
        "    \"\"\"\n",
        "    def __init__(self):\n",
        "        self.buf: List[torch.Tensor] = []\n",
        "        self.handles: List[torch.utils.hooks.RemovableHandle] = []\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _choose_nld(self, y: torch.Tensor) -> torch.Tensor:\n",
        "        # Candidate 1: assume input is LND (typical inside CLIP blocks)\n",
        "        cand1 = y.permute(1, 0, 2)  # NLD\n",
        "        # Candidate 2: assume input is already NLD\n",
        "        cand2 = y\n",
        "\n",
        "        def score(nld: torch.Tensor) -> torch.Tensor:\n",
        "            z = nld.mean(dim=1)  # [N, D] pooled over tokens\n",
        "            z = torch.nan_to_num(z).float()\n",
        "            return z.var(dim=0, unbiased=False).mean()  # scalar\n",
        "\n",
        "        s1 = score(cand1)\n",
        "        s2 = score(cand2)\n",
        "        return cand1 if (s1 >= s2) else cand2\n",
        "\n",
        "    def _hook(self, module, inp, out):\n",
        "        if isinstance(out, torch.Tensor) and out.ndim == 3:\n",
        "            nld = self._choose_nld(out)\n",
        "            self.buf.append(torch.nan_to_num(nld).float().cpu())\n",
        "        elif isinstance(out, (tuple, list)):\n",
        "            for x in out:\n",
        "                if isinstance(x, torch.Tensor) and x.ndim == 3:\n",
        "                    nld = self._choose_nld(x)\n",
        "                    self.buf.append(torch.nan_to_num(nld).float().cpu())\n",
        "                    break\n",
        "\n",
        "    def register_on_blocks(self, blocks: nn.ModuleList):\n",
        "        self.clear()\n",
        "        for blk in blocks:\n",
        "            self.handles.append(blk.register_forward_hook(self._hook))\n",
        "\n",
        "    def clear(self):\n",
        "        for h in self.handles:\n",
        "            h.remove()\n",
        "        self.handles = []\n",
        "        self.buf = []\n",
        "\n",
        "def pool_sequence(x: torch.Tensor, kind: str = \"cls\") -> torch.Tensor:\n",
        "    # x: [B, T, D]\n",
        "    if x.ndim != 3:\n",
        "        raise ValueError(\"Expected [B, T, D] tensor from blocks.\")\n",
        "    if kind == \"cls\":\n",
        "        return x[:, 0, :]\n",
        "    return x.mean(dim=1)\n",
        "\n",
        "def _batch_variance_debug(layer_outputs: List[torch.Tensor], pool=\"cls\") -> List[float]:\n",
        "    vals = []\n",
        "    for x in layer_outputs:\n",
        "        z = pool_sequence(x, kind=pool)  # [B, D]\n",
        "        z = torch.nan_to_num(z).float()\n",
        "        vals.append(float(z.var(dim=0, unbiased=False).mean().item()))\n",
        "    return vals\n",
        "\n",
        "def _collect_block_outputs_with_noise(\n",
        "    model,\n",
        "    images: torch.Tensor,\n",
        "    texts: torch.Tensor,\n",
        "    device: str,\n",
        "    sigma_v: float,\n",
        "    sigma_t: float,\n",
        ") -> Tuple[List[torch.Tensor], List[torch.Tensor]]:\n",
        "    \"\"\"\n",
        "    One forward pass with additive Gaussian noise injected:\n",
        "      - Vision: add noise to input pixels (already normalized) => images + noise\n",
        "      - Text:   add noise to token embeddings via a forward hook on model.token_embedding\n",
        "    Returns lists of per-block outputs (vision_NLD_list, text_NLD_list)\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "\n",
        "    # Vision hooks\n",
        "    v_hook = BlockHook()\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    v_hook.register_on_blocks(v_blocks)\n",
        "\n",
        "    # Text hooks + embedding noise hook\n",
        "    t_hook = BlockHook()\n",
        "    t_blocks = model.transformer.resblocks\n",
        "    t_hook.register_on_blocks(t_blocks)\n",
        "\n",
        "    emb_handle = None\n",
        "    if hasattr(model, \"token_embedding\") and sigma_t > 0:\n",
        "        def _emb_noise(_m, _inp, out):\n",
        "            return out + sigma_t * torch.randn_like(out)\n",
        "        emb_handle = model.token_embedding.register_forward_hook(_emb_noise)\n",
        "\n",
        "    try:\n",
        "        # Vision pass (noisy pixels)\n",
        "        if sigma_v > 0:\n",
        "            img_noisy = images + sigma_v * torch.randn_like(images)\n",
        "        else:\n",
        "            img_noisy = images\n",
        "        _ = model.encode_image(img_noisy.to(device, non_blocking=True))\n",
        "\n",
        "        # Text pass (embedding hook kicks in automatically)\n",
        "        _ = model.encode_text(texts.to(device, non_blocking=True))\n",
        "    finally:\n",
        "        if emb_handle is not None:\n",
        "            emb_handle.remove()\n",
        "\n",
        "    v_out = v_hook.buf[:]  # copy lists\n",
        "    t_out = t_hook.buf[:]\n",
        "    v_hook.clear()\n",
        "    t_hook.clear()\n",
        "    return v_out, t_out\n",
        "\n",
        "def _te_from_two_runs(\n",
        "    out1: List[torch.Tensor],\n",
        "    out2: List[torch.Tensor],\n",
        "    pool: str = \"cls\",\n",
        ") -> List[float]:\n",
        "    \"\"\"\n",
        "    Given two noisy forward lists of per-block outputs, compute TE using\n",
        "    Δz_l = z_l^(2) - z_l^(1) per sample (same batch), then cosine^2 across samples.\n",
        "    \"\"\"\n",
        "    assert len(out1) == len(out2) and len(out1) >= 2\n",
        "    # pooled shapes: [B, D]\n",
        "    pooled1 = [pool_sequence(x, kind=pool) for x in out1]\n",
        "    pooled2 = [pool_sequence(x, kind=pool) for x in out2]\n",
        "\n",
        "    # deltas per layer: [B, D]\n",
        "    deltas = [torch.nan_to_num(pooled2[l] - pooled1[l]) for l in range(len(pooled1))]\n",
        "\n",
        "    te_vals = []\n",
        "    eps = 1e-6\n",
        "    for l in range(len(deltas) - 1):\n",
        "        A = deltas[l]      # [B, D]\n",
        "        Bn = deltas[l + 1] # [B, D]\n",
        "        num = (A * Bn).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * Bn.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "\n",
        "\n",
        "# ---- Explicit per-block forwards (no hooks) ----\n",
        "\n",
        "def _first_param_dtype(module, default=torch.float32):\n",
        "    for p in module.parameters(recurse=True):\n",
        "        return p.dtype\n",
        "    return default\n",
        "\n",
        "@torch.no_grad()\n",
        "def _vision_block_outputs(model, images, sigma_v: float, device: str):\n",
        "    v = model.visual\n",
        "    pdev   = next(v.parameters()).device\n",
        "    v_dtype = _first_param_dtype(v, torch.float32)  # now fp32 after model.float()\n",
        "\n",
        "    x = images.to(pdev, non_blocking=True).to(v_dtype)   # << always fp32\n",
        "    if sigma_v > 0:\n",
        "        x = x + sigma_v * torch.randn_like(x, dtype=v_dtype)\n",
        "\n",
        "    x = v.conv1(x)\n",
        "    x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)\n",
        "    x = torch.cat([v.class_embedding[None, None, :].to(v_dtype).expand(x.shape[0], 1, -1), x], dim=1)\n",
        "    x = x + v.positional_embedding.to(v_dtype)\n",
        "    x = v.ln_pre(x.to(_first_param_dtype(v.ln_pre, v_dtype)))\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in v.transformer.resblocks:\n",
        "        # after model.float() every block is fp32, so this cast is a no-op\n",
        "        if x.dtype != v_dtype:\n",
        "            x = x.to(v_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())\n",
        "    return outs\n",
        "\n",
        "@torch.no_grad()\n",
        "def _text_block_outputs(model, texts, sigma_t: float, device: str):\n",
        "    tr = model.transformer\n",
        "    tr_dtype = _first_param_dtype(tr, torch.float32)\n",
        "\n",
        "    x = model.token_embedding(texts.to(device, non_blocking=True))\n",
        "    x = x.to(tr_dtype) + model.positional_embedding.to(tr_dtype)\n",
        "    if sigma_t > 0:\n",
        "        x = x + sigma_t * torch.randn_like(x, dtype=tr_dtype)\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in tr.resblocks:\n",
        "        if x.dtype != tr_dtype:\n",
        "            x = x.to(tr_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())\n",
        "    return outs\n",
        "\n",
        "\n",
        "\n",
        "def get_probe_loader(train_dataset, batch_size=64, num_batches=2):\n",
        "    from torch.utils.data import Subset, DataLoader\n",
        "    total = len(train_dataset)\n",
        "    need = min(total, batch_size * num_batches)\n",
        "    if need == 0:\n",
        "        raise ValueError(\"Empty train_dataset for probe loader.\")\n",
        "    idxs = random.sample(range(total), need)\n",
        "    probe = Subset(train_dataset, idxs)\n",
        "    bs = min(batch_size, len(probe))\n",
        "    return DataLoader(\n",
        "        probe,\n",
        "        batch_size=bs,\n",
        "        shuffle=False,\n",
        "        num_workers=2,\n",
        "        pin_memory=True,\n",
        "        drop_last=True,   # <— avoid B=1 (variance==0) batches in TE\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def _pooled(list_of_tensors, kind: str):\n",
        "    if kind not in (\"cls\", \"mean\"):\n",
        "        raise ValueError(\"pool must be 'cls' or 'mean'\")\n",
        "    pooled = []\n",
        "    for x in list_of_tensors:  # x: [B, T, D]\n",
        "        if kind == \"cls\":\n",
        "            pooled.append(x[:, 0, :])\n",
        "        else:\n",
        "            pooled.append(x.mean(dim=1))\n",
        "    return pooled\n",
        "\n",
        "def _te_from_two_runs_lists(pooled1, pooled2):\n",
        "    # pooled*: list of [B, D]\n",
        "    assert len(pooled1) == len(pooled2) and len(pooled1) >= 2\n",
        "    eps = 1e-6\n",
        "    te_vals = []\n",
        "    for l in range(len(pooled1) - 1):\n",
        "        A  = torch.nan_to_num(pooled2[l]   - pooled1[l]).float()     # [B, D]\n",
        "        Bn = torch.nan_to_num(pooled2[l+1] - pooled1[l+1]).float()   # [B, D]\n",
        "        num = (A * Bn).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * Bn.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_te_per_layer(\n",
        "    model,\n",
        "    images,\n",
        "    texts,\n",
        "    pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    device: str = \"cuda\",\n",
        "    sigma_v: float = 1e-2,\n",
        "    sigma_t: float = 5e-2,\n",
        "    verbose_once: bool = True,\n",
        "):\n",
        "    if isinstance(pool, str):\n",
        "        pool_v, pool_t = pool, pool\n",
        "    else:\n",
        "        pool_v, pool_t = pool\n",
        "\n",
        "    # Run two independent noisy passes, explicitly collecting per-block outputs\n",
        "    v1 = _vision_block_outputs(model, images, sigma_v, device)\n",
        "    t1 = _text_block_outputs(model, texts,  sigma_t, device)\n",
        "    v2 = _vision_block_outputs(model, images, sigma_v, device)\n",
        "    t2 = _text_block_outputs(model, texts,  sigma_t, device)\n",
        "\n",
        "    # Sanity: batch size must be >= 2 for the variance debug to be meaningful\n",
        "    if verbose_once:\n",
        "        Bv = v1[0].shape[0] if len(v1) else 0\n",
        "        Bt = t1[0].shape[0] if len(t1) else 0\n",
        "        if Bv < 2 or Bt < 2:\n",
        "            print(f\"[TE Debug] Probe batch too small (vision B={Bv}, text B={Bt}). \"\n",
        "                  f\"Use drop_last=True or a bigger probe batch.\")\n",
        "\n",
        "    pv1, pv2 = _pooled(v1, pool_v), _pooled(v2, pool_v)\n",
        "    pt1, pt2 = _pooled(t1, pool_t), _pooled(t2, pool_t)\n",
        "\n",
        "    if verbose_once:\n",
        "        # mean absolute delta magnitude as a quick wake-up check\n",
        "        md_v = float(torch.stack([ (b-a).abs().mean()\n",
        "                                   for a,b in zip(pv1, pv2) ]).mean().item())\n",
        "        md_t = float(torch.stack([ (b-a).abs().mean()\n",
        "                                   for a,b in zip(pt1, pt2) ]).mean().item())\n",
        "        print(f\"[TE Debug] mean|Δ| vision={md_v:.3e}, text={md_t:.3e}  \"\n",
        "              f\"(sigmas: v={sigma_v:.2e}, t={sigma_t:.2e})\")\n",
        "\n",
        "    te_v = _te_from_two_runs_lists(pv1, pv2)\n",
        "    te_t = _te_from_two_runs_lists(pt1, pt2)\n",
        "    return te_v, te_t\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Pruning\n",
        "# -------------------------\n",
        "class IdentityBlock(nn.Module):\n",
        "    def forward(self, x):\n",
        "        return x\n",
        "\n",
        "def prune_topk_layers_by_te(model, te_v: List[float], te_t: List[float], k_v: int = 4, k_t: int = 4,\n",
        "                            guard_skip_first_last: bool = True, avoid_adjacent: bool = True) -> Tuple[List[int], List[int]]:\n",
        "    # Vision\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    V = len(v_blocks)\n",
        "    cand_v = list(range(1, V))  # prune block index i corresponds to TE at i-1\n",
        "    if guard_skip_first_last:\n",
        "        cand_v = [idx for idx in cand_v if idx not in (0, V-1)]\n",
        "    order_v = sorted(cand_v, key=lambda idx: te_v[idx - 1], reverse=True)\n",
        "    chosen_v = []\n",
        "    for idx in order_v:\n",
        "        if len(chosen_v) >= k_v:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_v):\n",
        "            continue\n",
        "        chosen_v.append(idx)\n",
        "    for idx in chosen_v:\n",
        "        v_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    # Text\n",
        "    t_blocks = model.transformer.resblocks\n",
        "    T = len(t_blocks)\n",
        "    cand_t = list(range(1, T))\n",
        "    if guard_skip_first_last:\n",
        "        cand_t = [idx for idx in cand_t if idx not in (0, T-1)]\n",
        "    order_t = sorted(cand_t, key=lambda idx: te_t[idx - 1], reverse=True)\n",
        "    chosen_t = []\n",
        "    for idx in order_t:\n",
        "        if len(chosen_t) >= k_t:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_t):\n",
        "            continue\n",
        "        chosen_t.append(idx)\n",
        "    for idx in chosen_t:\n",
        "        t_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    return chosen_v, chosen_t\n",
        "\n",
        "# -------------------------\n",
        "# Contrastive loss (NaN-safe)\n",
        "# -------------------------\n",
        "def clip_contrastive_loss(model, image_features, text_features):\n",
        "    # Cast to fp32 for numerics, then normalize safely\n",
        "    image_features = safe_l2_normalize(image_features.float(), dim=-1, eps=1e-6)\n",
        "    text_features  = safe_l2_normalize(text_features.float(),  dim=-1, eps=1e-6)\n",
        "\n",
        "    # B×B cosine sims scaled by a SAFE logit scale\n",
        "    logit_scale = get_safe_logit_scale(model)  # scalar tensor on correct device\n",
        "    logits_per_image = (logit_scale * (image_features @ text_features.t()))\n",
        "    logits_per_image = torch.nan_to_num(logits_per_image, nan=0.0, posinf=1e4, neginf=-1e4)\n",
        "    logits_per_text = logits_per_image.t()\n",
        "\n",
        "    B = image_features.size(0)\n",
        "    labels = torch.arange(B, device=logits_per_image.device)\n",
        "\n",
        "    loss_i = F.cross_entropy(logits_per_image, labels)\n",
        "    loss_t = F.cross_entropy(logits_per_text,  labels)\n",
        "    return 0.5 * (loss_i + loss_t)\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Training (NaN-safe; TE via perturbations)\n",
        "# -------------------------\n",
        "def train_with_te_prune_lora(\n",
        "    model,\n",
        "    train_dataloader,\n",
        "    train_dataset,\n",
        "    device,\n",
        "    lora_rank=8,\n",
        "    lora_alpha=16,\n",
        "    lora_dropout=0.0,\n",
        "    lr=1e-4,\n",
        "    weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "    te_pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    probe_batches=2,\n",
        "    te_sigma_v: float = 1e-3,\n",
        "    te_sigma_t: float = 1e-3,\n",
        "):\n",
        "    # LoRA injection/upgrade\n",
        "    lora_params = add_lora_to_clip_vit_b16(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)\n",
        "    _assert_outproj_has_weight(model)\n",
        "    model.to(device)\n",
        "\n",
        "    optimizer = AdamW(lora_params, lr=lr, weight_decay=weight_decay)\n",
        "    probe_loader = get_probe_loader(train_dataset, batch_size=64, num_batches=probe_batches)\n",
        "\n",
        "    history = {\"vision\": [], \"text\": []}\n",
        "\n",
        "    def epoch_pass(epoch_idx):\n",
        "        model.train()\n",
        "        total_loss, num_steps = 0.0, 0\n",
        "        for images, texts in train_dataloader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "            texts  = texts.to(device, non_blocking=True)\n",
        "\n",
        "            optimizer.zero_grad(set_to_none=True)\n",
        "\n",
        "            img_feat = model.encode_image(images)\n",
        "            txt_feat = model.encode_text(texts)\n",
        "\n",
        "            loss = clip_contrastive_loss(model, img_feat, txt_feat)\n",
        "\n",
        "            # Skip non-finite batches\n",
        "            if not torch.isfinite(loss):\n",
        "                with torch.no_grad():\n",
        "                    bad = {\n",
        "                        \"loss\": float(loss.detach().float().item()),\n",
        "                        \"img_feat_has_nan\": bool(torch.isnan(img_feat).any().item()),\n",
        "                        \"txt_feat_has_nan\": bool(torch.isnan(txt_feat).any().item()),\n",
        "                    }\n",
        "                print(\"[Warn] Skipping non-finite batch:\", bad)\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            clip_grad_norm_(lora_params, max_norm=1.0)\n",
        "            optimizer.step()\n",
        "\n",
        "            total_loss += float(loss.item())\n",
        "            num_steps += 1\n",
        "\n",
        "        avg_loss = total_loss / max(1, num_steps)\n",
        "\n",
        "        # --- TE on probe (eval mode) via perturbations ---\n",
        "        model.eval()\n",
        "        te_v_all, te_t_all = [], []\n",
        "        with torch.no_grad():\n",
        "            first = True\n",
        "            for images, texts in probe_loader:\n",
        "                tv, tt = compute_te_per_layer(\n",
        "                    model, images, texts,\n",
        "                    pool=te_pool,\n",
        "                    device=device,\n",
        "                    sigma_v=te_sigma_v,\n",
        "                    sigma_t=te_sigma_t,\n",
        "                    verbose_once=first\n",
        "                )\n",
        "                first = False\n",
        "                te_v_all.append(torch.tensor(tv))\n",
        "                te_t_all.append(torch.tensor(tt))\n",
        "        te_v_mean = torch.stack(te_v_all).nanmean(dim=0).tolist()\n",
        "        te_t_mean = torch.stack(te_t_all).nanmean(dim=0).tolist()\n",
        "\n",
        "        history[\"vision\"].append(te_v_mean)\n",
        "        history[\"text\"].append(te_t_mean)\n",
        "\n",
        "        print(f\"[Epoch {epoch_idx}] loss={avg_loss:.6f}\")\n",
        "        print(f\"  Vision TE per layer (ℓ=0..{len(te_v_mean)-1} -> prune ℓ+1):\")\n",
        "        print(\"   \", [round(x, 6) for x in te_v_mean])\n",
        "        print(f\"  Text   TE per layer (ℓ=0..{len(te_t_mean)-1} -> prune ℓ+1):\")\n",
        "        print(\"   \", [round(x, 6) for x in te_t_mean])\n",
        "\n",
        "    # --- Phase 1 ---\n",
        "    for ep in range(1, epochs_before_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    # --- Prune based on TE ---\n",
        "    last_te_v = history[\"vision\"][-1]\n",
        "    last_te_t = history[\"text\"][-1]\n",
        "    chosen_v, chosen_t = prune_topk_layers_by_te(\n",
        "        model, last_te_v, last_te_t,\n",
        "        k_v=prune_k_vision, k_t=prune_k_text,\n",
        "        guard_skip_first_last=True, avoid_adjacent=True\n",
        "    )\n",
        "    print(f\"Pruned Vision blocks at indices: {chosen_v} (Identity)\")\n",
        "    print(f\"Pruned Text   blocks at indices: {chosen_t} (Identity)\")\n",
        "\n",
        "    # --- Phase 2 ---\n",
        "    for ep in range(epochs_before_prune + 1, epochs_before_prune + epochs_after_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    return history\n",
        "\n",
        "# ======================================================\n",
        "# OPTIONAL: COCO Retrieval Evaluation (I2T / T2I)\n",
        "# ======================================================\n",
        "def evaluate_coco_retrieval(model, data_dir: str, batch_size: int = 64, input_resolution: int = 224, context_length: int = 77):\n",
        "    \"\"\"\n",
        "    Evaluate trained model on COCO val2014 captions.\n",
        "    Assumes `pip install pycocotools` and images/annotations present in data_dir.\n",
        "    \"\"\"\n",
        "    import clip\n",
        "    import numpy as np\n",
        "    from torchvision import datasets, transforms\n",
        "    from torchvision.transforms import InterpolationMode\n",
        "    from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "    device = next(model.parameters()).device\n",
        "\n",
        "    eval_transform = transforms.Compose([\n",
        "        transforms.Resize((input_resolution, input_resolution), interpolation=InterpolationMode.BICUBIC),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                             std=(0.26862954, 0.26130258, 0.27577711))\n",
        "    ])\n",
        "\n",
        "    class CocoEvalDataset(Dataset):\n",
        "        def __init__(self, root, annFile, transform=None):\n",
        "            self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "            self.transform = transform\n",
        "        def __len__(self):\n",
        "            return len(self.dataset)\n",
        "        def __getitem__(self, idx):\n",
        "            image, captions = self.dataset[idx]\n",
        "            return image, captions\n",
        "\n",
        "    def coco_collate_fn(batch):\n",
        "        images = [img for img, _ in batch]\n",
        "        captions = [caps for _, caps in batch]  # list[list[str]]\n",
        "        images = torch.stack(images, dim=0)\n",
        "        return images, captions\n",
        "\n",
        "    val_img_dir  = os.path.join(data_dir, 'val2014')\n",
        "    val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "    val_dataset  = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2,\n",
        "                              pin_memory=torch.cuda.is_available(), collate_fn=coco_collate_fn)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    all_image_features = []\n",
        "    all_text_features  = []\n",
        "    image_to_text_indices = []\n",
        "    all_captions_flat = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        text_count = 0\n",
        "        for images, batch_captions in val_loader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "\n",
        "            image_feats = model.encode_image(images)\n",
        "            image_feats = safe_l2_normalize(image_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_image_features.append(image_feats.cpu())\n",
        "\n",
        "            flat_captions = []\n",
        "            image_to_text_map_for_batch = []\n",
        "            for caps in batch_captions:\n",
        "                start_idx = text_count + len(flat_captions)\n",
        "                flat_captions.extend(caps)\n",
        "                end_idx = text_count + len(flat_captions)\n",
        "                image_to_text_map_for_batch.append((start_idx, end_idx))\n",
        "\n",
        "            if len(flat_captions) == 0:\n",
        "                image_to_text_indices.extend([[] for _ in batch_captions])\n",
        "                continue\n",
        "\n",
        "            texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "            text_feats = model.encode_text(texts)\n",
        "            text_feats = safe_l2_normalize(text_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_text_features.append(text_feats.cpu())\n",
        "            all_captions_flat.extend(flat_captions)\n",
        "\n",
        "            for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "                image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "            text_count += len(flat_captions)\n",
        "\n",
        "    all_image_features = torch.cat(all_image_features, dim=0)  # [N_img, D]\n",
        "    all_text_features  = torch.cat(all_text_features,  dim=0)  # [N_txt, D]\n",
        "\n",
        "    sim_matrix = all_image_features @ all_text_features.t()    # [N_img, N_txt]\n",
        "\n",
        "    def compute_recall_with_multiple_captions(sim_matrix: torch.Tensor,\n",
        "                                              image_to_text_indices,\n",
        "                                              k: int = 1) -> float:\n",
        "        n = sim_matrix.size(0)\n",
        "        successes = 0\n",
        "        for i in range(n):\n",
        "            scores = sim_matrix[i]\n",
        "            sorted_idx = torch.argsort(scores, descending=True)\n",
        "            correct = image_to_text_indices[i]\n",
        "            if not correct:\n",
        "                continue\n",
        "            min_rank = None\n",
        "            for cidx in correct:\n",
        "                pos = (sorted_idx == cidx).nonzero(as_tuple=False)\n",
        "                if pos.numel() > 0:\n",
        "                    rank = int(pos[0, 0].item())\n",
        "                    if (min_rank is None) or (rank < min_rank):\n",
        "                        min_rank = rank\n",
        "            if (min_rank is not None) and (min_rank < k):\n",
        "                successes += 1\n",
        "        return successes / max(1, n)\n",
        "\n",
        "    # I2T\n",
        "    r1  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "    r5  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "    r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "    print(\"Image-to-Text Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "    # T2I\n",
        "    text_to_image = [None] * all_text_features.size(0)\n",
        "    for img_idx, tinds in enumerate(image_to_text_indices):\n",
        "        for t in tinds:\n",
        "            text_to_image[t] = img_idx\n",
        "\n",
        "    sim_matrix_t2i = sim_matrix.t()  # [N_txt, N_img]\n",
        "\n",
        "    def compute_recall_text_to_image(sim_matrix_t2i: torch.Tensor,\n",
        "                                     text_to_image_list,\n",
        "                                     k: int = 1) -> float:\n",
        "        m = sim_matrix_t2i.size(0)\n",
        "        successes = 0\n",
        "        for j in range(m):\n",
        "            scores = sim_matrix_t2i[j]\n",
        "            sorted_indices = torch.argsort(scores, descending=True)\n",
        "            correct_img = text_to_image_list[j]\n",
        "            if correct_img is None:\n",
        "                continue\n",
        "            pos = (sorted_indices == correct_img).nonzero(as_tuple=False)\n",
        "            if pos.numel() > 0 and int(pos[0, 0].item()) < k:\n",
        "                successes += 1\n",
        "        return successes / max(1, m)\n",
        "\n",
        "    r1_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "    r5_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "    r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "    print(\"Text-to-Image Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "id": "s_IDtdvU9yLv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========= Short-LVLM-style pruning for CLIP ViT-B/16 (vision & text) =========\n",
        "# Adds: cosine-based layer localization using CLS tokens + Subspace-Compensated Pruning (SCP)\n",
        "# Note: Uses your existing helpers: IdentityBlock, get_probe_loader, _vision_block_outputs, _text_block_outputs\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "@torch.no_grad()\n",
        "def _gather_cls_by_layer(model, probe_loader, device, stack=\"vision\"):\n",
        "    \"\"\"\n",
        "    Returns a list of tensors [N, D] per layer with CLS representations.\n",
        "    stack: \"vision\" -> model.visual; \"text\" -> model\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "    cls_all = None  # we'll aggregate batches\n",
        "    outs_list = None\n",
        "\n",
        "    for images, texts in probe_loader:\n",
        "        if stack == \"vision\":\n",
        "            v_outs = _vision_block_outputs(model, images, sigma_v=0.0, device=device)  # list of [B,T,D]\n",
        "            # CLS = token 0\n",
        "            batch_cls = [x[:, 0, :].float().cpu() for x in v_outs]\n",
        "        else:\n",
        "            # texts must be token ids already; if your probe loader yields raw strings, tokenize them before\n",
        "            t_outs = _text_block_outputs(model, texts, sigma_t=0.0, device=device)\n",
        "            batch_cls = [x[:, 0, :].float().cpu() for x in t_outs]\n",
        "\n",
        "        if outs_list is None:\n",
        "            outs_list = [c for c in batch_cls]\n",
        "        else:\n",
        "            for i in range(len(outs_list)):\n",
        "                outs_list[i] = torch.cat([outs_list[i], batch_cls[i]], dim=0)\n",
        "\n",
        "    # Normalize to be safe\n",
        "    outs_list = [torch.nan_to_num(x) for x in outs_list]\n",
        "    return outs_list  # list of [N,D], length = #blocks\n",
        "\n",
        "@torch.no_grad()\n",
        "def _cosine_scores_adjacent(cls_list):\n",
        "    \"\"\"\n",
        "    Given [X^0, X^1, ..., X^{L-1}] with shapes [N,D], compute mean cosine(X^ℓ, X^{ℓ+1}) per ℓ.\n",
        "    Returns a list scores of length L-1; pruning layer (ℓ+1) is suggested when score[ℓ] is high.\n",
        "    \"\"\"\n",
        "    scores = []\n",
        "    eps = 1e-6\n",
        "    for l in range(len(cls_list) - 1):\n",
        "        A = cls_list[l]\n",
        "        B = cls_list[l+1]\n",
        "        num = (A * B).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * B.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        scores.append(float(cos.mean().item()))\n",
        "    return scores  # index l -> suggests pruning layer l+1\n",
        "\n",
        "def _choose_k_layers(scores, total_blocks, k, guard_skip_first_last=True, avoid_adjacent=True):\n",
        "    \"\"\"\n",
        "    scores: length = total_blocks-1, where index l scores adjacency between l and l+1 and suggests pruning (l+1).\n",
        "    Returns sorted unique block indices to prune.\n",
        "    \"\"\"\n",
        "    cand = list(range(1, total_blocks))  # pruning target idx = l+1 maps to indices 1..L-1\n",
        "    if guard_skip_first_last:\n",
        "        cand = [i for i in cand if i not in (0, total_blocks-1)]\n",
        "    # order by descending score\n",
        "    order = sorted(cand, key=lambda i: scores[i-1], reverse=True)\n",
        "    chosen = []\n",
        "    for idx in order:\n",
        "        if len(chosen) >= k:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen):\n",
        "            continue\n",
        "        chosen.append(idx)\n",
        "    chosen.sort()\n",
        "    return chosen\n",
        "\n",
        "def _get_blocks(model, stack):\n",
        "    return model.visual.transformer.resblocks if stack == \"vision\" else model.transformer.resblocks\n",
        "\n",
        "def _apply_identity(model, indices, stack):\n",
        "    blocks = _get_blocks(model, stack)\n",
        "    for i in indices:\n",
        "        blocks[i] = IdentityBlock()\n",
        "\n",
        "def _right_project_weight_inplace(lin: nn.Linear, V: torch.Tensor):\n",
        "    \"\"\"\n",
        "    In-place: W <- W @ (I + V V^T), where W is [out, in], V is [in, k].\n",
        "    \"\"\"\n",
        "    if not isinstance(lin, nn.Linear):\n",
        "        return\n",
        "    W = lin.weight.data  # [out, in]\n",
        "    in_dim = W.shape[1]\n",
        "    I = torch.eye(in_dim, device=W.device, dtype=W.dtype)\n",
        "    P = I + V @ V.t()  # [in, in]\n",
        "    W.copy_(W @ P)  # right-multiply to transform inputs: y = x (I+VV^T) W^T\n",
        "\n",
        "def _project_block_weights(block: nn.Module, V: torch.Tensor):\n",
        "    # Project key linear layers inside a CLIP-style block (attn.out_proj, mlp.c_fc, mlp.c_proj)\n",
        "    attn = getattr(block, \"attn\", None)\n",
        "    if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "        _right_project_weight_inplace(attn.out_proj, V)\n",
        "    mlp = getattr(block, \"mlp\", None)\n",
        "    if mlp is not None:\n",
        "        if hasattr(mlp, \"c_fc\"):\n",
        "            _right_project_weight_inplace(mlp.c_fc, V)\n",
        "        if hasattr(mlp, \"c_proj\"):\n",
        "            _right_project_weight_inplace(mlp.c_proj, V)\n",
        "\n",
        "@torch.no_grad()\n",
        "def _scp_project_neighbors(model, probe_loader, device, pruned_idxs, stack=\"vision\", rank=64):\n",
        "    \"\"\"\n",
        "    For each pruned layer ℓ_p, pick nearest retained ℓ_r, collect CLS features X^{ℓ_p}, X^{ℓ_r} over probe,\n",
        "    compute H = X_p - X_r, take top-k right singular vectors V_k (on feature dim D), then project ℓ_r weights.\n",
        "    Each retained layer used at most once.\n",
        "    \"\"\"\n",
        "    # Gather CLS per layer\n",
        "    cls_list = _gather_cls_by_layer(model, probe_loader, device, stack=stack)  # list length L\n",
        "    L = len(cls_list)\n",
        "    retained = sorted(set(range(L)) - set(pruned_idxs))\n",
        "    if not retained:\n",
        "        return\n",
        "\n",
        "    # Map each pruned to nearest retained (unique assignment)\n",
        "    assigned = {}\n",
        "    used = set()\n",
        "    for p in pruned_idxs:\n",
        "        # nearest by absolute distance\n",
        "        cand = sorted(retained, key=lambda r: (abs(r - p), r))\n",
        "        r_pick = None\n",
        "        for r in cand:\n",
        "            if r not in used:\n",
        "                r_pick = r\n",
        "                break\n",
        "        if r_pick is None:\n",
        "            # fallback allow reuse\n",
        "            r_pick = cand[0]\n",
        "        assigned[p] = r_pick\n",
        "        used.add(r_pick)\n",
        "\n",
        "    # Do SVD and weight projection\n",
        "    blocks = _get_blocks(model, stack)\n",
        "    for p, r in assigned.items():\n",
        "        Xp = cls_list[p]   # [N,D]\n",
        "        Xr = cls_list[r]   # [N,D]\n",
        "        H  = (Xp - Xr).to(torch.float32)  # [N,D]\n",
        "        # compute top-k right singular vectors on feature dim D\n",
        "        # torch.linalg.svd(H, full_matrices=False): H = U S Vh ; V = Vh^T ; take first k columns of V\n",
        "        U, S, Vh = torch.linalg.svd(H, full_matrices=False)\n",
        "        k = min(rank, Vh.shape[0], Vh.shape[1])\n",
        "        Vk = Vh[:k, :].T.contiguous()  # [D,k]\n",
        "        Vk = Vk.to(next(model.parameters()).device).to(_first_param_dtype(model, torch.float32))\n",
        "        # Project nearest retained block weights\n",
        "        _project_block_weights(blocks[r], Vk)\n",
        "\n",
        "def shortlvlm_prune_clip(model, train_dataset, device, k_v=4, k_t=4, subspace_rank=64, probe_batches=2):\n",
        "    \"\"\"\n",
        "    Training-free pass:\n",
        "      1) Localize 4 vision & 4 text layers via adjacent-layer mean cosine on CLS.\n",
        "      2) Replace those blocks with identities.\n",
        "      3) SCP: project nearest retained blocks using top-k subspaces from CLS feature differences.\n",
        "    \"\"\"\n",
        "    probe_loader = get_probe_loader(train_dataset, batch_size=64, num_batches=probe_batches)\n",
        "\n",
        "    # ----- Vision -----\n",
        "    v_cls = _gather_cls_by_layer(model, probe_loader, device, stack=\"vision\")\n",
        "    v_scores = _cosine_scores_adjacent(v_cls)\n",
        "    V = len(v_cls)\n",
        "    prune_v = _choose_k_layers(v_scores, V, k_v, guard_skip_first_last=True, avoid_adjacent=True)\n",
        "    _apply_identity(model, prune_v, stack=\"vision\")\n",
        "    _scp_project_neighbors(model, probe_loader, device, prune_v, stack=\"vision\", rank=subspace_rank)\n",
        "    print(f\"[Short-LVLM] Vision pruned indices: {prune_v}\")\n",
        "\n",
        "    # ----- Text -----\n",
        "    t_cls = _gather_cls_by_layer(model, probe_loader, device, stack=\"text\")\n",
        "    t_scores = _cosine_scores_adjacent(t_cls)\n",
        "    T = len(t_cls)\n",
        "    prune_t = _choose_k_layers(t_scores, T, k_t, guard_skip_first_last=True, avoid_adjacent=True)\n",
        "    _apply_identity(model, prune_t, stack=\"text\")\n",
        "    _scp_project_neighbors(model, probe_loader, device, prune_t, stack=\"text\", rank=subspace_rank)\n",
        "    print(f\"[Short-LVLM] Text   pruned indices: {prune_t}\")\n",
        "\n",
        "    return {\"vision_pruned\": prune_v, \"text_pruned\": prune_t}\n",
        "\n",
        "# =========================\n",
        "# Usage (call once before LoRA/training/eval)\n",
        "# =========================\n",
        "# history = shortlvlm_prune_clip(model, train_dataset, device, k_v=4, k_t=4, subspace_rank=64)\n",
        "# Then continue with your existing LoRA injection and training/eval pipeline.\n"
      ],
      "metadata": {
        "id": "ashw2_19e6k6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# -------- Short-LVLM pruning helpers --------\n",
        "class IdentityBlock(nn.Module):\n",
        "    def forward(self, x): return x\n",
        "\n",
        "def _shortlvlm_indices(num_layers: int, k: int,\n",
        "                       guard_skip_first_last: bool = True,\n",
        "                       avoid_adjacent: bool = True):\n",
        "    \"\"\"\n",
        "    Choose upper-middle, evenly spaced block indices. For L=12, k=4 => [9,7,5,3].\n",
        "    \"\"\"\n",
        "    # candidate pool (skip first/last if requested)\n",
        "    start = 1 if guard_skip_first_last else 0\n",
        "    end   = num_layers - (2 if guard_skip_first_last else 1)\n",
        "    cands = list(range(start, end))            # e.g., 1..10 for L=12\n",
        "    cands = cands[::-1]                        # prefer top first\n",
        "\n",
        "    chosen = []\n",
        "    for idx in cands:\n",
        "        if len(chosen) >= k:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen):\n",
        "            continue\n",
        "        chosen.append(idx)\n",
        "\n",
        "    # If spacing filtered too many, fill remaining by skipping every other\n",
        "    if len(chosen) < k:\n",
        "        for idx in cands:\n",
        "            if idx in chosen: continue\n",
        "            if len(chosen) >= k: break\n",
        "            chosen.append(idx)\n",
        "\n",
        "    return chosen[:k]\n",
        "\n",
        "def prune_shortlvlm(model, k_v: int = 4, k_t: int = 4,\n",
        "                    guard_skip_first_last: bool = True,\n",
        "                    avoid_adjacent: bool = True):\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    t_blocks = model.transformer.resblocks\n",
        "\n",
        "    pick_v = _shortlvlm_indices(len(v_blocks), k_v, guard_skip_first_last, avoid_adjacent)\n",
        "    pick_t = _shortlvlm_indices(len(t_blocks), k_t, guard_skip_first_last, avoid_adjacent)\n",
        "\n",
        "    for i in pick_v:\n",
        "        v_blocks[i] = IdentityBlock()\n",
        "    for i in pick_t:\n",
        "        t_blocks[i] = IdentityBlock()\n",
        "\n",
        "    return pick_v, pick_t\n",
        "\n",
        "# -------- LoRA injection (reuse your LoRA utilities) --------\n",
        "# (Assumes you already defined LoRALinear, _wrap_or_upgrade_linear, add_lora_to_clip_vit_b16, etc.)\n",
        "\n",
        "# -------- Training with Short-LVLM pruning --------\n",
        "def train_with_shortlvlm_prune_lora(\n",
        "    model,\n",
        "    train_dataloader,\n",
        "    train_dataset,   # unused here but kept for API symmetry with your TE version\n",
        "    device,\n",
        "    lora_rank=8,\n",
        "    lora_alpha=16,\n",
        "    lora_dropout=0.0,\n",
        "    lr=1e-4,\n",
        "    weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "):\n",
        "    # Inject/upgrade LoRA and freeze base\n",
        "    lora_params = add_lora_to_clip_vit_b16(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)\n",
        "    model.to(device)\n",
        "    optimizer = AdamW(lora_params, lr=lr, weight_decay=weight_decay)\n",
        "\n",
        "    def one_epoch(ep_idx: int):\n",
        "        model.train()\n",
        "        total, steps = 0.0, 0\n",
        "        for images, texts in train_dataloader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "            texts  = texts.to(device,  non_blocking=True)\n",
        "\n",
        "            optimizer.zero_grad(set_to_none=True)\n",
        "            img_f = model.encode_image(images)\n",
        "            txt_f = model.encode_text(texts)\n",
        "            loss = clip_contrastive_loss(model, img_f, txt_f)\n",
        "\n",
        "            if not torch.isfinite(loss):\n",
        "                continue\n",
        "            loss.backward()\n",
        "            clip_grad_norm_(lora_params, 1.0)\n",
        "            optimizer.step()\n",
        "            total += float(loss.item()); steps += 1\n",
        "        print(f\"[Epoch {ep_idx}] loss={total/max(1,steps):.6f}\")\n",
        "\n",
        "    # Phase 1: warmup\n",
        "    for ep in range(1, epochs_before_prune + 1):\n",
        "        one_epoch(ep)\n",
        "\n",
        "    # Prune by Short-LVLM\n",
        "    pick_v, pick_t = prune_shortlvlm(\n",
        "        model,\n",
        "        k_v=prune_k_vision, k_t=prune_k_text,\n",
        "        guard_skip_first_last=True, avoid_adjacent=True\n",
        "    )\n",
        "    print(f\"[Short-LVLM] Pruned Vision blocks: {pick_v}\")\n",
        "    print(f\"[Short-LVLM] Pruned Text   blocks: {pick_t}\")\n",
        "\n",
        "    # Phase 2: post-prune LoRA fine-tune\n",
        "    for ep in range(epochs_before_prune + 1, epochs_before_prune + epochs_after_prune + 1):\n",
        "        one_epoch(ep)\n",
        "\n",
        "    return {\"shortlvlm_pruned_vision\": pick_v, \"shortlvlm_pruned_text\": pick_t}\n",
        "\n",
        "# -------- Example usage --------\n",
        "# history = train_with_shortlvlm_prune_lora(\n",
        "#     model=model,\n",
        "#     train_dataloader=train_dataloader,\n",
        "#     train_dataset=train_dataset,\n",
        "#     device=device,\n",
        "#     lora_rank=8, lora_alpha=16, lora_dropout=0.0,\n",
        "#     lr=1e-4, weight_decay=0.0,\n",
        "#     epochs_before_prune=2,\n",
        "#     epochs_after_prune=3,\n",
        "#     prune_k_vision=4,\n",
        "#     prune_k_text=4,\n",
        "# )\n",
        "\n",
        "# evaluate_coco_retrieval(model, data_dir=data_dir, batch_size=64)\n"
      ],
      "metadata": {
        "id": "38rUmMBwOU0a"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "history = train_with_shortlvlm_prune_lora(\n",
        "    model=model,\n",
        "    train_dataloader=train_dataloader,\n",
        "    train_dataset=train_dataset,\n",
        "    device=device,\n",
        "    lora_rank=8, lora_alpha=16, lora_dropout=0.0,\n",
        "    lr=1e-4, weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=2,\n",
        "    prune_k_text=2,\n",
        ")\n",
        "\n",
        "# evaluate_coco_retrieval(model, data_dir=data_dir, batch_size=64)\n"
      ],
      "metadata": {
        "id": "yjJQc2HANkGQ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a4c1e569-c6fd-43d5-dbf5-87d4a2397494"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[Epoch 1] loss=0.161328\n",
            "[Epoch 2] loss=0.126671\n",
            "[Short-LVLM] Pruned Vision blocks: [9, 7]\n",
            "[Short-LVLM] Pruned Text   blocks: [9, 7]\n",
            "[Epoch 3] loss=0.248910\n",
            "[Epoch 4] loss=0.171020\n",
            "[Epoch 5] loss=0.147786\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Put model in eval mode and run zero-shot retrieval on COCO val2014\n",
        "evaluate_coco_retrieval(model, data_dir=data_dir, batch_size=64)\n"
      ],
      "metadata": {
        "id": "E2GFl-OyNkfD",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5fd2a887-2d68-4a63-99e1-f210575426e1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.32s)\n",
            "creating index...\n",
            "index created!\n",
            "Image-to-Text Retrieval:\n",
            "Recall@1:  25.65%\n",
            "Recall@5:  49.34%\n",
            "Recall@10: 60.63%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1:  15.91%\n",
            "Recall@5:  35.16%\n",
            "Recall@10: 45.65%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "etHTsgFsNDdf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "RPv7HIqxkjON"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LQXgZ_alw9gc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "d42i-P_Yrq-Y"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}