{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "1a524d55-f751-4e07-c4e7-26c0875b0a24"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XMDkrWBhLGg",
        "outputId": "c687a171-642a-4cd0-aeda-e4cced0ee629"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [08:12<00:00, 27.4MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing val2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [03:30<00:00, 31.6MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:07<00:00, 34.3MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: CLIP ViT-B/16 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "\n",
        "# after: model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model = model.float()                     # unify everything to fp32\n",
        "for p in model.parameters():              # (LoRA gets injected later; it will also be fp32)\n",
        "    p.requires_grad = p.requires_grad     # no-op; keeps current flags\n",
        "\n",
        "\n",
        "#model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "2c62f200-2394-4509-81d6-85de6a4851e4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 149,620,737\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the MSCOCO Data Loaders"
      ],
      "metadata": {
        "id": "h_cfc6FcltX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision import transforms, datasets\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "\n",
        "class CocoDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=None)\n",
        "        self.transform = transform\n",
        "        self.data = []\n",
        "        # Flatten (image, captions) so each item is (image, single_caption)\n",
        "        for img_idx in range(len(self.dataset)):\n",
        "            image, captions = self.dataset[img_idx]\n",
        "            for caption in captions:\n",
        "                self.data.append((img_idx, caption))\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_idx, caption = self.data[idx]\n",
        "        image, _ = self.dataset[img_idx]  # get the image\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        text = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return image, text\n",
        "\n",
        "train_img_dir = os.path.join(data_dir, 'train2014')\n",
        "train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')\n",
        "\n",
        "train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=64,\n",
        "    shuffle=True,\n",
        "    num_workers=2,\n",
        "    pin_memory=True\n",
        ")\n"
      ],
      "metadata": {
        "id": "dGosFCtHhMMi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e04ad06d-6f55-4ef9-c1f3-819c340f6bd6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.69s)\n",
            "creating index...\n",
            "index created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ftKgCZzre6B-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ashw2_19e6k6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "#  CLIP + LoRA + TE Pruning\n",
        "#  (NaN-safe training + COCO eval) - Perturbation TE (no more TE=0)\n",
        "# =========================\n",
        "\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "from typing import List, Tuple, Union\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import AdamW\n",
        "from torch.nn.utils import clip_grad_norm_\n",
        "\n",
        "# -------------------------\n",
        "# Utils: safe numerics\n",
        "# -------------------------\n",
        "def safe_l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:\n",
        "    return x / x.norm(dim=dim, keepdim=True).clamp_min(eps)\n",
        "\n",
        "def get_safe_logit_scale(model, clamp_low: float = -5.0, clamp_high: float = 5.0) -> torch.Tensor:\n",
        "    # Clamp BEFORE exp to avoid overflow, keep on model device\n",
        "    if hasattr(model, \"logit_scale\"):\n",
        "        return model.logit_scale.float().clamp(clamp_low, clamp_high).exp()\n",
        "    return torch.tensor(1.0, device=next(model.parameters()).device)\n",
        "\n",
        "# -------------------------\n",
        "# LoRA wrapper (MHA-compatible)\n",
        "# -------------------------\n",
        "class LoRALinear(nn.Module):\n",
        "    \"\"\"\n",
        "    Wraps an nn.Linear and exposes a read-only merged .weight/.bias\n",
        "    so nn.MultiheadAttention can read out_proj.weight directly.\n",
        "    \"\"\"\n",
        "    def __init__(self, base_linear: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        assert isinstance(base_linear, nn.Linear), \"LoRALinear expects an nn.Linear as base\"\n",
        "        self.base = base_linear\n",
        "        for p in self.base.parameters():\n",
        "            p.requires_grad = False\n",
        "\n",
        "        self.r = int(r)\n",
        "        self.scaling = float(alpha) / float(r)\n",
        "\n",
        "        dev, dt = self.base.weight.device, self.base.weight.dtype\n",
        "        self.lora_A = nn.Parameter(torch.zeros(r, self.base.in_features,  device=dev, dtype=dt))\n",
        "        self.lora_B = nn.Parameter(torch.zeros(self.base.out_features, r, device=dev, dtype=dt))\n",
        "        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n",
        "        nn.init.zeros_(self.lora_B)\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
        "\n",
        "    @property\n",
        "    def weight(self) -> torch.Tensor:\n",
        "        # Used directly by nn.MultiheadAttention\n",
        "        return self.base.weight + self.scaling * (self.lora_B @ self.lora_A)\n",
        "\n",
        "    @property\n",
        "    def bias(self):\n",
        "        return self.base.bias\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.base(x)\n",
        "        delta = (self.dropout(x) @ self.lora_A.t()) @ self.lora_B.t()\n",
        "        return out + self.scaling * delta\n",
        "\n",
        "# -------------------------\n",
        "# Injection / upgrade helpers\n",
        "# -------------------------\n",
        "def _wrap_or_upgrade_linear(mod: nn.Module, attr: str, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    If mod.<attr> is:\n",
        "      - nn.Linear: wrap with LoRALinear\n",
        "      - old LoRALinear (without .weight): upgrade to new LoRALinear and copy A/B\n",
        "      - new LoRALinear: leave as-is\n",
        "    \"\"\"\n",
        "    cur = getattr(mod, attr, None)\n",
        "\n",
        "    # Already the new wrapper?\n",
        "    if isinstance(cur, LoRALinear) and isinstance(getattr(type(cur), \"weight\", None), property):\n",
        "        return\n",
        "\n",
        "    # Old wrapper (no .weight property) -> upgrade\n",
        "    if isinstance(cur, nn.Module) and cur.__class__.__name__ == \"LoRALinear\" and not hasattr(cur, \"weight\"):\n",
        "        base = cur.base\n",
        "        new = LoRALinear(base, r=r, alpha=alpha, dropout=dropout)\n",
        "        if hasattr(cur, \"lora_A\") and hasattr(cur, \"lora_B\"):\n",
        "            with torch.no_grad():\n",
        "                if new.lora_A.shape == cur.lora_A.shape:\n",
        "                    new.lora_A.copy_(cur.lora_A.data)\n",
        "                if new.lora_B.shape == cur.lora_B.shape:\n",
        "                    new.lora_B.copy_(cur.lora_B.data)\n",
        "        setattr(mod, attr, new)\n",
        "        return\n",
        "\n",
        "    # Fresh wrap\n",
        "    if isinstance(cur, nn.Linear):\n",
        "        setattr(mod, attr, LoRALinear(cur, r=r, alpha=alpha, dropout=dropout))\n",
        "\n",
        "\n",
        "\n",
        "# --- helpers to iterate resblocks safely ---\n",
        "def _iter_resblocks(stack_module: nn.Module):\n",
        "    # Works for model.visual and for the text tower (model itself)\n",
        "    if hasattr(stack_module, \"transformer\") and hasattr(stack_module.transformer, \"resblocks\"):\n",
        "        return list(stack_module.transformer.resblocks)\n",
        "    return []\n",
        "\n",
        "def add_lora_to_clip_vit_b16(model, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    Add/upgrade LoRA at:\n",
        "      - Vision blocks: attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "      - Text blocks:   attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "    Skips pruned/identity blocks or any block missing the expected attrs.\n",
        "    \"\"\"\n",
        "    # Vision\n",
        "    for blk in _iter_resblocks(model.visual):\n",
        "        # skip identity/pruned or non-standard blocks\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Text\n",
        "    for blk in _iter_resblocks(model):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Freeze all, then enable ONLY LoRA params (and keep logit_scale frozen)\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad = False\n",
        "    if hasattr(model, \"logit_scale\") and isinstance(model.logit_scale, torch.Tensor):\n",
        "        model.logit_scale.requires_grad_(False)\n",
        "    for n, p in model.named_parameters():\n",
        "        if n.endswith(\"lora_A\") or n.endswith(\"lora_B\"):\n",
        "            p.requires_grad = True\n",
        "\n",
        "    return [p for p in model.parameters() if p.requires_grad]\n",
        "\n",
        "\n",
        "def _assert_outproj_has_weight(model):\n",
        "    \"\"\"\n",
        "    Only verifies blocks that actually have attn.out_proj.\n",
        "    Skips Identity/pruned blocks cleanly.\n",
        "    \"\"\"\n",
        "    def _check_tower(tower, tower_name: str):\n",
        "        for i, blk in enumerate(_iter_resblocks(tower)):\n",
        "            attn = getattr(getattr(blk, \"attn\", None), \"out_proj\", None)\n",
        "            if attn is None:\n",
        "                # likely Identity/pruned or non-standard; skip\n",
        "                continue\n",
        "            w = getattr(attn, \"weight\", None)\n",
        "            if not isinstance(w, torch.Tensor):\n",
        "                raise AssertionError(\n",
        "                    f\"{tower_name} resblock {i} out_proj has no Tensor `.weight` \"\n",
        "                    f\"(got {type(w)} from {type(attn)})\"\n",
        "                )\n",
        "    _check_tower(model.visual, \"vision\")\n",
        "    _check_tower(model, \"text\")\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Hooks & TE computation (PERTURBATION-BASED)\n",
        "# -------------------------\n",
        "class BlockHook:\n",
        "    \"\"\"\n",
        "    Captures each transformer block output.\n",
        "    Auto-detects orientation (LND vs NLD) by choosing the version\n",
        "    with higher between-sample variance after token pooling.\n",
        "    Stores CPU float tensors.\n",
        "    \"\"\"\n",
        "    def __init__(self):\n",
        "        self.buf: List[torch.Tensor] = []\n",
        "        self.handles: List[torch.utils.hooks.RemovableHandle] = []\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def _choose_nld(self, y: torch.Tensor) -> torch.Tensor:\n",
        "        # Candidate 1: assume input is LND (typical inside CLIP blocks)\n",
        "        cand1 = y.permute(1, 0, 2)  # NLD\n",
        "        # Candidate 2: assume input is already NLD\n",
        "        cand2 = y\n",
        "\n",
        "        def score(nld: torch.Tensor) -> torch.Tensor:\n",
        "            z = nld.mean(dim=1)  # [N, D] pooled over tokens\n",
        "            z = torch.nan_to_num(z).float()\n",
        "            return z.var(dim=0, unbiased=False).mean()  # scalar\n",
        "\n",
        "        s1 = score(cand1)\n",
        "        s2 = score(cand2)\n",
        "        return cand1 if (s1 >= s2) else cand2\n",
        "\n",
        "    def _hook(self, module, inp, out):\n",
        "        if isinstance(out, torch.Tensor) and out.ndim == 3:\n",
        "            nld = self._choose_nld(out)\n",
        "            self.buf.append(torch.nan_to_num(nld).float().cpu())\n",
        "        elif isinstance(out, (tuple, list)):\n",
        "            for x in out:\n",
        "                if isinstance(x, torch.Tensor) and x.ndim == 3:\n",
        "                    nld = self._choose_nld(x)\n",
        "                    self.buf.append(torch.nan_to_num(nld).float().cpu())\n",
        "                    break\n",
        "\n",
        "    def register_on_blocks(self, blocks: nn.ModuleList):\n",
        "        self.clear()\n",
        "        for blk in blocks:\n",
        "            self.handles.append(blk.register_forward_hook(self._hook))\n",
        "\n",
        "    def clear(self):\n",
        "        for h in self.handles:\n",
        "            h.remove()\n",
        "        self.handles = []\n",
        "        self.buf = []\n",
        "\n",
        "def pool_sequence(x: torch.Tensor, kind: str = \"cls\") -> torch.Tensor:\n",
        "    # x: [B, T, D]\n",
        "    if x.ndim != 3:\n",
        "        raise ValueError(\"Expected [B, T, D] tensor from blocks.\")\n",
        "    if kind == \"cls\":\n",
        "        return x[:, 0, :]\n",
        "    return x.mean(dim=1)\n",
        "\n",
        "def _batch_variance_debug(layer_outputs: List[torch.Tensor], pool=\"cls\") -> List[float]:\n",
        "    vals = []\n",
        "    for x in layer_outputs:\n",
        "        z = pool_sequence(x, kind=pool)  # [B, D]\n",
        "        z = torch.nan_to_num(z).float()\n",
        "        vals.append(float(z.var(dim=0, unbiased=False).mean().item()))\n",
        "    return vals\n",
        "\n",
        "def _collect_block_outputs_with_noise(\n",
        "    model,\n",
        "    images: torch.Tensor,\n",
        "    texts: torch.Tensor,\n",
        "    device: str,\n",
        "    sigma_v: float,\n",
        "    sigma_t: float,\n",
        ") -> Tuple[List[torch.Tensor], List[torch.Tensor]]:\n",
        "    \"\"\"\n",
        "    One forward pass with additive Gaussian noise injected:\n",
        "      - Vision: add noise to input pixels (already normalized) => images + noise\n",
        "      - Text:   add noise to token embeddings via a forward hook on model.token_embedding\n",
        "    Returns lists of per-block outputs (vision_NLD_list, text_NLD_list)\n",
        "    \"\"\"\n",
        "    model.eval()\n",
        "\n",
        "    # Vision hooks\n",
        "    v_hook = BlockHook()\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    v_hook.register_on_blocks(v_blocks)\n",
        "\n",
        "    # Text hooks + embedding noise hook\n",
        "    t_hook = BlockHook()\n",
        "    t_blocks = model.transformer.resblocks\n",
        "    t_hook.register_on_blocks(t_blocks)\n",
        "\n",
        "    emb_handle = None\n",
        "    if hasattr(model, \"token_embedding\") and sigma_t > 0:\n",
        "        def _emb_noise(_m, _inp, out):\n",
        "            return out + sigma_t * torch.randn_like(out)\n",
        "        emb_handle = model.token_embedding.register_forward_hook(_emb_noise)\n",
        "\n",
        "    try:\n",
        "        # Vision pass (noisy pixels)\n",
        "        if sigma_v > 0:\n",
        "            img_noisy = images + sigma_v * torch.randn_like(images)\n",
        "        else:\n",
        "            img_noisy = images\n",
        "        _ = model.encode_image(img_noisy.to(device, non_blocking=True))\n",
        "\n",
        "        # Text pass (embedding hook kicks in automatically)\n",
        "        _ = model.encode_text(texts.to(device, non_blocking=True))\n",
        "    finally:\n",
        "        if emb_handle is not None:\n",
        "            emb_handle.remove()\n",
        "\n",
        "    v_out = v_hook.buf[:]  # copy lists\n",
        "    t_out = t_hook.buf[:]\n",
        "    v_hook.clear()\n",
        "    t_hook.clear()\n",
        "    return v_out, t_out\n",
        "\n",
        "def _te_from_two_runs(\n",
        "    out1: List[torch.Tensor],\n",
        "    out2: List[torch.Tensor],\n",
        "    pool: str = \"cls\",\n",
        ") -> List[float]:\n",
        "    \"\"\"\n",
        "    Given two noisy forward lists of per-block outputs, compute TE using\n",
        "    Δz_l = z_l^(2) - z_l^(1) per sample (same batch), then cosine^2 across samples.\n",
        "    \"\"\"\n",
        "    assert len(out1) == len(out2) and len(out1) >= 2\n",
        "    # pooled shapes: [B, D]\n",
        "    pooled1 = [pool_sequence(x, kind=pool) for x in out1]\n",
        "    pooled2 = [pool_sequence(x, kind=pool) for x in out2]\n",
        "\n",
        "    # deltas per layer: [B, D]\n",
        "    deltas = [torch.nan_to_num(pooled2[l] - pooled1[l]) for l in range(len(pooled1))]\n",
        "\n",
        "    te_vals = []\n",
        "    eps = 1e-6\n",
        "    for l in range(len(deltas) - 1):\n",
        "        A = deltas[l]      # [B, D]\n",
        "        Bn = deltas[l + 1] # [B, D]\n",
        "        num = (A * Bn).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * Bn.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "\n",
        "\n",
        "# ---- Explicit per-block forwards (no hooks) ----\n",
        "\n",
        "def _first_param_dtype(module, default=torch.float32):\n",
        "    for p in module.parameters(recurse=True):\n",
        "        return p.dtype\n",
        "    return default\n",
        "\n",
        "@torch.no_grad()\n",
        "def _vision_block_outputs(model, images, sigma_v: float, device: str):\n",
        "    v = model.visual\n",
        "    pdev   = next(v.parameters()).device\n",
        "    v_dtype = _first_param_dtype(v, torch.float32)  # now fp32 after model.float()\n",
        "\n",
        "    x = images.to(pdev, non_blocking=True).to(v_dtype)   # << always fp32\n",
        "    if sigma_v > 0:\n",
        "        x = x + sigma_v * torch.randn_like(x, dtype=v_dtype)\n",
        "\n",
        "    x = v.conv1(x)\n",
        "    x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)\n",
        "    x = torch.cat([v.class_embedding[None, None, :].to(v_dtype).expand(x.shape[0], 1, -1), x], dim=1)\n",
        "    x = x + v.positional_embedding.to(v_dtype)\n",
        "    x = v.ln_pre(x.to(_first_param_dtype(v.ln_pre, v_dtype)))\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in v.transformer.resblocks:\n",
        "        # after model.float() every block is fp32, so this cast is a no-op\n",
        "        if x.dtype != v_dtype:\n",
        "            x = x.to(v_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())\n",
        "    return outs\n",
        "\n",
        "@torch.no_grad()\n",
        "def _text_block_outputs(model, texts, sigma_t: float, device: str):\n",
        "    tr = model.transformer\n",
        "    tr_dtype = _first_param_dtype(tr, torch.float32)\n",
        "\n",
        "    x = model.token_embedding(texts.to(device, non_blocking=True))\n",
        "    x = x.to(tr_dtype) + model.positional_embedding.to(tr_dtype)\n",
        "    if sigma_t > 0:\n",
        "        x = x + sigma_t * torch.randn_like(x, dtype=tr_dtype)\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in tr.resblocks:\n",
        "        if x.dtype != tr_dtype:\n",
        "            x = x.to(tr_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())\n",
        "    return outs\n",
        "\n",
        "\n",
        "\n",
        "def get_probe_loader(train_dataset, batch_size=64, num_batches=2):\n",
        "    from torch.utils.data import Subset, DataLoader\n",
        "    total = len(train_dataset)\n",
        "    need = min(total, batch_size * num_batches)\n",
        "    if need == 0:\n",
        "        raise ValueError(\"Empty train_dataset for probe loader.\")\n",
        "    idxs = random.sample(range(total), need)\n",
        "    probe = Subset(train_dataset, idxs)\n",
        "    bs = min(batch_size, len(probe))\n",
        "    return DataLoader(\n",
        "        probe,\n",
        "        batch_size=bs,\n",
        "        shuffle=False,\n",
        "        num_workers=2,\n",
        "        pin_memory=True,\n",
        "        drop_last=True,   # <— avoid B=1 (variance==0) batches in TE\n",
        "    )\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def _pooled(list_of_tensors, kind: str):\n",
        "    if kind not in (\"cls\", \"mean\"):\n",
        "        raise ValueError(\"pool must be 'cls' or 'mean'\")\n",
        "    pooled = []\n",
        "    for x in list_of_tensors:  # x: [B, T, D]\n",
        "        if kind == \"cls\":\n",
        "            pooled.append(x[:, 0, :])\n",
        "        else:\n",
        "            pooled.append(x.mean(dim=1))\n",
        "    return pooled\n",
        "\n",
        "def _te_from_two_runs_lists(pooled1, pooled2):\n",
        "    # pooled*: list of [B, D]\n",
        "    assert len(pooled1) == len(pooled2) and len(pooled1) >= 2\n",
        "    eps = 1e-6\n",
        "    te_vals = []\n",
        "    for l in range(len(pooled1) - 1):\n",
        "        A  = torch.nan_to_num(pooled2[l]   - pooled1[l]).float()     # [B, D]\n",
        "        Bn = torch.nan_to_num(pooled2[l+1] - pooled1[l+1]).float()   # [B, D]\n",
        "        num = (A * Bn).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * Bn.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_te_per_layer(\n",
        "    model,\n",
        "    images,\n",
        "    texts,\n",
        "    pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    device: str = \"cuda\",\n",
        "    sigma_v: float = 1e-2,\n",
        "    sigma_t: float = 5e-2,\n",
        "    verbose_once: bool = True,\n",
        "):\n",
        "    if isinstance(pool, str):\n",
        "        pool_v, pool_t = pool, pool\n",
        "    else:\n",
        "        pool_v, pool_t = pool\n",
        "\n",
        "    # Run two independent noisy passes, explicitly collecting per-block outputs\n",
        "    v1 = _vision_block_outputs(model, images, sigma_v, device)\n",
        "    t1 = _text_block_outputs(model, texts,  sigma_t, device)\n",
        "    v2 = _vision_block_outputs(model, images, sigma_v, device)\n",
        "    t2 = _text_block_outputs(model, texts,  sigma_t, device)\n",
        "\n",
        "    # Sanity: batch size must be >= 2 for the variance debug to be meaningful\n",
        "    if verbose_once:\n",
        "        Bv = v1[0].shape[0] if len(v1) else 0\n",
        "        Bt = t1[0].shape[0] if len(t1) else 0\n",
        "        if Bv < 2 or Bt < 2:\n",
        "            print(f\"[TE Debug] Probe batch too small (vision B={Bv}, text B={Bt}). \"\n",
        "                  f\"Use drop_last=True or a bigger probe batch.\")\n",
        "\n",
        "    pv1, pv2 = _pooled(v1, pool_v), _pooled(v2, pool_v)\n",
        "    pt1, pt2 = _pooled(t1, pool_t), _pooled(t2, pool_t)\n",
        "\n",
        "    if verbose_once:\n",
        "        # mean absolute delta magnitude as a quick wake-up check\n",
        "        md_v = float(torch.stack([ (b-a).abs().mean()\n",
        "                                   for a,b in zip(pv1, pv2) ]).mean().item())\n",
        "        md_t = float(torch.stack([ (b-a).abs().mean()\n",
        "                                   for a,b in zip(pt1, pt2) ]).mean().item())\n",
        "        print(f\"[TE Debug] mean|Δ| vision={md_v:.3e}, text={md_t:.3e}  \"\n",
        "              f\"(sigmas: v={sigma_v:.2e}, t={sigma_t:.2e})\")\n",
        "\n",
        "    te_v = _te_from_two_runs_lists(pv1, pv2)\n",
        "    te_t = _te_from_two_runs_lists(pt1, pt2)\n",
        "    return te_v, te_t\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Pruning\n",
        "# -------------------------\n",
        "class IdentityBlock(nn.Module):\n",
        "    def forward(self, x):\n",
        "        return x\n",
        "\n",
        "def prune_topk_layers_by_te(model, te_v: List[float], te_t: List[float], k_v: int = 4, k_t: int = 4,\n",
        "                            guard_skip_first_last: bool = True, avoid_adjacent: bool = True) -> Tuple[List[int], List[int]]:\n",
        "    # Vision\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    V = len(v_blocks)\n",
        "    cand_v = list(range(1, V))  # prune block index i corresponds to TE at i-1\n",
        "    if guard_skip_first_last:\n",
        "        cand_v = [idx for idx in cand_v if idx not in (0, V-1)]\n",
        "    order_v = sorted(cand_v, key=lambda idx: te_v[idx - 1], reverse=True)\n",
        "    chosen_v = []\n",
        "    for idx in order_v:\n",
        "        if len(chosen_v) >= k_v:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_v):\n",
        "            continue\n",
        "        chosen_v.append(idx)\n",
        "    for idx in chosen_v:\n",
        "        v_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    # Text\n",
        "    t_blocks = model.transformer.resblocks\n",
        "    T = len(t_blocks)\n",
        "    cand_t = list(range(1, T))\n",
        "    if guard_skip_first_last:\n",
        "        cand_t = [idx for idx in cand_t if idx not in (0, T-1)]\n",
        "    order_t = sorted(cand_t, key=lambda idx: te_t[idx - 1], reverse=True)\n",
        "    chosen_t = []\n",
        "    for idx in order_t:\n",
        "        if len(chosen_t) >= k_t:\n",
        "            break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_t):\n",
        "            continue\n",
        "        chosen_t.append(idx)\n",
        "    for idx in chosen_t:\n",
        "        t_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    return chosen_v, chosen_t\n",
        "\n",
        "# -------------------------\n",
        "# Contrastive loss (NaN-safe)\n",
        "# -------------------------\n",
        "def clip_contrastive_loss(model, image_features, text_features):\n",
        "    # Cast to fp32 for numerics, then normalize safely\n",
        "    image_features = safe_l2_normalize(image_features.float(), dim=-1, eps=1e-6)\n",
        "    text_features  = safe_l2_normalize(text_features.float(),  dim=-1, eps=1e-6)\n",
        "\n",
        "    # B×B cosine sims scaled by a SAFE logit scale\n",
        "    logit_scale = get_safe_logit_scale(model)  # scalar tensor on correct device\n",
        "    logits_per_image = (logit_scale * (image_features @ text_features.t()))\n",
        "    logits_per_image = torch.nan_to_num(logits_per_image, nan=0.0, posinf=1e4, neginf=-1e4)\n",
        "    logits_per_text = logits_per_image.t()\n",
        "\n",
        "    B = image_features.size(0)\n",
        "    labels = torch.arange(B, device=logits_per_image.device)\n",
        "\n",
        "    loss_i = F.cross_entropy(logits_per_image, labels)\n",
        "    loss_t = F.cross_entropy(logits_per_text,  labels)\n",
        "    return 0.5 * (loss_i + loss_t)\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Training (NaN-safe; TE via perturbations)\n",
        "# -------------------------\n",
        "def train_with_te_prune_lora(\n",
        "    model,\n",
        "    train_dataloader,\n",
        "    train_dataset,\n",
        "    device,\n",
        "    lora_rank=8,\n",
        "    lora_alpha=16,\n",
        "    lora_dropout=0.0,\n",
        "    lr=1e-4,\n",
        "    weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "    te_pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    probe_batches=2,\n",
        "    te_sigma_v: float = 1e-3,\n",
        "    te_sigma_t: float = 1e-3,\n",
        "):\n",
        "    # LoRA injection/upgrade\n",
        "    lora_params = add_lora_to_clip_vit_b16(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)\n",
        "    _assert_outproj_has_weight(model)\n",
        "    model.to(device)\n",
        "\n",
        "    optimizer = AdamW(lora_params, lr=lr, weight_decay=weight_decay)\n",
        "    probe_loader = get_probe_loader(train_dataset, batch_size=64, num_batches=probe_batches)\n",
        "\n",
        "    history = {\"vision\": [], \"text\": []}\n",
        "\n",
        "    def epoch_pass(epoch_idx):\n",
        "        model.train()\n",
        "        total_loss, num_steps = 0.0, 0\n",
        "        for images, texts in train_dataloader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "            texts  = texts.to(device, non_blocking=True)\n",
        "\n",
        "            optimizer.zero_grad(set_to_none=True)\n",
        "\n",
        "            img_feat = model.encode_image(images)\n",
        "            txt_feat = model.encode_text(texts)\n",
        "\n",
        "            loss = clip_contrastive_loss(model, img_feat, txt_feat)\n",
        "\n",
        "            # Skip non-finite batches\n",
        "            if not torch.isfinite(loss):\n",
        "                with torch.no_grad():\n",
        "                    bad = {\n",
        "                        \"loss\": float(loss.detach().float().item()),\n",
        "                        \"img_feat_has_nan\": bool(torch.isnan(img_feat).any().item()),\n",
        "                        \"txt_feat_has_nan\": bool(torch.isnan(txt_feat).any().item()),\n",
        "                    }\n",
        "                print(\"[Warn] Skipping non-finite batch:\", bad)\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            clip_grad_norm_(lora_params, max_norm=1.0)\n",
        "            optimizer.step()\n",
        "\n",
        "            total_loss += float(loss.item())\n",
        "            num_steps += 1\n",
        "\n",
        "        avg_loss = total_loss / max(1, num_steps)\n",
        "\n",
        "        # --- TE on probe (eval mode) via perturbations ---\n",
        "        model.eval()\n",
        "        te_v_all, te_t_all = [], []\n",
        "        with torch.no_grad():\n",
        "            first = True\n",
        "            for images, texts in probe_loader:\n",
        "                tv, tt = compute_te_per_layer(\n",
        "                    model, images, texts,\n",
        "                    pool=te_pool,\n",
        "                    device=device,\n",
        "                    sigma_v=te_sigma_v,\n",
        "                    sigma_t=te_sigma_t,\n",
        "                    verbose_once=first\n",
        "                )\n",
        "                first = False\n",
        "                te_v_all.append(torch.tensor(tv))\n",
        "                te_t_all.append(torch.tensor(tt))\n",
        "        te_v_mean = torch.stack(te_v_all).nanmean(dim=0).tolist()\n",
        "        te_t_mean = torch.stack(te_t_all).nanmean(dim=0).tolist()\n",
        "\n",
        "        history[\"vision\"].append(te_v_mean)\n",
        "        history[\"text\"].append(te_t_mean)\n",
        "\n",
        "        print(f\"[Epoch {epoch_idx}] loss={avg_loss:.6f}\")\n",
        "        print(f\"  Vision TE per layer (ℓ=0..{len(te_v_mean)-1} -> prune ℓ+1):\")\n",
        "        print(\"   \", [round(x, 6) for x in te_v_mean])\n",
        "        print(f\"  Text   TE per layer (ℓ=0..{len(te_t_mean)-1} -> prune ℓ+1):\")\n",
        "        print(\"   \", [round(x, 6) for x in te_t_mean])\n",
        "\n",
        "    # --- Phase 1 ---\n",
        "    for ep in range(1, epochs_before_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    # --- Prune based on TE ---\n",
        "    last_te_v = history[\"vision\"][-1]\n",
        "    last_te_t = history[\"text\"][-1]\n",
        "    chosen_v, chosen_t = prune_topk_layers_by_te(\n",
        "        model, last_te_v, last_te_t,\n",
        "        k_v=prune_k_vision, k_t=prune_k_text,\n",
        "        guard_skip_first_last=True, avoid_adjacent=True\n",
        "    )\n",
        "    print(f\"Pruned Vision blocks at indices: {chosen_v} (Identity)\")\n",
        "    print(f\"Pruned Text   blocks at indices: {chosen_t} (Identity)\")\n",
        "\n",
        "    # --- Phase 2 ---\n",
        "    for ep in range(epochs_before_prune + 1, epochs_before_prune + epochs_after_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    return history\n",
        "\n",
        "# ======================================================\n",
        "# OPTIONAL: COCO Retrieval Evaluation (I2T / T2I)\n",
        "# ======================================================\n",
        "def evaluate_coco_retrieval(model, data_dir: str, batch_size: int = 64, input_resolution: int = 224, context_length: int = 77):\n",
        "    \"\"\"\n",
        "    Evaluate trained model on COCO val2014 captions.\n",
        "    Assumes `pip install pycocotools` and images/annotations present in data_dir.\n",
        "    \"\"\"\n",
        "    import clip\n",
        "    import numpy as np\n",
        "    from torchvision import datasets, transforms\n",
        "    from torchvision.transforms import InterpolationMode\n",
        "    from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "    device = next(model.parameters()).device\n",
        "\n",
        "    eval_transform = transforms.Compose([\n",
        "        transforms.Resize((input_resolution, input_resolution), interpolation=InterpolationMode.BICUBIC),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                             std=(0.26862954, 0.26130258, 0.27577711))\n",
        "    ])\n",
        "\n",
        "    class CocoEvalDataset(Dataset):\n",
        "        def __init__(self, root, annFile, transform=None):\n",
        "            self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "            self.transform = transform\n",
        "        def __len__(self):\n",
        "            return len(self.dataset)\n",
        "        def __getitem__(self, idx):\n",
        "            image, captions = self.dataset[idx]\n",
        "            return image, captions\n",
        "\n",
        "    def coco_collate_fn(batch):\n",
        "        images = [img for img, _ in batch]\n",
        "        captions = [caps for _, caps in batch]  # list[list[str]]\n",
        "        images = torch.stack(images, dim=0)\n",
        "        return images, captions\n",
        "\n",
        "    val_img_dir  = os.path.join(data_dir, 'val2014')\n",
        "    val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "    val_dataset  = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2,\n",
        "                              pin_memory=torch.cuda.is_available(), collate_fn=coco_collate_fn)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    all_image_features = []\n",
        "    all_text_features  = []\n",
        "    image_to_text_indices = []\n",
        "    all_captions_flat = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        text_count = 0\n",
        "        for images, batch_captions in val_loader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "\n",
        "            image_feats = model.encode_image(images)\n",
        "            image_feats = safe_l2_normalize(image_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_image_features.append(image_feats.cpu())\n",
        "\n",
        "            flat_captions = []\n",
        "            image_to_text_map_for_batch = []\n",
        "            for caps in batch_captions:\n",
        "                start_idx = text_count + len(flat_captions)\n",
        "                flat_captions.extend(caps)\n",
        "                end_idx = text_count + len(flat_captions)\n",
        "                image_to_text_map_for_batch.append((start_idx, end_idx))\n",
        "\n",
        "            if len(flat_captions) == 0:\n",
        "                image_to_text_indices.extend([[] for _ in batch_captions])\n",
        "                continue\n",
        "\n",
        "            texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "            text_feats = model.encode_text(texts)\n",
        "            text_feats = safe_l2_normalize(text_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_text_features.append(text_feats.cpu())\n",
        "            all_captions_flat.extend(flat_captions)\n",
        "\n",
        "            for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "                image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "            text_count += len(flat_captions)\n",
        "\n",
        "    all_image_features = torch.cat(all_image_features, dim=0)  # [N_img, D]\n",
        "    all_text_features  = torch.cat(all_text_features,  dim=0)  # [N_txt, D]\n",
        "\n",
        "    sim_matrix = all_image_features @ all_text_features.t()    # [N_img, N_txt]\n",
        "\n",
        "    def compute_recall_with_multiple_captions(sim_matrix: torch.Tensor,\n",
        "                                              image_to_text_indices,\n",
        "                                              k: int = 1) -> float:\n",
        "        n = sim_matrix.size(0)\n",
        "        successes = 0\n",
        "        for i in range(n):\n",
        "            scores = sim_matrix[i]\n",
        "            sorted_idx = torch.argsort(scores, descending=True)\n",
        "            correct = image_to_text_indices[i]\n",
        "            if not correct:\n",
        "                continue\n",
        "            min_rank = None\n",
        "            for cidx in correct:\n",
        "                pos = (sorted_idx == cidx).nonzero(as_tuple=False)\n",
        "                if pos.numel() > 0:\n",
        "                    rank = int(pos[0, 0].item())\n",
        "                    if (min_rank is None) or (rank < min_rank):\n",
        "                        min_rank = rank\n",
        "            if (min_rank is not None) and (min_rank < k):\n",
        "                successes += 1\n",
        "        return successes / max(1, n)\n",
        "\n",
        "    # I2T\n",
        "    r1  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "    r5  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "    r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "    print(\"Image-to-Text Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "    # T2I\n",
        "    text_to_image = [None] * all_text_features.size(0)\n",
        "    for img_idx, tinds in enumerate(image_to_text_indices):\n",
        "        for t in tinds:\n",
        "            text_to_image[t] = img_idx\n",
        "\n",
        "    sim_matrix_t2i = sim_matrix.t()  # [N_txt, N_img]\n",
        "\n",
        "    def compute_recall_text_to_image(sim_matrix_t2i: torch.Tensor,\n",
        "                                     text_to_image_list,\n",
        "                                     k: int = 1) -> float:\n",
        "        m = sim_matrix_t2i.size(0)\n",
        "        successes = 0\n",
        "        for j in range(m):\n",
        "            scores = sim_matrix_t2i[j]\n",
        "            sorted_indices = torch.argsort(scores, descending=True)\n",
        "            correct_img = text_to_image_list[j]\n",
        "            if correct_img is None:\n",
        "                continue\n",
        "            pos = (sorted_indices == correct_img).nonzero(as_tuple=False)\n",
        "            if pos.numel() > 0 and int(pos[0, 0].item()) < k:\n",
        "                successes += 1\n",
        "        return successes / max(1, m)\n",
        "\n",
        "    r1_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "    r5_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "    r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "    print(\"Text-to-Image Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "id": "s_IDtdvU9yLv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "with torch.no_grad():\n",
        "    print(\"conv1:\", model.visual.conv1.weight.dtype, model.visual.conv1.weight.device)\n",
        "    dummy = torch.randn(2, 3, model.visual.input_resolution, model.visual.input_resolution,\n",
        "                        device=next(model.parameters()).device, dtype=torch.float32)\n",
        "    _ = model.visual.conv1(dummy)  # should pass silently in fp32\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x64roaWQlWgk",
        "outputId": "5f49fd56-518a-423c-ccb6-8916312e5487"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "conv1: torch.float32 cuda:0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "history = train_with_te_prune_lora(\n",
        "    model=model,\n",
        "    train_dataloader=train_dataloader,\n",
        "    train_dataset=train_dataset,\n",
        "    device=device,\n",
        "    lora_rank=8, lora_alpha=16, lora_dropout=0.0,\n",
        "    lr=1e-4, weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "    te_pool=\"mean\",          # or \"cls\" (often better for ViT)\n",
        "    probe_batches=2,\n",
        "    te_sigma_v=1e-2,         # <-- key\n",
        "    te_sigma_t=5e-2,         # <-- key\n",
        ")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cEBXmg4aNCdo",
        "outputId": "061f6970-56d8-4dd0-fc96-a33ba6c18605"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[TE Debug] mean|Δ| vision=2.193e-03, text=5.285e-02  (sigmas: v=1.00e-02, t=5.00e-02)\n",
            "[Epoch 1] loss=0.162197\n",
            "  Vision TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.11443, 0.122704, 0.039134, 0.047108, 0.046906, 0.04845, 0.179606, 0.367432, 0.347021, 0.359968, 0.219194]\n",
            "  Text   TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.390287, 0.407227, 0.421385, 0.420529, 0.409381, 0.410627, 0.433728, 0.40956, 0.416847, 0.413812, 0.228143]\n",
            "[TE Debug] mean|Δ| vision=2.302e-03, text=5.254e-02  (sigmas: v=1.00e-02, t=5.00e-02)\n",
            "[Epoch 2] loss=0.125690\n",
            "  Vision TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.100597, 0.126949, 0.043077, 0.043075, 0.048225, 0.043256, 0.148033, 0.373181, 0.359663, 0.369577, 0.237096]\n",
            "  Text   TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.39782, 0.405314, 0.415949, 0.412757, 0.412124, 0.416727, 0.430464, 0.417764, 0.41907, 0.416669, 0.220526]\n",
            "Pruned Vision blocks at indices: [8, 10, 2, 5] (Identity)\n",
            "Pruned Text   blocks at indices: [7, 9, 3, 5] (Identity)\n",
            "[TE Debug] mean|Δ| vision=1.863e-03, text=5.112e-02  (sigmas: v=1.00e-02, t=5.00e-02)\n",
            "[Epoch 3] loss=0.481578\n",
            "  Vision TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.047928, 0.5, 0.026146, 0.036271, 0.5, 0.043085, 0.150449, 0.5, 0.36021, 0.5, 0.209292]\n",
            "  Text   TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.38933, 0.399941, 0.5, 0.381214, 0.5, 0.40086, 0.5, 0.363126, 0.5, 0.380782, 0.234233]\n",
            "[TE Debug] mean|Δ| vision=1.792e-03, text=5.080e-02  (sigmas: v=1.00e-02, t=5.00e-02)\n",
            "[Epoch 4] loss=0.282571\n",
            "  Vision TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.047581, 0.5, 0.025697, 0.033292, 0.5, 0.040057, 0.176318, 0.5, 0.36569, 0.5, 0.194366]\n",
            "  Text   TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.3828, 0.406065, 0.5, 0.390422, 0.5, 0.410173, 0.5, 0.371239, 0.5, 0.382732, 0.229406]\n",
            "[TE Debug] mean|Δ| vision=1.858e-03, text=5.284e-02  (sigmas: v=1.00e-02, t=5.00e-02)\n",
            "[Epoch 5] loss=0.233113\n",
            "  Vision TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.049338, 0.5, 0.026685, 0.031542, 0.5, 0.041793, 0.139241, 0.5, 0.347695, 0.5, 0.197217]\n",
            "  Text   TE per layer (ℓ=0..10 -> prune ℓ+1):\n",
            "    [0.380875, 0.388526, 0.5, 0.372775, 0.5, 0.400832, 0.5, 0.365856, 0.5, 0.386058, 0.231311]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "etHTsgFsNDdf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "evaluate_coco_retrieval(model, data_dir, batch_size=64)\n"
      ],
      "metadata": {
        "id": "cBmSHD9-kipW",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e984b281-bf74-4bed-a845-c65e4320096c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.33s)\n",
            "creating index...\n",
            "index created!\n",
            "Image-to-Text Retrieval:\n",
            "Recall@1:  15.95%\n",
            "Recall@5:  34.81%\n",
            "Recall@10: 45.51%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1:  9.53%\n",
            "Recall@5:  24.08%\n",
            "Recall@10: 33.11%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "RPv7HIqxkjON"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LQXgZ_alw9gc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "d42i-P_Yrq-Y"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}