{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "e3cc56b4-73a8-48f4-b82c-95adafa10c25"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "id": "9XMDkrWBhLGg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the CLIP Model: CLIP ViT-B/16 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "\n",
        "# after: model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model = model.float()                     # unify everything to fp32\n",
        "for p in model.parameters():              # (LoRA gets injected later; it will also be fp32)\n",
        "    p.requires_grad = p.requires_grad     # no-op; keeps current flags\n",
        "\n",
        "\n",
        "#model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the MSCOCO Data Loaders"
      ],
      "metadata": {
        "id": "h_cfc6FcltX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision import transforms, datasets\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "\n",
        "class CocoDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=None)\n",
        "        self.transform = transform\n",
        "        self.data = []\n",
        "        # Flatten (image, captions) so each item is (image, single_caption)\n",
        "        for img_idx in range(len(self.dataset)):\n",
        "            image, captions = self.dataset[img_idx]\n",
        "            for caption in captions:\n",
        "                self.data.append((img_idx, caption))\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_idx, caption = self.data[idx]\n",
        "        image, _ = self.dataset[img_idx]  # get the image\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        text = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return image, text\n",
        "\n",
        "train_img_dir = os.path.join(data_dir, 'train2014')\n",
        "train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')\n",
        "\n",
        "train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=64,\n",
        "    shuffle=True,\n",
        "    num_workers=2,\n",
        "    pin_memory=True\n",
        ")\n"
      ],
      "metadata": {
        "id": "dGosFCtHhMMi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "#  CLIP + LoRA + TE Head Pruning (ViT-B/16 vision)\n",
        "#  (NaN-safe training + COCO eval) - Perturbation CosTE per head\n",
        "#  Role-aware (benign-first) head pruning\n",
        "# =========================\n",
        "\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "from typing import List, Tuple, Union, Optional\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import AdamW\n",
        "from torch.nn.utils import clip_grad_norm_\n",
        "\n",
        "# -------------------------\n",
        "# Utils: safe numerics\n",
        "# -------------------------\n",
        "def safe_l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:\n",
        "    return x / x.norm(dim=dim, keepdim=True).clamp_min(eps)\n",
        "\n",
        "def get_safe_logit_scale(model, clamp_low: float = -5.0, clamp_high: float = 5.0) -> torch.Tensor:\n",
        "    # Clamp BEFORE exp to avoid overflow, keep on model device\n",
        "    if hasattr(model, \"logit_scale\"):\n",
        "        return model.logit_scale.float().clamp(clamp_low, clamp_high).exp()\n",
        "    return torch.tensor(1.0, device=next(model.parameters()).device)\n",
        "\n",
        "# -------------------------\n",
        "# LoRA wrapper (MHA-compatible)\n",
        "# -------------------------\n",
        "class LoRALinear(nn.Module):\n",
        "    \"\"\"\n",
        "    Wraps an nn.Linear and exposes a read-only merged .weight/.bias\n",
        "    so nn.MultiheadAttention can read out_proj.weight directly.\n",
        "    \"\"\"\n",
        "    def __init__(self, base_linear: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        assert isinstance(base_linear, nn.Linear), \"LoRALinear expects an nn.Linear as base\"\n",
        "        self.base = base_linear\n",
        "        for p in self.base.parameters():\n",
        "            p.requires_grad = False\n",
        "\n",
        "        self.r = int(r)\n",
        "        self.scaling = float(alpha) / float(r)\n",
        "\n",
        "        dev, dt = self.base.weight.device, self.base.weight.dtype\n",
        "        self.lora_A = nn.Parameter(torch.zeros(r, self.base.in_features,  device=dev, dtype=dt))\n",
        "        self.lora_B = nn.Parameter(torch.zeros(self.base.out_features, r, device=dev, dtype=dt))\n",
        "        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n",
        "        nn.init.zeros_(self.lora_B)\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
        "\n",
        "    @property\n",
        "    def weight(self) -> torch.Tensor:\n",
        "        # Used directly by nn.MultiheadAttention\n",
        "        return self.base.weight + self.scaling * (self.lora_B @ self.lora_A)\n",
        "\n",
        "    @property\n",
        "    def bias(self):\n",
        "        return self.base.bias\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.base(x)\n",
        "        delta = (self.dropout(x) @ self.lora_A.t()) @ self.lora_B.t()\n",
        "        return out + self.scaling * delta\n",
        "\n",
        "# -------------------------\n",
        "# Injection / upgrade helpers\n",
        "# -------------------------\n",
        "def _wrap_or_upgrade_linear(mod: nn.Module, attr: str, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    If mod.<attr> is:\n",
        "      - nn.Linear: wrap with LoRALinear\n",
        "      - old LoRALinear (without .weight): upgrade to new LoRALinear and copy A/B\n",
        "      - new LoRALinear: leave as-is\n",
        "    \"\"\"\n",
        "    cur = getattr(mod, attr, None)\n",
        "\n",
        "    # Already the new wrapper?\n",
        "    if isinstance(cur, LoRALinear) and isinstance(getattr(type(cur), \"weight\", None), property):\n",
        "        return\n",
        "\n",
        "    # Old wrapper (no .weight property) -> upgrade\n",
        "    if isinstance(cur, nn.Module) and cur.__class__.__name__ == \"LoRALinear\" and not hasattr(cur, \"weight\"):\n",
        "        base = cur.base\n",
        "        new = LoRALinear(base, r=r, alpha=alpha, dropout=dropout)\n",
        "        if hasattr(cur, \"lora_A\") and hasattr(cur, \"lora_B\"):\n",
        "            with torch.no_grad():\n",
        "                if new.lora_A.shape == cur.lora_A.shape:\n",
        "                    new.lora_A.copy_(cur.lora_A.data)\n",
        "                if new.lora_B.shape == cur.lora_B.shape:\n",
        "                    new.lora_B.copy_(cur.lora_B.data)\n",
        "        setattr(mod, attr, new)\n",
        "        return\n",
        "\n",
        "    # Fresh wrap\n",
        "    if isinstance(cur, nn.Linear):\n",
        "        setattr(mod, attr, LoRALinear(cur, r=r, alpha=alpha, dropout=dropout))\n",
        "\n",
        "def _iter_resblocks(stack_module: nn.Module):\n",
        "    # Works for model.visual and for the text tower (model itself)\n",
        "    if hasattr(stack_module, \"transformer\") and hasattr(stack_module.transformer, \"resblocks\"):\n",
        "        return list(stack_module.transformer.resblocks)\n",
        "    return []\n",
        "\n",
        "def add_lora_to_clip_vit_b16(model, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    Add/upgrade LoRA at:\n",
        "      - Vision blocks: attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "      - Text blocks:   attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "    Skips pruned/identity blocks or any block missing the expected attrs.\n",
        "    \"\"\"\n",
        "    # Vision\n",
        "    for blk in _iter_resblocks(model.visual):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Text\n",
        "    for blk in _iter_resblocks(model):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Freeze all, then enable ONLY LoRA params (and keep logit_scale frozen)\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad = False\n",
        "    if hasattr(model, \"logit_scale\") and isinstance(model.logit_scale, torch.Tensor):\n",
        "        model.logit_scale.requires_grad_(False)\n",
        "    for n, p in model.named_parameters():\n",
        "        if n.endswith(\"lora_A\") or n.endswith(\"lora_B\"):\n",
        "            p.requires_grad = True\n",
        "\n",
        "    return [p for p in model.parameters() if p.requires_grad]\n",
        "\n",
        "def _assert_outproj_has_weight(model):\n",
        "    \"\"\"\n",
        "    Only verifies blocks that actually have attn.out_proj.\n",
        "    Skips Identity/pruned blocks cleanly.\n",
        "    \"\"\"\n",
        "    def _check_tower(tower, tower_name: str):\n",
        "        for i, blk in enumerate(_iter_resblocks(tower)):\n",
        "            attn = getattr(getattr(blk, \"attn\", None), \"out_proj\", None)\n",
        "            if attn is None:\n",
        "                continue\n",
        "            w = getattr(attn, \"weight\", None)\n",
        "            if not isinstance(w, torch.Tensor):\n",
        "                raise AssertionError(\n",
        "                    f\"{tower_name} resblock {i} out_proj has no Tensor `.weight` \"\n",
        "                    f\"(got {type(w)} from {type(attn)})\"\n",
        "                )\n",
        "    _check_tower(model.visual, \"vision\")\n",
        "    _check_tower(model, \"text\")\n",
        "\n",
        "# -------------------------\n",
        "# Per-head TE (CosTE) for CLIP ViT-B/16 vision\n",
        "# -------------------------\n",
        "def _first_param_dtype(module, default=torch.float32):\n",
        "    for p in module.parameters(recurse=True):\n",
        "        return p.dtype\n",
        "    return default\n",
        "\n",
        "def _mha_qkv_slices(embed_dim: int):\n",
        "    return slice(0, embed_dim), slice(embed_dim, 2*embed_dim), slice(2*embed_dim, 3*embed_dim)\n",
        "\n",
        "@torch.no_grad()\n",
        "def _pool_tokens(x: torch.Tensor, pool: str = \"mean\", include_cls: bool = False):\n",
        "    \"\"\"\n",
        "    x: [B, T, D]. If include_cls=False and T>=2, drop token 0 (CLS).\n",
        "    \"\"\"\n",
        "    if not include_cls and x.size(1) >= 2:\n",
        "        x = x[:, 1:, :]\n",
        "    if pool == \"cls\":\n",
        "        return x[:, :1, :].squeeze(1)           # [B, D]\n",
        "    elif pool == \"mean\":\n",
        "        return x.mean(dim=1)                    # [B, D]\n",
        "    else:\n",
        "        raise ValueError(\"pool must be 'cls' or 'mean'\")\n",
        "\n",
        "@torch.no_grad()\n",
        "def _vision_preln_inputs_two_runs(model, images, sigma_v: float, device: str):\n",
        "    \"\"\"\n",
        "    Two independent noisy forwards; return lists of Xn (pre-LN1 inputs) per block for run1/run2.\n",
        "    Each Xn is the ln_1(x) argument: [T, B, D].\n",
        "    \"\"\"\n",
        "    v = model.visual\n",
        "    pdev = next(v.parameters()).device\n",
        "    v_dtype = _first_param_dtype(v, torch.float32)\n",
        "\n",
        "    def _one(images, noise):\n",
        "        x = images.to(pdev, non_blocking=True).to(v_dtype)\n",
        "        if noise > 0:\n",
        "            x = x + noise * torch.randn_like(x, dtype=v_dtype)\n",
        "        x = v.conv1(x)\n",
        "        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)           # [B, T-1, D]\n",
        "        x = torch.cat([v.class_embedding[None, None, :].to(v_dtype).expand(x.shape[0], 1, -1), x], dim=1)\n",
        "        x = x + v.positional_embedding.to(v_dtype)\n",
        "        x = v.ln_pre(x.to(_first_param_dtype(v.ln_pre, v_dtype)))            # [B, T, D]\n",
        "        x = x.transpose(0, 1).contiguous()                                   # [T, B, D] for blocks\n",
        "\n",
        "        preln_list = []\n",
        "        for blk in v.transformer.resblocks:\n",
        "            Xn = blk.ln_1(x)                                                 # [T, B, D]\n",
        "            preln_list.append(Xn.clone())\n",
        "            # Forward to next layer state\n",
        "            attn_out = blk.attn(Xn, Xn, Xn, need_weights=False)[0]\n",
        "            x = x + attn_out\n",
        "            x = x + blk.mlp(blk.ln_2(x))\n",
        "        return preln_list\n",
        "\n",
        "    return _one(images, sigma_v), _one(images, sigma_v)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Text: pre-LN inputs for two runs (noise on embeddings)\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def _text_preln_inputs_two_runs(model, token_ids, sigma_t: float, device: str):\n",
        "    \"\"\"\n",
        "    Two independent noisy forwards; return lists of Xn (pre-LN1 inputs) per block for run1/run2.\n",
        "    Each Xn is the ln_1(x) argument: [T, B, D].\n",
        "    \"\"\"\n",
        "    txt = model\n",
        "    pdev = next(txt.parameters()).device\n",
        "    t_dtype = _first_param_dtype(txt, torch.float32)\n",
        "\n",
        "    def _one(token_ids, noise):\n",
        "        # token_ids: [B, T]\n",
        "        x = txt.token_embedding(token_ids.to(pdev)).to(t_dtype)              # [B, T, D]\n",
        "        x = x + txt.positional_embedding.to(t_dtype)[: x.size(1)]\n",
        "        if noise > 0:\n",
        "            x = x + noise * torch.randn_like(x, dtype=t_dtype)\n",
        "        x = x.transpose(0, 1).contiguous()                                   # [T, B, D]\n",
        "\n",
        "        preln_list = []\n",
        "        for blk in txt.transformer.resblocks:\n",
        "            Xn = blk.ln_1(x)                                                 # [T, B, D]\n",
        "            preln_list.append(Xn.clone())\n",
        "            attn_out = blk.attn(Xn, Xn, Xn, need_weights=False)[0]\n",
        "            x = x + attn_out\n",
        "            x = x + blk.mlp(blk.ln_2(x))\n",
        "        return preln_list\n",
        "\n",
        "    return _one(token_ids, sigma_t), _one(token_ids, sigma_t)\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Text: per-head outputs from Xn (same as vision, re-used)\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def _attn_head_outputs_from_Xn_text(attn: nn.MultiheadAttention, Xn_LBD: torch.Tensor):\n",
        "    # identical math; keep a wrapper to keep code readable\n",
        "    return _attn_head_outputs_from_Xn(attn, Xn_LBD)\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# TE (CosTE) per head for TEXT\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def compute_te_per_head_text(\n",
        "    model,\n",
        "    token_ids,                                  # [B, T]\n",
        "    device: str = \"cuda\",\n",
        "    sigma_t: float = 1e-2,\n",
        "    pool: str = \"cls\",                           # default different from vision; text often uses BOS/CLS\n",
        "    include_cls: bool = True,                    # include BOS token when pooling \"cls\"\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns list per layer; each element is a tensor of per-head TE scores [H].\n",
        "    Cos^2 between delta per-head output and delta total attention output across two noisy runs.\n",
        "    \"\"\"\n",
        "    Xn_list1, Xn_list2 = _text_preln_inputs_two_runs(model, token_ids, sigma_t, device)\n",
        "    te_per_layer = []\n",
        "    eps = 1e-6\n",
        "    for l, (Xn1, Xn2) in enumerate(zip(Xn_list1, Xn_list2)):\n",
        "        blk = model.transformer.resblocks[l]\n",
        "        attn = blk.attn\n",
        "\n",
        "        O_h1, O_all1 = _attn_head_outputs_from_Xn_text(attn, Xn1)\n",
        "        O_h2, O_all2 = _attn_head_outputs_from_Xn_text(attn, Xn2)\n",
        "\n",
        "        O_all1_p = _pool_tokens(O_all1, pool=pool, include_cls=include_cls)  # [B, D]\n",
        "        O_all2_p = _pool_tokens(O_all2, pool=pool, include_cls=include_cls)\n",
        "        delta_all = torch.nan_to_num(O_all2_p - O_all1_p).float()            # [B, D]\n",
        "        denom_all = delta_all.norm(dim=-1).clamp_min(eps)                    # [B]\n",
        "\n",
        "        head_scores = []\n",
        "        for h in range(attn.num_heads):\n",
        "            oh1_p = _pool_tokens(O_h1[h], pool=pool, include_cls=include_cls)\n",
        "            oh2_p = _pool_tokens(O_h2[h], pool=pool, include_cls=include_cls)\n",
        "            delta_h = torch.nan_to_num(oh2_p - oh1_p).float()\n",
        "\n",
        "            num = (delta_h * delta_all).sum(dim=-1)                          # [B]\n",
        "            den = (delta_h.norm(dim=-1).clamp_min(eps) * denom_all)          # [B]\n",
        "            cos = (num / den).clamp(-1.0, 1.0)\n",
        "            head_scores.append(0.5 * (cos * cos).mean())\n",
        "\n",
        "        te_per_layer.append(torch.stack(head_scores, dim=0))  # [H]\n",
        "    return te_per_layer\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_avg_te_per_head_text_over_probe(\n",
        "    model,\n",
        "    probe_loader,\n",
        "    device: str = \"cuda\",\n",
        "    sigma_t: float = 1e-2,\n",
        "    pool: str = \"cls\",\n",
        "    include_cls: bool = True,\n",
        "):\n",
        "    te_accum = None\n",
        "    n = 0\n",
        "    first = True\n",
        "    for _images, texts in probe_loader:\n",
        "        te_list = compute_te_per_head_text(\n",
        "            model, texts, device=device, sigma_t=sigma_t, pool=pool, include_cls=include_cls\n",
        "        )\n",
        "        te_stacked = torch.stack([t.cpu() for t in te_list], dim=0)  # [L,H]\n",
        "        te_accum = te_stacked if te_accum is None else (te_accum + te_stacked)\n",
        "        n += 1\n",
        "        if first:\n",
        "            print(f\"[Head-TE Debug/Text] collected per-head TE on one probe batch \"\n",
        "                  f\"(sigma_t={sigma_t:.2e}, pool={pool}, include_cls={include_cls})\")\n",
        "            first = False\n",
        "    return (te_accum / max(1, n)).tolist()  # [L][H]\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def _attn_head_outputs_from_Xn(attn: nn.MultiheadAttention, Xn_LBD: torch.Tensor):\n",
        "    \"\"\"\n",
        "    Compute per-head contributions O_h and their sum O_all for one block's attention,\n",
        "    given pre-LN input Xn (shape [T, B, D]). Returns (O_h_list, O_all),\n",
        "    each O_h is [B, T, D], O_all is [B, T, D].\n",
        "    \"\"\"\n",
        "    T, B, D = Xn_LBD.shape\n",
        "    H = attn.num_heads\n",
        "    dh = D // H\n",
        "\n",
        "    # Project Q,K,V with fused in_proj\n",
        "    W = attn.in_proj_weight     # [3D, D]\n",
        "    b = attn.in_proj_bias       # [3D] or None\n",
        "    sq, sk, sv = _mha_qkv_slices(D)\n",
        "\n",
        "    Xn = Xn_LBD.transpose(0, 1)                                 # [B, T, D]\n",
        "    Q = F.linear(Xn, W[sq], b[sq] if b is not None else None)   # [B, T, D]\n",
        "    K = F.linear(Xn, W[sk], b[sk] if b is not None else None)\n",
        "    V = F.linear(Xn, W[sv], b[sv] if b is not None else None)\n",
        "\n",
        "    Q = Q.view(B, T, H, dh).transpose(1, 2)                     # [B, H, T, dh]\n",
        "    K = K.view(B, T, H, dh).transpose(1, 2)                     # [B, H, T, dh]\n",
        "    V = V.view(B, T, H, dh).transpose(1, 2)                     # [B, H, T, dh]\n",
        "\n",
        "    scale = 1.0 / math.sqrt(dh)\n",
        "    attn_scores = torch.matmul(Q, K.transpose(-1, -2)) * scale  # [B, H, T, T]\n",
        "    attn_probs  = attn_scores.softmax(dim=-1)                   # [B, H, T, T]\n",
        "    Z = torch.matmul(attn_probs, V)                             # [B, H, T, dh]\n",
        "\n",
        "    # Project each head via its columns in out_proj\n",
        "    W_o = attn.out_proj.weight                                  # [D, D]\n",
        "    O_h = []\n",
        "    for h in range(H):\n",
        "        cols = slice(h*dh, (h+1)*dh)\n",
        "        W_o_h = W_o[:, cols]                                    # [D, dh]\n",
        "        O_h.append(torch.matmul(Z[:, h], W_o_h.transpose(0, 1)))# [B, T, D]\n",
        "    O_all = torch.stack(O_h, dim=1).sum(dim=1)                  # [B, T, D]\n",
        "    return O_h, O_all\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Text: attention probs from Xn (for mass features)\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def _attn_probs_from_Xn_text(attn: nn.MultiheadAttention, Xn_LBD: torch.Tensor):\n",
        "    # same computation; wrapper for symmetry\n",
        "    return _attn_probs_from_Xn(attn, Xn_LBD)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_head_roles_text_over_probe(\n",
        "    model,\n",
        "    probe_loader,\n",
        "    device: str = \"cuda\",\n",
        "    te_heads_per_layer: Optional[List[torch.Tensor]] = None,  # [L][H], required\n",
        "    in_hi_pct: float = 60.0,\n",
        "    out_hi_pct: float = 80.0,\n",
        "    te_hi_pct: float = 60.0,\n",
        "    te_lo_pct: float = 30.0,\n",
        "    require_te_for_role: bool = True,\n",
        "    use_te_as_tiebreak: bool = True,\n",
        "):\n",
        "    \"\"\"\n",
        "    TE-aware head roles for TEXT using:\n",
        "      • cls_in_mass  : mean attention TO BOS/CLS (idx 0)  -> sink-ish\n",
        "      • cls_out_mass : mean attention FROM BOS/CLS        -> broadcaster-ish\n",
        "      • te           : per-head CosTE\n",
        "    \"\"\"\n",
        "    txt = model\n",
        "    blocks = txt.transformer.resblocks\n",
        "    device_model = next(txt.parameters()).device\n",
        "\n",
        "    # Probe one mini-batch\n",
        "    _images, texts = next(iter(probe_loader))\n",
        "    texts = texts.to(device_model, non_blocking=True)\n",
        "\n",
        "    # Pre-LN inputs per block\n",
        "    Xn_list1, _ = _text_preln_inputs_two_runs(model, texts, sigma_t=0.0, device=device)\n",
        "\n",
        "    roles_per_layer, feats_per_layer = [], []\n",
        "\n",
        "    def _pct(vals: torch.Tensor, p: float) -> float:\n",
        "        vals = vals.sort().values\n",
        "        k = max(0, min(len(vals)-1, int(round(p/100.0 * (len(vals)-1)))))\n",
        "        return vals[k].item()\n",
        "\n",
        "    for l, Xn in enumerate(Xn_list1):\n",
        "        attn = blocks[l].attn\n",
        "        probs = _attn_probs_from_Xn_text(attn, Xn)  # [B,H,T,T]\n",
        "        B, H, T, _ = probs.shape\n",
        "        CLS = 0\n",
        "\n",
        "        # Mass features\n",
        "        cls_in_mass  = probs[..., :, CLS].mean(dim=(0, 2))   # [H]   (→ CLS)\n",
        "        cls_out_mass = probs[..., CLS, :].mean(dim=(0, 2))   # [H]   (CLS →)\n",
        "\n",
        "        # z-scores per layer\n",
        "        def _z(x: torch.Tensor) -> torch.Tensor:\n",
        "            m = x.mean(); s = x.std().clamp_min(1e-6)\n",
        "            return (x - m) / s\n",
        "\n",
        "        z_in  = _z(cls_in_mass)\n",
        "        z_out = _z(cls_out_mass)\n",
        "\n",
        "        te_l = torch.as_tensor(te_heads_per_layer[l], dtype=torch.float32, device=z_in.device)\n",
        "        in_hi  = _pct(z_in,  in_hi_pct)\n",
        "        out_hi = _pct(z_out, out_hi_pct)\n",
        "        te_hi  = _pct(te_l,  te_hi_pct)\n",
        "        te_lo  = _pct(te_l,  te_lo_pct)\n",
        "\n",
        "        roles, feats = [], []\n",
        "        for h in range(H):\n",
        "            zin, zout, te = z_in[h].item(), z_out[h].item(), float(te_l[h].item())\n",
        "\n",
        "            is_broad = (zout >= out_hi) and (zin < in_hi)\n",
        "            is_sink  = (zin  >= in_hi)  and (zout < out_hi)\n",
        "\n",
        "            if (is_broad or is_sink):\n",
        "                role = \"broadcaster\" if is_broad else \"sink\"\n",
        "                if require_te_for_role and te < te_lo:\n",
        "                    role = \"benign\"\n",
        "            else:\n",
        "                if use_te_as_tiebreak and te >= te_hi:\n",
        "                    role = \"sink\" if zin >= zout else \"broadcaster\"\n",
        "                else:\n",
        "                    role = \"benign\"\n",
        "\n",
        "            if role != \"benign\" and te < te_lo:\n",
        "                role = \"benign\"\n",
        "\n",
        "            roles.append(role)\n",
        "            feats.append({\n",
        "                \"cls_in_mass\":  float(cls_in_mass[h].item()),\n",
        "                \"cls_out_mass\": float(cls_out_mass[h].item()),\n",
        "                \"z_in\":  zin,\n",
        "                \"z_out\": zout,\n",
        "                \"te\":    te,\n",
        "            })\n",
        "\n",
        "        roles_per_layer.append(roles)\n",
        "        feats_per_layer.append(feats)\n",
        "\n",
        "    return roles_per_layer, feats_per_layer\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# TE (CosTE) per head\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def compute_te_per_head_vision(\n",
        "    model,\n",
        "    images,\n",
        "    device: str = \"cuda\",\n",
        "    sigma_v: float = 1e-2,\n",
        "    pool: str = \"mean\",\n",
        "    include_cls: bool = False,\n",
        "):\n",
        "    \"\"\"\n",
        "    Returns list per layer; each element is a tensor of per-head TE scores [H].\n",
        "    Cos^2 between delta per-head output and delta total attention output across two noisy runs.\n",
        "    \"\"\"\n",
        "    Xn_list1, Xn_list2 = _vision_preln_inputs_two_runs(model, images, sigma_v, device)\n",
        "    te_per_layer = []\n",
        "    eps = 1e-6\n",
        "    for l, (Xn1, Xn2) in enumerate(zip(Xn_list1, Xn_list2)):\n",
        "        blk = model.visual.transformer.resblocks[l]\n",
        "        attn = blk.attn\n",
        "\n",
        "        O_h1, O_all1 = _attn_head_outputs_from_Xn(attn, Xn1)     # lists of [B, T, D]; [B, T, D]\n",
        "        O_h2, O_all2 = _attn_head_outputs_from_Xn(attn, Xn2)\n",
        "\n",
        "        O_all1_p = _pool_tokens(O_all1, pool=pool, include_cls=include_cls)  # [B, D]\n",
        "        O_all2_p = _pool_tokens(O_all2, pool=pool, include_cls=include_cls)\n",
        "        delta_all = torch.nan_to_num(O_all2_p - O_all1_p).float()            # [B, D]\n",
        "        denom_all = delta_all.norm(dim=-1).clamp_min(eps)                    # [B]\n",
        "\n",
        "        head_scores = []\n",
        "        for h in range(attn.num_heads):\n",
        "            oh1_p = _pool_tokens(O_h1[h], pool=pool, include_cls=include_cls)  # [B, D]\n",
        "            oh2_p = _pool_tokens(O_h2[h], pool=pool, include_cls=include_cls)\n",
        "            delta_h = torch.nan_to_num(oh2_p - oh1_p).float()                  # [B, D]\n",
        "\n",
        "            num = (delta_h * delta_all).sum(dim=-1)                            # [B]\n",
        "            den = (delta_h.norm(dim=-1).clamp_min(eps) * denom_all)            # [B]\n",
        "            cos = (num / den).clamp(-1.0, 1.0)\n",
        "            head_scores.append(0.5 * (cos * cos).mean())                       # scalar\n",
        "\n",
        "        te_per_layer.append(torch.stack(head_scores, dim=0))    # [H]\n",
        "    return te_per_layer  # list length L (layers)\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_avg_te_per_head_vision_over_probe(\n",
        "    model,\n",
        "    probe_loader,\n",
        "    device: str = \"cuda\",\n",
        "    sigma_v: float = 1e-2,\n",
        "    pool: str = \"mean\",\n",
        "    include_cls: bool = False,\n",
        "):\n",
        "    te_accum = None\n",
        "    n = 0\n",
        "    first = True\n",
        "    for images, _texts in probe_loader:\n",
        "        te_list = compute_te_per_head_vision(\n",
        "            model, images, device=device, sigma_v=sigma_v, pool=pool, include_cls=include_cls\n",
        "        )\n",
        "        te_stacked = torch.stack([t.cpu() for t in te_list], dim=0)  # [L,H]\n",
        "        te_accum = te_stacked if te_accum is None else (te_accum + te_stacked)\n",
        "        n += 1\n",
        "        if first:\n",
        "            print(f\"[Head-TE Debug] collected per-head TE on one probe batch \"\n",
        "                  f\"(sigma_v={sigma_v:.2e}, pool={pool}, include_cls={include_cls})\")\n",
        "            first = False\n",
        "    return (te_accum / max(1, n)).tolist()  # list of length L, each a list len H\n",
        "\n",
        "# -------------------------\n",
        "# Head roles (sums-only) and role-aware pruning\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def _attn_probs_from_Xn(attn: nn.MultiheadAttention, Xn_LBD: torch.Tensor):\n",
        "    \"\"\"\n",
        "    Returns attention probabilities per head for one block given pre-LN inputs.\n",
        "    Xn_LBD: [T,B,D]  -> attn_probs: [B,H,T,T]\n",
        "    \"\"\"\n",
        "    T, B, D = Xn_LBD.shape\n",
        "    H = attn.num_heads\n",
        "    dh = D // H\n",
        "\n",
        "    W = attn.in_proj_weight\n",
        "    b = attn.in_proj_bias\n",
        "    sq = slice(0, D); sk = slice(D, 2*D)\n",
        "\n",
        "    Xn = Xn_LBD.transpose(0, 1)  # [B,T,D]\n",
        "    Q = F.linear(Xn, W[sq], b[sq] if b is not None else None)\n",
        "    K = F.linear(Xn, W[sk], b[sk] if b is not None else None)\n",
        "\n",
        "    Q = Q.view(B, T, H, dh).transpose(1, 2)  # [B,H,T,dh]\n",
        "    K = K.view(B, T, H, dh).transpose(1, 2)  # [B,H,T,dh]\n",
        "\n",
        "    scale = 1.0 / math.sqrt(dh)\n",
        "    attn_scores = torch.matmul(Q, K.transpose(-1, -2)) * scale  # [B,H,T,T]\n",
        "    attn_probs  = attn_scores.softmax(dim=-1)\n",
        "    return attn_probs\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def compute_head_roles_vision_over_probe(\n",
        "    model,\n",
        "    probe_loader,\n",
        "    device: str = \"cuda\",\n",
        "    te_heads_per_layer: Optional[List[torch.Tensor]] = None,  # [L][H], required for TE-aware\n",
        "    # Mass thresholds (percentiles over heads within a layer)\n",
        "    in_hi_pct: float = 60.0,     # lower -> more sinks\n",
        "    out_hi_pct: float = 80.0,    # higher -> fewer broadcasters\n",
        "    # TE thresholds (percentiles over heads within a layer)\n",
        "    te_hi_pct: float = 60.0,     # tie-break threshold for ambiguous heads\n",
        "    te_lo_pct: float = 30.0,     # demote very low-TE heads to benign\n",
        "    # Behavior toggles\n",
        "    require_te_for_role: bool = True,  # require TE >= te_lo to be sink/broadcaster\n",
        "    use_te_as_tiebreak: bool = True,   # if ambiguous, high TE picks side of larger z-score\n",
        "):\n",
        "    \"\"\"\n",
        "    TE-aware head roles using sums of attention mass + per-head TE:\n",
        "\n",
        "      • cls_in_mass  : mean attention TO CLS (sink-ish)\n",
        "      • cls_out_mass : mean attention FROM CLS (broadcaster-ish)\n",
        "      • te           : per-head CosTE (impact proxy)\n",
        "\n",
        "    Rules (per layer):\n",
        "      1) Strong broadcaster if zout >= out_hi and zin < in_hi AND (not require_te_for_role or te >= te_lo)\n",
        "      2) Strong sink        if zin  >= in_hi and zout < out_hi AND (not require_te_for_role or te >= te_lo)\n",
        "      3) Otherwise ambiguous:\n",
        "           - if use_te_as_tiebreak and te >= te_hi:\n",
        "                 choose 'sink' if z_in >= z_out else 'broadcaster'\n",
        "           - else: benign\n",
        "      4) Safety demotion: any non-benign with te < te_lo becomes benign.\n",
        "    \"\"\"\n",
        "    assert te_heads_per_layer is not None, \\\n",
        "        \"TE-aware labeling requires te_heads_per_layer; pass average TE [L][H].\"\n",
        "\n",
        "    v = model.visual\n",
        "    blocks = v.transformer.resblocks\n",
        "    device_model = next(v.parameters()).device\n",
        "\n",
        "    # Probe one mini-batch\n",
        "    images, _ = next(iter(probe_loader))\n",
        "    images = images.to(device_model, non_blocking=True)\n",
        "\n",
        "    # Pre-LN inputs per block\n",
        "    Xn_list1, _ = _vision_preln_inputs_two_runs(model, images, sigma_v=0.0, device=device)\n",
        "\n",
        "    roles_per_layer, feats_per_layer = [], []\n",
        "\n",
        "    def _pct(vals: torch.Tensor, p: float) -> float:\n",
        "        vals = vals.sort().values\n",
        "        k = max(0, min(len(vals)-1, int(round(p/100.0 * (len(vals)-1)))))\n",
        "        return vals[k].item()\n",
        "\n",
        "    for l, Xn in enumerate(Xn_list1):\n",
        "        attn = blocks[l].attn\n",
        "        probs = _attn_probs_from_Xn(attn, Xn)  # [B,H,T,T]\n",
        "        B, H, T, _ = probs.shape\n",
        "        CLS = 0\n",
        "\n",
        "        # Mass features (means over batch & query positions)\n",
        "        cls_in_mass  = probs[..., :, CLS].mean(dim=(0, 2))   # [H]\n",
        "        cls_out_mass = probs[..., CLS, :].mean(dim=(0, 2))   # [H]\n",
        "\n",
        "        # z-scores per layer\n",
        "        def _z(x: torch.Tensor) -> torch.Tensor:\n",
        "            m = x.mean(); s = x.std().clamp_min(1e-6)\n",
        "            return (x - m) / s\n",
        "\n",
        "        z_in  = _z(cls_in_mass)\n",
        "        z_out = _z(cls_out_mass)\n",
        "\n",
        "        # TE vector and thresholds\n",
        "        te_l = torch.as_tensor(te_heads_per_layer[l], dtype=torch.float32, device=z_in.device)\n",
        "        in_hi  = _pct(z_in,  in_hi_pct)\n",
        "        out_hi = _pct(z_out, out_hi_pct)\n",
        "        te_hi  = _pct(te_l,  te_hi_pct)\n",
        "        te_lo  = _pct(te_l,  te_lo_pct)\n",
        "\n",
        "        roles, feats = [], []\n",
        "        for h in range(H):\n",
        "            zin, zout, te = z_in[h].item(), z_out[h].item(), float(te_l[h].item())\n",
        "\n",
        "            is_broad = (zout >= out_hi) and (zin < in_hi)\n",
        "            is_sink  = (zin  >= in_hi)  and (zout < out_hi)\n",
        "\n",
        "            if (is_broad or is_sink):\n",
        "                if (not require_te_for_role) or (te >= te_lo):\n",
        "                    role = \"broadcaster\" if is_broad else \"sink\"\n",
        "                else:\n",
        "                    role = \"benign\"  # demote low-TE “fake” roles\n",
        "            else:\n",
        "                if use_te_as_tiebreak and te >= te_hi:\n",
        "                    role = \"sink\" if zin >= zout else \"broadcaster\"\n",
        "                else:\n",
        "                    role = \"benign\"\n",
        "\n",
        "            if role != \"benign\" and te < te_lo:\n",
        "                role = \"benign\"\n",
        "\n",
        "            roles.append(role)\n",
        "            feats.append({\n",
        "                \"cls_in_mass\":  float(cls_in_mass[h].item()),\n",
        "                \"cls_out_mass\": float(cls_out_mass[h].item()),\n",
        "                \"z_in\":  zin,\n",
        "                \"z_out\": zout,\n",
        "                \"te\":    te,\n",
        "            })\n",
        "\n",
        "        roles_per_layer.append(roles)\n",
        "        feats_per_layer.append(feats)\n",
        "\n",
        "    return roles_per_layer, feats_per_layer\n",
        "\n",
        "\n",
        "\n",
        "# ===== TE-aware token utilities =====\n",
        "\n",
        "def _mad_z_per_layer(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:\n",
        "    med = x.median()\n",
        "    mad = (x - med).abs().median().clamp_min(eps)\n",
        "    return (x - med) / mad\n",
        "\n",
        "@torch.no_grad()\n",
        "def _block_outputs_and_probs(attn: nn.MultiheadAttention, Xn_LBD: torch.Tensor):\n",
        "    T, B, D = Xn_LBD.shape\n",
        "    H = attn.num_heads\n",
        "    dh = D // H\n",
        "\n",
        "    W = attn.in_proj_weight; b = attn.in_proj_bias\n",
        "    Xn = Xn_LBD.transpose(0,1)  # [B,T,D]\n",
        "    Q = F.linear(Xn, W[:D],    b[:D]    if b is not None else None)\n",
        "    K = F.linear(Xn, W[D:2*D], b[D:2*D] if b is not None else None)\n",
        "    V = F.linear(Xn, W[2*D:],  b[2*D:]  if b is not None else None)\n",
        "    Q = Q.view(B,T,H,dh).transpose(1,2)\n",
        "    K = K.view(B,T,H,dh).transpose(1,2)\n",
        "    V = V.view(B,T,H,dh).transpose(1,2)\n",
        "\n",
        "    scale = 1.0 / math.sqrt(dh)\n",
        "    scores = (Q @ K.transpose(-1,-2)) * scale\n",
        "    P = scores.softmax(dim=-1)                      # [B,H,T,T]\n",
        "    Z = P @ V                                       # [B,H,T,dh]\n",
        "\n",
        "    Wo = attn.out_proj.weight\n",
        "    O_parts = []\n",
        "    for h in range(H):\n",
        "        cols = slice(h*dh, (h+1)*dh)\n",
        "        O_parts.append(Z[:,h] @ Wo[:, cols].T)      # [B,T,D]\n",
        "    O_all = torch.stack(O_parts, dim=1).sum(dim=1)  # [B,T,D]\n",
        "    return O_all, P\n",
        "\n",
        "@torch.no_grad()\n",
        "def _te_inout_alpha_for_layer(attn: nn.MultiheadAttention, X1: torch.Tensor, X2: torch.Tensor):\n",
        "    O1, P1 = _block_outputs_and_probs(attn, X1)     # [B,T,D], [B,H,T,T]\n",
        "    O2, P2 = _block_outputs_and_probs(attn, X2)\n",
        "    d = torch.nan_to_num((O2 - O1).float(), nan=0.0)   # [B,T,D]\n",
        "    B, T, D = d.shape\n",
        "\n",
        "    P = torch.stack([P1, P2], dim=0).mean(dim=0)       # [B,H,T,T]\n",
        "    A = P.mean(dim=(0,1))                               # [T,T] (avg over batch, heads)\n",
        "\n",
        "    # cosine^2 between token deltas (avg over batch)\n",
        "    dn = d / d.norm(dim=-1, keepdim=True).clamp_min(1e-6)\n",
        "    C = torch.einsum(\"btd,bsd->bts\", dn, dn).clamp(-1,1)\n",
        "    C = (C * C).mean(dim=0)                             # [T,T]\n",
        "\n",
        "    off = ~torch.eye(T, dtype=torch.bool, device=C.device)\n",
        "    # broadcast (row i -> *)\n",
        "    t_out = (A * C).sum(dim=1) / off.sum(dim=1).clamp_min(1)   # [T]\n",
        "    # collect   (* -> j)\n",
        "    t_in  = (A.T * C.T).sum(dim=1) / off.sum(dim=0).clamp_min(1)  # [T]\n",
        "\n",
        "    alpha = A.mean(dim=0)                                # col-avg attention\n",
        "    Xb = X1.transpose(0,1)                               # [B,T,D]\n",
        "    nu = Xb.mean(dim=0).norm(dim=-1)                     # optional residual norm\n",
        "    return t_out, t_in, alpha, nu\n",
        "\n",
        "@torch.no_grad()\n",
        "def classify_tokens_by_te_rules(\n",
        "    t_out: torch.Tensor, t_in: torch.Tensor, alpha: torch.Tensor, nu: Optional[torch.Tensor],\n",
        "    thresholds: dict, special_keep_idx: List[int]\n",
        ") -> List[str]:\n",
        "    z_tout = _mad_z_per_layer(t_out)\n",
        "    z_tin  = _mad_z_per_layer(t_in)\n",
        "    z_alpha = _mad_z_per_layer(alpha)\n",
        "    Delta = z_tin - z_tout\n",
        "\n",
        "    tau_a  = thresholds[\"tau_alpha\"]\n",
        "    tau_o  = thresholds[\"tau_out\"]\n",
        "    tau_i  = thresholds[\"tau_in\"]\n",
        "    tau_d  = thresholds[\"tau_delta\"]\n",
        "\n",
        "    specials = set(special_keep_idx or [])\n",
        "    roles = []\n",
        "    for j in range(t_out.numel()):\n",
        "        if j in specials:\n",
        "            roles.append(\"special\"); continue\n",
        "        is_sink  = (z_alpha[j] >= tau_a) and (z_tin[j] >= tau_i) and (Delta[j] >= tau_d) and (z_tout[j] < tau_o)\n",
        "        is_broad = (z_tout[j] >= tau_o) and ((z_tout[j] - z_tin[j]) >= tau_d)\n",
        "        roles.append(\"sink\" if is_sink else (\"broadcaster\" if is_broad else \"benign\"))\n",
        "    return roles\n",
        "\n",
        "@torch.no_grad()\n",
        "def te_roles_over_probe_vision(model, probe_loader, device=\"cuda\", sigma_v=1e-2,\n",
        "                               thresholds=dict(tau_alpha=1.5, tau_out=1.0, tau_in=1.0, tau_delta=0.5),\n",
        "                               special_keep_idx=[0]):\n",
        "    images, _ = next(iter(probe_loader))\n",
        "    X1_list, X2_list = _vision_preln_inputs_two_runs(model, images, sigma_v, device)\n",
        "    roles_all, stats_all = [], []\n",
        "    for l, (X1, X2) in enumerate(zip(X1_list, X2_list)):\n",
        "        attn = model.visual.transformer.resblocks[l].attn\n",
        "        t_out, t_in, alpha, nu = _te_inout_alpha_for_layer(attn, X1, X2)\n",
        "        roles_all.append(classify_tokens_by_te_rules(t_out, t_in, alpha, nu, thresholds, special_keep_idx))\n",
        "        stats_all.append({\"t_out\": t_out.cpu(), \"t_in\": t_in.cpu(), \"alpha\": alpha.cpu(), \"nu\": nu.cpu()})\n",
        "    return roles_all, stats_all\n",
        "\n",
        "@torch.no_grad()\n",
        "def te_roles_over_probe_text(model, probe_loader, device=\"cuda\", sigma_t=1e-2,\n",
        "                             thresholds=dict(tau_alpha=0.5, tau_out=1.0, tau_in=0.3, tau_delta=0.2),\n",
        "                             special_keep_idx=[0]):\n",
        "    _imgs, token_ids = next(iter(probe_loader))\n",
        "    X1_list, X2_list = _text_preln_inputs_two_runs(model, token_ids, sigma_t, device)\n",
        "    roles_all, stats_all = [], []\n",
        "    for l, (X1, X2) in enumerate(zip(X1_list, X2_list)):\n",
        "        attn = model.transformer.resblocks[l].attn\n",
        "        t_out, t_in, alpha, nu = _te_inout_alpha_for_layer(attn, X1, X2)\n",
        "        roles_all.append(classify_tokens_by_te_rules(t_out, t_in, alpha, nu, thresholds, special_keep_idx))\n",
        "        stats_all.append({\"t_out\": t_out.cpu(), \"t_in\": t_in.cpu(), \"alpha\": alpha.cpu(), \"nu\": nu.cpu()})\n",
        "    return roles_all, stats_all\n",
        "\n",
        "@torch.no_grad()\n",
        "def select_tokens_to_prune(roles_per_layer, te_tokens_per_layer, prune_frac_per_layer, min_tokens_remaining, special_keep_idx):\n",
        "    pruned = []\n",
        "    specials = set(special_keep_idx or [])\n",
        "    for l, roles in enumerate(roles_per_layer):\n",
        "        T = len(roles)\n",
        "        floor = max(min_tokens_remaining, len(specials))\n",
        "        max_drop = max(0, T - floor)\n",
        "        k = min(int(round(prune_frac_per_layer * T)), max_drop)\n",
        "        if k == 0:\n",
        "            pruned.append([]); continue\n",
        "\n",
        "        benign = [i for i,r in enumerate(roles) if r==\"benign\" and i not in specials]\n",
        "        sinks  = [i for i,r in enumerate(roles) if r==\"sink\" and i not in specials]\n",
        "        broads = [i for i,r in enumerate(roles) if r==\"broadcaster\" and i not in specials]\n",
        "        scores = te_tokens_per_layer[l]  # [T]\n",
        "\n",
        "        def lowest(idx, need, keep_at_least=0):\n",
        "            if need<=0 or not idx: return []\n",
        "            order = torch.argsort(scores[idx], descending=False).tolist()\n",
        "            return [idx[i] for i in order[:min(need, max(0, len(idx)-keep_at_least))]]\n",
        "\n",
        "        to_drop, need = [], k\n",
        "        take = lowest(benign, need); to_drop += take; need -= len(take)\n",
        "        if need>0 and len(sinks)>1:\n",
        "            take = lowest(sinks, min(need, len(sinks)-1)); to_drop += take; need -= len(take)\n",
        "        if need>0 and len(broads)>1:\n",
        "            take = lowest(broads, min(need, len(broads)-1)); to_drop += take; need -= len(take)\n",
        "        pruned.append(sorted(set(to_drop))[:k])\n",
        "    return pruned\n",
        "\n",
        "class _TokenPruneApplier:\n",
        "    \"\"\"\n",
        "    Apply token pruning (mask attention to/from dropped tokens; zero their inputs at each pruned layer).\n",
        "    Does NOT change sequence length.\n",
        "    \"\"\"\n",
        "    def __init__(self, model, pruned_tokens_per_layer, tower=\"vision\"):\n",
        "        layers = _iter_resblocks(model.visual if tower==\"vision\" else model)\n",
        "        self._hooks = []\n",
        "        for l, drop in enumerate(pruned_tokens_per_layer):\n",
        "            if not drop: continue\n",
        "            blk = layers[l]\n",
        "            idx = torch.tensor(drop, dtype=torch.long)\n",
        "\n",
        "            def make_pre_hook(idx):\n",
        "                def _hook(mod, args):\n",
        "                    x = args[0].clone()          # [T,B,D] entering the block\n",
        "                    x.index_fill_(0, idx.to(x.device), 0.0)\n",
        "                    return (x,)\n",
        "                return _hook\n",
        "            self._hooks.append(blk.register_forward_pre_hook(make_pre_hook(idx)))\n",
        "\n",
        "            prev = blk.attention\n",
        "            def make_attn(prev, idx):\n",
        "                def new_attn(x):\n",
        "                    T = x.size(0)\n",
        "                    mask = torch.zeros(T, T, device=x.device, dtype=x.dtype)\n",
        "                    mask[:, idx] = float(\"-inf\")\n",
        "                    mask[idx, :] = float(\"-inf\")\n",
        "                    blk.attn_mask = mask\n",
        "                    return prev(x)\n",
        "                return new_attn\n",
        "            blk.attention = make_attn(prev, idx)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Head pruning (zero Q/K/V rows + zero out_proj columns)\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def _zero_heads_in_mha(attn: nn.MultiheadAttention, heads_to_prune: List[int]):\n",
        "    \"\"\"\n",
        "    Zero Q/K/V rows for selected heads in in_proj_weight (+bias),\n",
        "    and zero the corresponding input columns of out_proj.\n",
        "    Works with LoRALinear-wrapped out_proj by touching base.weight and lora_A.\n",
        "    \"\"\"\n",
        "    D  = attn.embed_dim\n",
        "    H  = attn.num_heads\n",
        "    dh = D // H\n",
        "\n",
        "    # ---- Q/K/V fused weights ----\n",
        "    W = attn.in_proj_weight            # [3D, D]\n",
        "    b = attn.in_proj_bias              # [3D] or None\n",
        "\n",
        "    def offslice(rows: slice, off: int) -> slice:\n",
        "        return slice(off + rows.start, off + rows.stop, rows.step)\n",
        "\n",
        "    for h in heads_to_prune:\n",
        "        rows = slice(h * dh, (h + 1) * dh)\n",
        "\n",
        "        # Zero Q, K, V blocks\n",
        "        for off in (0, D, 2 * D):\n",
        "            s = offslice(rows, off)\n",
        "            W[s, :].zero_()\n",
        "            if b is not None:\n",
        "                b[s].zero_()\n",
        "\n",
        "        # Zero corresponding input columns of out_proj\n",
        "        out = attn.out_proj\n",
        "        if hasattr(out, \"base\") and hasattr(out, \"lora_A\"):\n",
        "            out.base.weight[:, rows].zero_()\n",
        "            out.lora_A[:, rows].zero_()\n",
        "        elif isinstance(out, nn.Linear):\n",
        "            out.weight[:, rows].zero_()\n",
        "        else:\n",
        "            Wout = getattr(out, \"weight\", None)\n",
        "            if isinstance(Wout, torch.Tensor):\n",
        "                Wout[:, rows].zero_()\n",
        "\n",
        "@torch.no_grad()\n",
        "def prune_heads_by_te_vision(\n",
        "    model,\n",
        "    te_heads_per_layer: List[torch.Tensor],          # [L] of [H] tensors\n",
        "    prune_frac_per_layer: float = 0.25,\n",
        "    min_heads_remaining: int = 4,\n",
        "    guard_layers: Tuple[int, int] = (0, -1),         # protect first & last layers if desired\n",
        "    verbose: bool = True,\n",
        "    roles_per_layer: Optional[List[List[str]]] = None,  # role-aware benign-first if provided\n",
        ") -> List[List[int]]:\n",
        "    \"\"\"\n",
        "    Role-aware head pruning (benign-first):\n",
        "      • Prune 'benign' heads first by lowest TE.\n",
        "      • If more pruning needed, prune 'excess' sinks/broadcasters by TE,\n",
        "        but keep at least one sink and one broadcaster per layer (when they exist).\n",
        "      • Respects guard_layers and min_heads_remaining.\n",
        "    \"\"\"\n",
        "    blocks = model.visual.transformer.resblocks\n",
        "    L = len(blocks)\n",
        "    pruned = []\n",
        "\n",
        "    for l, scores in enumerate(te_heads_per_layer):\n",
        "        # Guard entire layers, if requested\n",
        "        if l == guard_layers[0] or (guard_layers[1] in (-1, L-1) and l == L-1) or l == guard_layers[1]:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        attn = blocks[l].attn\n",
        "        H = attn.num_heads\n",
        "        if H <= 1:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        # How many can we prune at most while keeping min_heads_remaining?\n",
        "        max_prunable = max(0, H - min_heads_remaining)\n",
        "\n",
        "        # Target k by fraction, clamped\n",
        "        k = int(max(0, round(H * prune_frac_per_layer)))\n",
        "        k = min(k, max_prunable)\n",
        "        if k == 0:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        # Roles for this layer\n",
        "        if roles_per_layer is None:\n",
        "            roles = [\"benign\"] * H\n",
        "        else:\n",
        "            roles = roles_per_layer[l]\n",
        "\n",
        "        all_heads = list(range(H))\n",
        "        benign = [h for h in all_heads if roles[h] == \"benign\"]\n",
        "        sinks  = [h for h in all_heads if roles[h] == \"sink\"]\n",
        "        broads = [h for h in all_heads if roles[h] == \"broadcaster\"]\n",
        "\n",
        "        # Helper: select lowest-TE heads from a list\n",
        "        def take_lowest(head_list, need):\n",
        "            if need <= 0 or not head_list:\n",
        "                return []\n",
        "            head_scores = torch.tensor([scores[h] for h in head_list])\n",
        "            order = torch.argsort(head_scores, descending=False).tolist()\n",
        "            return [head_list[i] for i in order[:min(need, len(head_list))]]\n",
        "\n",
        "        to_prune, need = [], k\n",
        "\n",
        "        # 1) prune benign first\n",
        "        pick = take_lowest(benign, need)\n",
        "        to_prune.extend(pick)\n",
        "        need -= len(pick)\n",
        "\n",
        "        # 2) if still need, prune surplus sinks/broadcasters but keep ≥1 of each if present\n",
        "        if need > 0 and len(sinks) > 1:\n",
        "            pick = take_lowest(sinks, min(need, len(sinks) - 1))\n",
        "            to_prune.extend(pick)\n",
        "            need -= len(pick)\n",
        "        if need > 0 and len(broads) > 1:\n",
        "            pick = take_lowest(broads, min(need, len(broads) - 1))\n",
        "            to_prune.extend(pick)\n",
        "            need -= len(pick)\n",
        "\n",
        "        to_prune = sorted(set(to_prune))[:k]\n",
        "        _zero_heads_in_mha(attn, to_prune)\n",
        "\n",
        "        if verbose:\n",
        "            print(f\"[HeadPrune-Roles] layer {l}: pruned {to_prune} \"\n",
        "                  f\"(H={H}, keep≥{min_heads_remaining}, roles: benign={len(benign)}, sink={len(sinks)}, broad={len(broads)})\")\n",
        "\n",
        "        pruned.append(to_prune)\n",
        "\n",
        "    return pruned\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def _short_tok_role(r: str) -> str:\n",
        "    return {\"sink\":\"S\", \"broadcaster\":\"B\", \"benign\":\"N\", \"special\":\"*\"}.get(r, \"?\")\n",
        "\n",
        "def _print_token_roles(roles_per_layer):\n",
        "    print(\"[Roles/Tokens] per layer (S=sink, B=broadcaster, N=benign, *=special-kept)\")\n",
        "    for l, roles in enumerate(roles_per_layer):\n",
        "        tags = \" \".join(_short_tok_role(r) for r in roles)\n",
        "        cnt = {\"sink\":0, \"broadcaster\":0, \"benign\":0, \"special\":0}\n",
        "        for r in roles: cnt[r] = cnt.get(r,0)+1\n",
        "        print(f\"  Layer {l:02d}: {tags}   (S={cnt['sink']}, B={cnt['broadcaster']}, N={cnt['benign']}, *={cnt['special']})\")\n",
        "\n",
        "def _print_token_scores(stats_per_layer, label=\"TE/alpha per token\"):\n",
        "    # stats_per_layer: list of dicts with keys: \"t_in\",\"t_out\",\"alpha\",\"nu\" (each [T] tensor on CPU)\n",
        "    print(f\"[Tokens] {label}\")\n",
        "    for l, s in enumerate(stats_per_layer):\n",
        "        t_in   = s[\"t_in\"].tolist()\n",
        "        t_out  = s[\"t_out\"].tolist()\n",
        "        alpha  = s[\"alpha\"].tolist()\n",
        "        mean_in   = sum(t_in)/max(1,len(t_in))\n",
        "        mean_out  = sum(t_out)/max(1,len(t_out))\n",
        "        mean_alpha= sum(alpha)/max(1,len(alpha))\n",
        "        # print short vectors (trim if long)\n",
        "        def _brief(v, k=10):\n",
        "            if len(v) <= k: return [round(float(x),6) for x in v]\n",
        "            return [round(float(x),6) for x in v[:k]] + [\"...\"]\n",
        "        print(f\"  L{l:02d}  mean(TE_in)={mean_in:.6f}  mean(TE_out)={mean_out:.6f}  mean(alpha)={mean_alpha:.6f}\")\n",
        "        print(f\"         TE_in : {_brief(t_in)}\")\n",
        "        print(f\"         TE_out: {_brief(t_out)}\")\n",
        "        print(f\"         alpha : {_brief(alpha)}\")\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Head pruning for TEXT (reuses zeroing helper)\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def prune_heads_by_te_text(\n",
        "    model,\n",
        "    te_heads_per_layer: List[torch.Tensor],          # [L] of [H] tensors\n",
        "    prune_frac_per_layer: float = 0.25,\n",
        "    min_heads_remaining: int = 4,\n",
        "    guard_layers: Tuple[int, int] = (0, -1),\n",
        "    verbose: bool = True,\n",
        "    roles_per_layer: Optional[List[List[str]]] = None,\n",
        ") -> List[List[int]]:\n",
        "    blocks = model.transformer.resblocks\n",
        "    L = len(blocks)\n",
        "    pruned = []\n",
        "\n",
        "    for l, scores in enumerate(te_heads_per_layer):\n",
        "        # Guard layers\n",
        "        if l == guard_layers[0] or (guard_layers[1] in (-1, L-1) and l == L-1) or l == guard_layers[1]:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        attn = blocks[l].attn\n",
        "        H = attn.num_heads\n",
        "        if H <= 1:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        max_prunable = max(0, H - min_heads_remaining)\n",
        "        k = int(max(0, round(H * prune_frac_per_layer)))\n",
        "        k = min(k, max_prunable)\n",
        "        if k == 0:\n",
        "            pruned.append([])\n",
        "            continue\n",
        "\n",
        "        roles = roles_per_layer[l] if roles_per_layer is not None else [\"benign\"] * H\n",
        "        all_heads = list(range(H))\n",
        "        benign = [h for h in all_heads if roles[h] == \"benign\"]\n",
        "        sinks  = [h for h in all_heads if roles[h] == \"sink\"]\n",
        "        broads = [h for h in all_heads if roles[h] == \"broadcaster\"]\n",
        "\n",
        "        def take_lowest(head_list, need):\n",
        "            if need <= 0 or not head_list:\n",
        "                return []\n",
        "            head_scores = torch.tensor([scores[h] for h in head_list])\n",
        "            order = torch.argsort(head_scores, descending=False).tolist()\n",
        "            return [head_list[i] for i in order[:min(need, len(head_list))]]\n",
        "\n",
        "        to_prune, need = [], k\n",
        "        pick = take_lowest(benign, need)\n",
        "        to_prune.extend(pick); need -= len(pick)\n",
        "\n",
        "        if need > 0 and len(sinks) > 1:\n",
        "            pick = take_lowest(sinks, min(need, len(sinks) - 1))\n",
        "            to_prune.extend(pick); need -= len(pick)\n",
        "        if need > 0 and len(broads) > 1:\n",
        "            pick = take_lowest(broads, min(need, len(broads) - 1))\n",
        "            to_prune.extend(pick); need -= len(pick)\n",
        "\n",
        "        to_prune = sorted(set(to_prune))[:k]\n",
        "        _zero_heads_in_mha(attn, to_prune)\n",
        "\n",
        "        if verbose:\n",
        "            print(f\"[HeadPrune-TextRoles] layer {l}: pruned {to_prune} \"\n",
        "                  f\"(H={H}, keep≥{min_heads_remaining}, roles: benign={len(benign)}, sink={len(sinks)}, broad={len(broads)})\")\n",
        "\n",
        "        pruned.append(to_prune)\n",
        "\n",
        "    return pruned\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Probe loader (unused if you prefer using train_dataloader as probe)\n",
        "# -------------------------\n",
        "def get_probe_loader(train_dataset, batch_size=64, num_batches=2):\n",
        "    from torch.utils.data import Subset, DataLoader\n",
        "    total = len(train_dataset)\n",
        "    need = min(total, batch_size * num_batches)\n",
        "    if need == 0:\n",
        "        raise ValueError(\"Empty train_dataset for probe loader.\")\n",
        "    idxs = random.sample(range(total), need)\n",
        "    probe = Subset(train_dataset, idxs)\n",
        "    bs = min(batch_size, len(probe))\n",
        "    return DataLoader(\n",
        "        probe,\n",
        "        batch_size=bs,\n",
        "        shuffle=False,\n",
        "        num_workers=2,\n",
        "        pin_memory=torch.cuda.is_available(),\n",
        "        drop_last=True,   # avoid B=1 (variance==0) batches\n",
        "    )\n",
        "\n",
        "# -------------------------\n",
        "# Contrastive loss (NaN-safe)\n",
        "# -------------------------\n",
        "def clip_contrastive_loss(model, image_features, text_features):\n",
        "    # Cast to fp32 for numerics, then normalize safely\n",
        "    image_features = safe_l2_normalize(image_features.float(), dim=-1, eps=1e-6)\n",
        "    text_features  = safe_l2_normalize(text_features.float(),  dim=-1, eps=1e-6)\n",
        "\n",
        "    # B×B cosine sims scaled by a SAFE logit scale\n",
        "    logit_scale = get_safe_logit_scale(model)  # scalar tensor on correct device\n",
        "    logits_per_image = (logit_scale * (image_features @ text_features.t()))\n",
        "    logits_per_image = torch.nan_to_num(logits_per_image, nan=0.0, posinf=1e4, neginf=-1e4)\n",
        "    logits_per_text = logits_per_image.t()\n",
        "\n",
        "    B = image_features.size(0)\n",
        "    labels = torch.arange(B, device=logits_per_image.device)\n",
        "\n",
        "    loss_i = F.cross_entropy(logits_per_image, labels)\n",
        "    loss_t = F.cross_entropy(logits_per_text,  labels)\n",
        "    return 0.5 * (loss_i + loss_t)\n",
        "\n",
        "\n",
        "def _short_role(r: str) -> str:\n",
        "    return {\"sink\": \"S\", \"broadcaster\": \"B\", \"benign\": \"N\"}.get(r, \"?\")\n",
        "\n",
        "def _print_roles(roles_per_layer: List[List[str]]):\n",
        "    print(\"[Roles] Head roles per layer (S=sink, B=broadcaster, N=benign)\")\n",
        "    for l, roles in enumerate(roles_per_layer):\n",
        "        tags = \" \".join(_short_role(r) for r in roles)\n",
        "        cnt_s = sum(1 for r in roles if r == \"sink\")\n",
        "        cnt_b = sum(1 for r in roles if r == \"broadcaster\")\n",
        "        cnt_n = sum(1 for r in roles if r == \"benign\")\n",
        "        print(f\"  Layer {l:02d}: {tags}   (S={cnt_s}, B={cnt_b}, N={cnt_n})\")\n",
        "\n",
        "\n",
        "def _print_head_masses(feats_per_layer):\n",
        "    \"\"\"\n",
        "    Pretty-print m_in (=cls_in_mass) and m_out (=cls_out_mass) per head, per layer.\n",
        "    \"\"\"\n",
        "    print(\"[Mass] Per-head m_in (→CLS) and m_out (CLS→) after warmup\")\n",
        "    for l, feats in enumerate(feats_per_layer):\n",
        "        m_in  = [round(float(f[\"cls_in_mass\"]),  6) for f in feats]\n",
        "        m_out = [round(float(f[\"cls_out_mass\"]), 6) for f in feats]\n",
        "        print(f\"  Layer {l:02d}  m_in : {m_in}\")\n",
        "        print(f\"  Layer {l:02d}  m_out: {m_out}\")\n",
        "\n",
        "\n",
        "\n",
        "# -------------------------\n",
        "# Training (NaN-safe; TE-aware token pruning, benign-first)\n",
        "# -------------------------\n",
        "def train_with_te_prune_lora(\n",
        "    model,\n",
        "    train_dataloader,\n",
        "    train_dataset,\n",
        "    device,\n",
        "    lora_rank=8,\n",
        "    lora_alpha=16,\n",
        "    lora_dropout=0.0,\n",
        "    lr=1e-4,\n",
        "    weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    # legacy args (kept for API compatibility)\n",
        "    prune_k_vision=0,\n",
        "    prune_k_text=0,\n",
        "    te_pool: Union[str, Tuple[str, str]] = \"mean\",    # (\"mean\",\"cls\") typical\n",
        "    probe_batches=2,\n",
        "    te_sigma_v: float = 1e-2,\n",
        "    te_sigma_t: float = 1e-2,\n",
        "    head_prune_frac: float = 0.25,                    # ← legacy; mapped to token_prune_frac if not provided\n",
        "    min_heads_remaining: int = 4,                     # ← legacy; ignored\n",
        "    guard_layers: Tuple[int, int] = (0, -1),\n",
        "    include_cls_in_te: bool = False,                  # used only for vision pooling if needed elsewhere\n",
        "    # NEW preferred args (safe defaults set inside if omitted)\n",
        "    token_prune_frac: Optional[float] = None,\n",
        "    min_tokens_remaining_vision: int = 48,\n",
        "    min_tokens_remaining_text: int = 28,\n",
        "    # (not used in token mode; kept for API compat)\n",
        "    role_in_hi_pct: float = 60.0,\n",
        "    role_out_hi_pct: float = 80.0,\n",
        "    te_hi_pct: float = 60.0,\n",
        "    te_lo_pct: float = 30.0,\n",
        "    require_te_for_role: bool = True,\n",
        "    use_te_as_tiebreak: bool = True,\n",
        "):\n",
        "    # ---- map legacy fraction to token fraction (back-compat) ----\n",
        "    if token_prune_frac is None:\n",
        "        token_prune_frac = head_prune_frac\n",
        "        print(f\"[Note] Using head_prune_frac={head_prune_frac} as token_prune_frac.\")\n",
        "\n",
        "    # LoRA injection/upgrade (and freeze non-LoRA params)\n",
        "    lora_params = add_lora_to_clip_vit_b16(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)\n",
        "    _assert_outproj_has_weight(model)\n",
        "    model.to(device)\n",
        "\n",
        "    optimizer = AdamW(lora_params, lr=lr, weight_decay=weight_decay)\n",
        "    probe_loader = train_dataloader  # reuse\n",
        "\n",
        "    # handle pool specs for (vision, text) — unused in token TE (we compute directional TE directly)\n",
        "    if isinstance(te_pool, tuple):\n",
        "        vision_pool, text_pool = te_pool\n",
        "    else:\n",
        "        vision_pool, text_pool = te_pool, \"cls\"\n",
        "\n",
        "    history = {\n",
        "        # token diagnostics (filled below)\n",
        "        \"tokens_te_in_vision\":   None,\n",
        "        \"tokens_te_out_vision\":  None,\n",
        "        \"tokens_alpha_vision\":   None,\n",
        "        \"tokens_nu_vision\":      None,\n",
        "        \"tokens_te_in_text\":     None,\n",
        "        \"tokens_te_out_text\":    None,\n",
        "        \"tokens_alpha_text\":     None,\n",
        "        \"tokens_nu_text\":        None,\n",
        "        \"roles_before_prune_tokens_vision\": None,\n",
        "        \"roles_before_prune_tokens_text\":   None,\n",
        "        \"pruned_tokens_vision\":  None,\n",
        "        \"pruned_tokens_text\":    None,\n",
        "        # loss trace (optional)\n",
        "        \"loss_per_epoch\": [],\n",
        "    }\n",
        "\n",
        "    def epoch_pass(epoch_idx):\n",
        "        model.train()\n",
        "        total_loss, num_steps = 0.0, 0\n",
        "        for images, texts in train_dataloader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "            texts  = texts.to(device, non_blocking=True)\n",
        "\n",
        "            optimizer.zero_grad(set_to_none=True)\n",
        "\n",
        "            img_feat = model.encode_image(images)\n",
        "            txt_feat = model.encode_text(texts)\n",
        "\n",
        "            loss = clip_contrastive_loss(model, img_feat, txt_feat)\n",
        "            if not torch.isfinite(loss):\n",
        "                with torch.no_grad():\n",
        "                    bad = {\n",
        "                        \"loss\": float(loss.detach().float().item()),\n",
        "                        \"img_feat_has_nan\": bool(torch.isnan(img_feat).any().item()),\n",
        "                        \"txt_feat_has_nan\": bool(torch.isnan(txt_feat).any().item()),\n",
        "                    }\n",
        "                print(\"[Warn] Skipping non-finite batch:\", bad)\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            clip_grad_norm_(lora_params, max_norm=1.0)\n",
        "            optimizer.step()\n",
        "\n",
        "            total_loss += float(loss.item())\n",
        "            num_steps += 1\n",
        "\n",
        "        avg_loss = total_loss / max(1, num_steps)\n",
        "        history[\"loss_per_epoch\"].append(avg_loss)\n",
        "        print(f\"[Epoch {epoch_idx}] loss={avg_loss:.6f}\")\n",
        "\n",
        "    # --- Phase 1: warmup ---\n",
        "    for ep in range(1, epochs_before_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    # --- Snapshot TE-aware token roles right after warmup ---\n",
        "    model.eval()\n",
        "\n",
        "    # thresholds from your spec: (vision stricter, text looser)\n",
        "    vision_thr = dict(tau_alpha=1.5, tau_out=1.0, tau_in=1.0, tau_delta=0.5)\n",
        "    text_thr   = dict(tau_alpha=0.5, tau_out=1.0, tau_in=0.3, tau_delta=0.2)\n",
        "\n",
        "    # compute roles + stats per layer (per-token)\n",
        "    roles_v_tok, stats_v = te_roles_over_probe_vision(\n",
        "        model, probe_loader, device=device, sigma_v=te_sigma_v,\n",
        "        thresholds=vision_thr, special_keep_idx=[0]  # keep CLS\n",
        "    )\n",
        "    roles_t_tok, stats_t = te_roles_over_probe_text(\n",
        "        model, probe_loader, device=device, sigma_t=te_sigma_t,\n",
        "        thresholds=text_thr, special_keep_idx=[0]    # keep BOS/CLS (extend if you have PAD/EOT ids)\n",
        "    )\n",
        "\n",
        "    # Print token stats\n",
        "    print(\"\\n[TE/Vision] Per-token directional TE & alpha (averaged over probe)\")\n",
        "    _print_token_scores(stats_v, label=\"Vision TE_in/TE_out/alpha\")\n",
        "    print(\"\\n[TE/Text] Per-token directional TE & alpha (averaged over probe)\")\n",
        "    _print_token_scores(stats_t, label=\"Text TE_in/TE_out/alpha\")\n",
        "\n",
        "    print(\"\\n[Roles/Vision/Tokens] (S=sink, B=broadcaster, N=benign, *=special-kept)\")\n",
        "    _print_token_roles(roles_v_tok)\n",
        "    print(\"\\n[Roles/Text/Tokens] (S=sink, B=broadcaster, N=benign, *=special-kept)\")\n",
        "    _print_token_roles(roles_t_tok)\n",
        "\n",
        "    # Save token stats into history\n",
        "    history[\"tokens_te_in_vision\"]   = [[float(x) for x in s[\"t_in\"]]   for s in stats_v]\n",
        "    history[\"tokens_te_out_vision\"]  = [[float(x) for x in s[\"t_out\"]]  for s in stats_v]\n",
        "    history[\"tokens_alpha_vision\"]   = [[float(x) for x in s[\"alpha\"]]  for s in stats_v]\n",
        "    history[\"tokens_nu_vision\"]      = [[float(x) for x in s[\"nu\"]]     for s in stats_v]\n",
        "    history[\"tokens_te_in_text\"]     = [[float(x) for x in s[\"t_in\"]]   for s in stats_t]\n",
        "    history[\"tokens_te_out_text\"]    = [[float(x) for x in s[\"t_out\"]]  for s in stats_t]\n",
        "    history[\"tokens_alpha_text\"]     = [[float(x) for x in s[\"alpha\"]]  for s in stats_t]\n",
        "    history[\"tokens_nu_text\"]        = [[float(x) for x in s[\"nu\"]]     for s in stats_t]\n",
        "    history[\"roles_before_prune_tokens_vision\"] = roles_v_tok\n",
        "    history[\"roles_before_prune_tokens_text\"]   = roles_t_tok\n",
        "\n",
        "    # --- Build token importance vectors (higher=important) ---\n",
        "    te_v_tok = [torch.max(s[\"t_in\"], s[\"t_out\"]) for s in stats_v]  # list of [T] tensors\n",
        "    te_t_tok = [torch.max(s[\"t_in\"], s[\"t_out\"]) for s in stats_t]\n",
        "\n",
        "    # --- Select tokens to drop per layer (benign-first) ---\n",
        "    pruned_tokens_v = select_tokens_to_prune(\n",
        "        roles_v_tok, te_v_tok,\n",
        "        prune_frac_per_layer=token_prune_frac,\n",
        "        min_tokens_remaining=min_tokens_remaining_vision,\n",
        "        special_keep_idx=[0],\n",
        "    )\n",
        "    pruned_tokens_t = select_tokens_to_prune(\n",
        "        roles_t_tok, te_t_tok,\n",
        "        prune_frac_per_layer=token_prune_frac,\n",
        "        min_tokens_remaining=min_tokens_remaining_text,\n",
        "        special_keep_idx=[0],\n",
        "    )\n",
        "\n",
        "    # respect guard layers\n",
        "    def _apply_guards(drop_lists, tower_name):\n",
        "        layers = _iter_resblocks(model.visual if tower_name==\"vision\" else model)\n",
        "        L = len(layers)\n",
        "        first_guard, last_guard = guard_layers\n",
        "        if last_guard == -1: last_guard = L - 1\n",
        "        for l in range(L):\n",
        "            if l == first_guard or l == last_guard:\n",
        "                drop_lists[l] = []\n",
        "        return drop_lists\n",
        "\n",
        "    pruned_tokens_v = _apply_guards(pruned_tokens_v, \"vision\")\n",
        "    pruned_tokens_t = _apply_guards(pruned_tokens_t, \"text\")\n",
        "\n",
        "    print(\"\\n[TokenPrune/TE] Vision per-layer drops:\", pruned_tokens_v)\n",
        "    print(\"[TokenPrune/TE] Text   per-layer drops:\", pruned_tokens_t)\n",
        "\n",
        "    # Apply pruning masks (per-layer, non-cumulative)\n",
        "    _ = _TokenPruneApplier(model, pruned_tokens_v, tower=\"vision\")\n",
        "    _ = _TokenPruneApplier(model, pruned_tokens_t, tower=\"text\")\n",
        "\n",
        "    history[\"pruned_tokens_vision\"] = pruned_tokens_v\n",
        "    history[\"pruned_tokens_text\"]   = pruned_tokens_t\n",
        "\n",
        "    # --- Phase 2: continue training after pruning ---\n",
        "    for ep in range(epochs_before_prune + 1, epochs_before_prune + epochs_after_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    return history\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# ======================================================\n",
        "# OPTIONAL: COCO Retrieval Evaluation (I2T / T2I)\n",
        "# ======================================================\n",
        "def evaluate_coco_retrieval(model, data_dir: str, batch_size: int = 64, input_resolution: int = 224, context_length: int = 77):\n",
        "    \"\"\"\n",
        "    Evaluate trained model on COCO val2014 captions.\n",
        "    Assumes `pip install pycocotools` and images/annotations present in data_dir.\n",
        "    \"\"\"\n",
        "    import clip\n",
        "    import numpy as np\n",
        "    from torchvision import datasets, transforms\n",
        "    from torchvision.transforms import InterpolationMode\n",
        "    from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "    device = next(model.parameters()).device\n",
        "\n",
        "    eval_transform = transforms.Compose([\n",
        "        transforms.Resize((input_resolution, input_resolution), interpolation=InterpolationMode.BICUBIC),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                             std=(0.26862954, 0.26130258, 0.27577711))\n",
        "    ])\n",
        "\n",
        "    class CocoEvalDataset(Dataset):\n",
        "        def __init__(self, root, annFile, transform=None):\n",
        "            self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "            self.transform = transform\n",
        "        def __len__(self):\n",
        "            return len(self.dataset)\n",
        "        def __getitem__(self, idx):\n",
        "            image, captions = self.dataset[idx]\n",
        "            return image, captions\n",
        "\n",
        "    def coco_collate_fn(batch):\n",
        "        images = [img for img, _ in batch]\n",
        "        captions = [caps for _, caps in batch]  # list[list[str]]\n",
        "        images = torch.stack(images, dim=0)\n",
        "        return images, captions\n",
        "\n",
        "    val_img_dir  = os.path.join(data_dir, 'val2014')\n",
        "    val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "    val_dataset  = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2,\n",
        "                              pin_memory=torch.cuda.is_available(), collate_fn=coco_collate_fn)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    all_image_features = []\n",
        "    all_text_features  = []\n",
        "    image_to_text_indices = []\n",
        "    all_captions_flat = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        text_count = 0\n",
        "        for images, batch_captions in val_loader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "\n",
        "            image_feats = model.encode_image(images)\n",
        "            image_feats = safe_l2_normalize(image_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_image_features.append(image_feats.cpu())\n",
        "\n",
        "            flat_captions = []\n",
        "            image_to_text_map_for_batch = []\n",
        "            for caps in batch_captions:\n",
        "                start_idx = text_count + len(flat_captions)\n",
        "                flat_captions.extend(caps)\n",
        "                end_idx = text_count + len(flat_captions)\n",
        "                image_to_text_map_for_batch.append((start_idx, end_idx))\n",
        "\n",
        "            if len(flat_captions) == 0:\n",
        "                image_to_text_indices.extend([[] for _ in batch_captions])\n",
        "                continue\n",
        "\n",
        "            texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "            text_feats = model.encode_text(texts)\n",
        "            text_feats = safe_l2_normalize(text_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_text_features.append(text_feats.cpu())\n",
        "            all_captions_flat.extend(flat_captions)\n",
        "\n",
        "            for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "                image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "            text_count += len(flat_captions)\n",
        "\n",
        "    all_image_features = torch.cat(all_image_features, dim=0)  # [N_img, D]\n",
        "    all_text_features  = torch.cat(all_text_features,  dim=0)  # [N_txt, D]\n",
        "\n",
        "    sim_matrix = all_image_features @ all_text_features.t()    # [N_img, N_txt]\n",
        "\n",
        "    def compute_recall_with_multiple_captions(sim_matrix: torch.Tensor,\n",
        "                                              image_to_text_indices,\n",
        "                                              k: int = 1) -> float:\n",
        "        n = sim_matrix.size(0)\n",
        "        successes = 0\n",
        "        for i in range(n):\n",
        "            scores = sim_matrix[i]\n",
        "            sorted_idx = torch.argsort(scores, descending=True)\n",
        "            correct = image_to_text_indices[i]\n",
        "            if not correct:\n",
        "                continue\n",
        "            min_rank = None\n",
        "            for cidx in correct:\n",
        "                pos = (sorted_idx == cidx).nonzero(as_tuple=False)\n",
        "                if pos.numel() > 0:\n",
        "                    rank = int(pos[0, 0].item())\n",
        "                    if (min_rank is None) or (rank < min_rank):\n",
        "                        min_rank = rank\n",
        "            if (min_rank is not None) and (min_rank < k):\n",
        "                successes += 1\n",
        "        return successes / max(1, n)\n",
        "\n",
        "    # I2T\n",
        "    r1  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "    r5  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "    r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "    print(\"Image-to-Text Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "    # T2I\n",
        "    text_to_image = [None] * all_text_features.size(0)\n",
        "    for img_idx, tinds in enumerate(image_to_text_indices):\n",
        "        for t in tinds:\n",
        "            text_to_image[t] = img_idx\n",
        "\n",
        "    sim_matrix_t2i = sim_matrix.t()  # [N_txt, N_img]\n",
        "\n",
        "    def compute_recall_text_to_image(sim_matrix_t2i: torch.Tensor,\n",
        "                                     text_to_image_list,\n",
        "                                     k: int = 1) -> float:\n",
        "        m = sim_matrix_t2i.size(0)\n",
        "        successes = 0\n",
        "        for j in range(m):\n",
        "            scores = sim_matrix_t2i[j]\n",
        "            sorted_indices = torch.argsort(scores, descending=True)\n",
        "            correct_img = text_to_image_list[j]\n",
        "            if correct_img is None:\n",
        "                continue\n",
        "            pos = (sorted_indices == correct_img).nonzero(as_tuple=False)\n",
        "            if pos.numel() > 0 and int(pos[0, 0].item()) < k:\n",
        "                successes += 1\n",
        "        return successes / max(1, m)\n",
        "\n",
        "    r1_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "    r5_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "    r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "    print(\"Text-to-Image Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "id": "ftKgCZzre6B-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "history = train_with_te_prune_lora(\n",
        "    model=model,\n",
        "    train_dataloader=train_dataloader,\n",
        "    train_dataset=train_dataset,\n",
        "    device=device,\n",
        "    lora_rank=8, lora_alpha=16, lora_dropout=0.0,\n",
        "    lr=1e-4, weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=0,\n",
        "    te_pool=(\"mean\", \"cls\"),\n",
        "    probe_batches=2,\n",
        "    te_sigma_v=1e-2,\n",
        "    te_sigma_t=1e-2,\n",
        "    token_prune_frac=0.25,\n",
        "    guard_layers=(0, -1),\n",
        "    include_cls_in_te=False,\n",
        ")\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "KIEmJrN6K9Mf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "P2Ida1SbHcWf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "bHybAl_qHc2V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LQXgZ_alw9gc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "d42i-P_Yrq-Y"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}