{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "00bb1f736cc1429bb76ca8f896ab135d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0edb7bf930954a598546a0d600c5df2a",
              "IPY_MODEL_f554d43ff3314cd4adabea173fb9d1b0",
              "IPY_MODEL_e00698f395914dc89d6130f7562cb3d0"
            ],
            "layout": "IPY_MODEL_b4025e355ed84fd5a291f49607a6642a"
          }
        },
        "0edb7bf930954a598546a0d600c5df2a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e832ca8cbe7647bcacbed0d5a111e527",
            "placeholder": "​",
            "style": "IPY_MODEL_0e540530f7224389953099fab770960c",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "f554d43ff3314cd4adabea173fb9d1b0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ceb9cad7342045399fa8e4b605f0d7b9",
            "max": 3,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_33f03393b1294964ad37ce262d66a68c",
            "value": 3
          }
        },
        "e00698f395914dc89d6130f7562cb3d0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b13dfe9ce5ab431c8e17080e4de7222a",
            "placeholder": "​",
            "style": "IPY_MODEL_9b52248ecede4b3386e50a65b6fe3622",
            "value": " 3/3 [00:00&lt;00:00,  4.80it/s]"
          }
        },
        "b4025e355ed84fd5a291f49607a6642a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e832ca8cbe7647bcacbed0d5a111e527": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0e540530f7224389953099fab770960c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ceb9cad7342045399fa8e4b605f0d7b9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "33f03393b1294964ad37ce262d66a68c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "b13dfe9ce5ab431c8e17080e4de7222a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9b52248ecede4b3386e50a65b6fe3622": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n9NaU-YAWPjl",
        "outputId": "95bd1ad6-be36-4982-c021-a5fb36eae435"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "COCO2014 not found, downloading...\n",
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [14:21<00:00, 15.7MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:20<00:00, 12.3MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ],
      "source": [
        "# ============================================\n",
        "# 0. COCO2014 download + paths\n",
        "# ============================================\n",
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "data_dir = \"/content/coco2014\"\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "datasets_urls = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get(\"content-length\", 0))\n",
        "    with open(dest_path, \"wb\") as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit=\"B\",\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024,\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "def ensure_coco2014(data_dir: str):\n",
        "    ann_path = os.path.join(data_dir, \"annotations\", \"captions_train2014.json\")\n",
        "    if os.path.exists(ann_path):\n",
        "        print(f\"COCO2014 already present under {data_dir}\")\n",
        "        return\n",
        "\n",
        "    print(\"COCO2014 not found, downloading...\")\n",
        "    for name, url in datasets_urls.items():\n",
        "        zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "        print(f\"Processing {name}...\")\n",
        "        download_file(url, zip_path)\n",
        "        with ZipFile(zip_path, \"r\") as zip_ref:\n",
        "            zip_ref.extractall(data_dir)\n",
        "        os.remove(zip_path)\n",
        "        print(f\"{name} downloaded and extracted.\")\n",
        "    print(\"All datasets and annotations successfully downloaded and extracted!\")\n",
        "\n",
        "ensure_coco2014(data_dir)\n",
        "\n",
        "COCO_ROOT = os.path.join(data_dir, \"train2014\")\n",
        "COCO_ANN  = os.path.join(data_dir, \"annotations\", \"captions_train2014.json\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ============================================\n",
        "# 1. LLaVA + LoRA + TE imports / config\n",
        "# ============================================\n",
        "import math\n",
        "import random\n",
        "from typing import List, Tuple\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from torchvision import datasets as tv_datasets\n",
        "\n",
        "from transformers import LlavaForConditionalGeneration, AutoProcessor\n",
        "from peft import LoraConfig, get_peft_model\n",
        "\n",
        "MODEL_ID = \"llava-hf/llava-1.5-7b-hf\"\n",
        "\n",
        "DEVICE   = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "DTYPE    = torch.float16 if DEVICE == \"cuda\" else torch.float32\n",
        "\n",
        "# Training config\n",
        "BATCH_SIZE_TRAIN = 4        # keep tiny for 7B\n",
        "NUM_EPOCHS       = 5\n",
        "LR               = 1e-4\n",
        "\n",
        "# TE config (perturbation-based, like CLIP TE)\n",
        "IMG_NOISE_STD    = 1e-5     # small Gaussian noise on images\n",
        "VISION_POOL      = \"cls\"    # \"cls\" or \"mean\" or \"last\"\n",
        "LANG_POOL        = \"mean\"\n",
        "BATCH_SIZE_TE    = 16       # >=2 for TE proxy\n",
        "TE_MAX_STEPS     = 256      # max TE batches; images ~= TE_MAX_STEPS * BATCH_SIZE_TE\n",
        "\n",
        "# LoRA config\n",
        "LORA_R           = 4\n",
        "LORA_ALPHA       = 16\n",
        "LORA_DROPOUT     = 0.1\n",
        "LORA_TARGET_MODULES = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"]\n",
        "\n",
        "SEED             = 42\n",
        "\n",
        "random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 2. COCO datasets for LLaVA (train + probe)\n",
        "# ============================================\n",
        "class CocoLlavaTrainDataset(Dataset):\n",
        "    \"\"\"\n",
        "    Training dataset: (PIL image, prompt-with-caption-as-target).\n",
        "    Uses the first COCO caption as the assistant answer.\n",
        "    \"\"\"\n",
        "    def __init__(self, root: str, annFile: str, max_images: int = 512):\n",
        "        self.base = tv_datasets.CocoCaptions(root=root, annFile=annFile)\n",
        "        idxs = list(range(len(self.base)))\n",
        "        random.shuffle(idxs)\n",
        "        self.indices = idxs[:max_images]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.indices)\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        img_idx = self.indices[idx]\n",
        "        image, captions = self.base[img_idx]\n",
        "        image = image.convert(\"RGB\")\n",
        "\n",
        "        caption = captions[0] if len(captions) > 0 else \"A photo.\"\n",
        "\n",
        "        prompt = (\n",
        "            \"USER: <image>\\nDescribe this image in one sentence.\\n\"\n",
        "            f\"ASSISTANT: {caption}\"\n",
        "        )\n",
        "        return image, prompt\n",
        "\n",
        "\n",
        "class CocoLlavaProbeDataset(Dataset):\n",
        "    \"\"\"\n",
        "    TE probe dataset (no labels, generic prompt).\n",
        "    Each item: (PIL image, prompt string).\n",
        "    \"\"\"\n",
        "    def __init__(self, root: str, annFile: str, max_images: int = 512):\n",
        "        self.base = tv_datasets.CocoCaptions(root=root, annFile=annFile)\n",
        "        idxs = list(range(len(self.base)))\n",
        "        random.shuffle(idxs)\n",
        "        self.indices = idxs[:max_images]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.indices)\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        img_idx = self.indices[idx]\n",
        "        image, _captions = self.base[img_idx]\n",
        "        image = image.convert(\"RGB\")\n",
        "        prompt = \"USER: <image>\\nDescribe this image in one sentence.\\nASSISTANT:\"\n",
        "        return image, prompt\n",
        "\n",
        "\n",
        "def collate_coco(batch):\n",
        "    images, prompts = zip(*batch)\n",
        "    return list(images), list(prompts)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 3. Vision tower helper + TE utilities\n",
        "# ============================================\n",
        "def get_vision_tower(model: LlavaForConditionalGeneration) -> nn.Module:\n",
        "    \"\"\"\n",
        "    Extract the CLIP vision tower from a LLaVA HF model.\n",
        "    \"\"\"\n",
        "    base = getattr(model, \"base_model\", model)\n",
        "    vt = getattr(base, \"vision_tower\", None)\n",
        "    if vt is None:\n",
        "        raise ValueError(\"This LLaVA model has no `vision_tower` attribute.\")\n",
        "    if isinstance(vt, (list, tuple)):\n",
        "        vt = vt[0]\n",
        "    if hasattr(vt, \"vision_tower\"):\n",
        "        vt = vt.vision_tower\n",
        "    return vt\n",
        "\n",
        "\n",
        "def _pool_hidden(h: torch.Tensor, pool: str) -> torch.Tensor:\n",
        "    # h: [B, T, D]\n",
        "    if pool == \"cls\":\n",
        "        return h[:, 0, :]\n",
        "    elif pool == \"last\":\n",
        "        return h[:, -1, :]\n",
        "    elif pool == \"mean\":\n",
        "        return h.mean(dim=1)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown pool={pool}\")\n",
        "\n",
        "\n",
        "def _te_from_two_runs(pooled1: List[torch.Tensor],\n",
        "                      pooled2: List[torch.Tensor]) -> List[float]:\n",
        "    \"\"\"\n",
        "    Perturbation-based TE proxy (same idea as your CLIP TE):\n",
        "\n",
        "    For each edge ℓ -> ℓ+1:\n",
        "      Δℓ   = zℓ(2)   − zℓ(1)\n",
        "      Δℓ+1 = zℓ+1(2) − zℓ+1(1)\n",
        "      TEℓ  = 0.5 * mean_b cos²(Δℓ, Δℓ+1)\n",
        "    \"\"\"\n",
        "    assert len(pooled1) == len(pooled2) and len(pooled1) >= 2\n",
        "    eps = 1e-6\n",
        "    te_vals = []\n",
        "    for l in range(len(pooled1) - 1):\n",
        "        A = torch.nan_to_num(pooled2[l]   - pooled1[l]).float()   # [B, D]\n",
        "        B = torch.nan_to_num(pooled2[l+1] - pooled1[l+1]).float() # [B, D]\n",
        "        num = (A * B).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * B.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "\n",
        "def _v2t_te_from_two_runs(\n",
        "    hs1: Tuple[torch.Tensor, ...],\n",
        "    hs2: Tuple[torch.Tensor, ...],\n",
        "    input_ids: torch.Tensor,\n",
        "    attention_mask: torch.Tensor,\n",
        "    image_token_id: int,\n",
        ") -> List[float]:\n",
        "    \"\"\"\n",
        "    Cross-modal TE proxy per *language* layer:\n",
        "      For each transformer layer ℓ (1..L):\n",
        "        z_vis^r  = mean over image tokens (run r)\n",
        "        z_text^r = mean over text tokens (run r)\n",
        "        Δ_vis  = z_vis^2  - z_vis^1\n",
        "        Δ_text = z_text^2 - z_text^1\n",
        "        TE_{vis->text, ℓ} = 0.5 * mean_b cos²(Δ_vis, Δ_text)\n",
        "    \"\"\"\n",
        "    assert len(hs1) == len(hs2) and len(hs1) >= 2\n",
        "    B, T, D = hs1[0].shape\n",
        "    device = hs1[0].device\n",
        "    eps = 1e-6\n",
        "\n",
        "    te_vals: List[float] = []\n",
        "\n",
        "    # hidden_states[0] is embedding; [1:] are transformer blocks\n",
        "    for l in range(1, len(hs1)):\n",
        "        h1 = hs1[l]  # [B, T, D]\n",
        "        h2 = hs2[l]\n",
        "\n",
        "        z1_vis = torch.zeros(B, D, device=device)\n",
        "        z2_vis = torch.zeros(B, D, device=device)\n",
        "        z1_txt = torch.zeros(B, D, device=device)\n",
        "        z2_txt = torch.zeros(B, D, device=device)\n",
        "\n",
        "        for b in range(B):\n",
        "            valid = attention_mask[b].bool()\n",
        "            ids_b = input_ids[b]\n",
        "\n",
        "            img_mask = (ids_b == image_token_id) & valid\n",
        "            txt_mask = (~img_mask) & valid\n",
        "\n",
        "            if img_mask.any():\n",
        "                z1_vis[b] = h1[b][img_mask].mean(dim=0)\n",
        "                z2_vis[b] = h2[b][img_mask].mean(dim=0)\n",
        "            else:\n",
        "                # fallback: use all valid tokens\n",
        "                z1_vis[b] = h1[b][valid].mean(dim=0)\n",
        "                z2_vis[b] = h2[b][valid].mean(dim=0)\n",
        "\n",
        "            if txt_mask.any():\n",
        "                z1_txt[b] = h1[b][txt_mask].mean(dim=0)\n",
        "                z2_txt[b] = h2[b][txt_mask].mean(dim=0)\n",
        "            else:\n",
        "                z1_txt[b] = h1[b][valid].mean(dim=0)\n",
        "                z2_txt[b] = h2[b][valid].mean(dim=0)\n",
        "\n",
        "        d_vis  = torch.nan_to_num(z2_vis  - z1_vis).float()\n",
        "        d_text = torch.nan_to_num(z2_txt  - z1_txt).float()\n",
        "\n",
        "        num = (d_vis * d_text).sum(dim=1)\n",
        "        den = (d_vis.norm(dim=1) * d_text.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "\n",
        "    return te_vals\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 4. Perturbation TE: vision tower\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def compute_te_vision_batch(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    pixel_values: torch.Tensor,\n",
        "    sigma: float = IMG_NOISE_STD,\n",
        "    pool: str = VISION_POOL,\n",
        ") -> List[float]:\n",
        "    vt = get_vision_tower(model)\n",
        "    vt = vt.to(DEVICE)\n",
        "    vt_dtype = next(vt.parameters()).dtype\n",
        "\n",
        "    pv = pixel_values.to(DEVICE, dtype=vt_dtype)\n",
        "    noise1 = sigma * torch.randn_like(pv)\n",
        "    noise2 = sigma * torch.randn_like(pv)\n",
        "\n",
        "    out1 = vt(pixel_values=pv + noise1,\n",
        "              output_hidden_states=True,\n",
        "              return_dict=True)\n",
        "    out2 = vt(pixel_values=pv + noise2,\n",
        "              output_hidden_states=True,\n",
        "              return_dict=True)\n",
        "\n",
        "    hs1 = out1.hidden_states  # tuple(len = L+1), each [B, T, D]\n",
        "    hs2 = out2.hidden_states\n",
        "\n",
        "    pooled1 = [_pool_hidden(h, pool) for h in hs1]  # [B, D]\n",
        "    pooled2 = [_pool_hidden(h, pool) for h in hs2]\n",
        "\n",
        "    return _te_from_two_runs(pooled1, pooled2)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 5. Perturbation TE: language tower (LLaMA) + vision→text\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def compute_te_language_and_v2t_batch(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    inputs: dict,\n",
        "    sigma_img: float = IMG_NOISE_STD,\n",
        "    pool: str = LANG_POOL,\n",
        ") -> Tuple[List[float], List[float]]:\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      lang_te: TE across language edges (ℓ -> ℓ+1), pooled over all tokens.\n",
        "      v2t_te:  cross-modal TE per language layer (vision → text).\n",
        "    \"\"\"\n",
        "    input_ids      = inputs[\"input_ids\"].to(DEVICE)\n",
        "    attention_mask = inputs[\"attention_mask\"].to(DEVICE)\n",
        "    pixel_values   = inputs[\"pixel_values\"].to(DEVICE, dtype=model.dtype)\n",
        "\n",
        "    noise1 = sigma_img * torch.randn_like(pixel_values)\n",
        "    noise2 = sigma_img * torch.randn_like(pixel_values)\n",
        "\n",
        "    kwargs1 = dict(\n",
        "        input_ids=input_ids,\n",
        "        attention_mask=attention_mask,\n",
        "        pixel_values=pixel_values + noise1,\n",
        "        output_hidden_states=True,\n",
        "        use_cache=False,\n",
        "        return_dict=True,\n",
        "    )\n",
        "    kwargs2 = dict(\n",
        "        input_ids=input_ids,\n",
        "        attention_mask=attention_mask,\n",
        "        pixel_values=pixel_values + noise2,\n",
        "        output_hidden_states=True,\n",
        "        use_cache=False,\n",
        "        return_dict=True,\n",
        "    )\n",
        "\n",
        "    out1 = model(**kwargs1)\n",
        "    out2 = model(**kwargs2)\n",
        "\n",
        "    hs1 = out1.hidden_states\n",
        "    hs2 = out2.hidden_states\n",
        "\n",
        "    # Standard language TE (edges ℓ -> ℓ+1, pooled over all tokens)\n",
        "    pooled1 = [_pool_hidden(h, pool) for h in hs1]  # [B, D]\n",
        "    pooled2 = [_pool_hidden(h, pool) for h in hs2]\n",
        "    lang_te = _te_from_two_runs(pooled1, pooled2)\n",
        "\n",
        "    # Cross-modal TE per layer: vision → text\n",
        "    image_token_id = model.config.image_token_index\n",
        "    v2t_te = _v2t_te_from_two_runs(\n",
        "        hs1, hs2,\n",
        "        input_ids=input_ids,\n",
        "        attention_mask=attention_mask,\n",
        "        image_token_id=image_token_id,\n",
        "    )\n",
        "\n",
        "    return lang_te, v2t_te\n",
        "\n",
        "\n",
        "# keep a wrapper if you still want \"language-only\" TE interface\n",
        "@torch.no_grad()\n",
        "def compute_te_language_batch(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    inputs: dict,\n",
        "    sigma_img: float = IMG_NOISE_STD,\n",
        "    pool: str = LANG_POOL,\n",
        ") -> List[float]:\n",
        "    lang_te, _ = compute_te_language_and_v2t_batch(\n",
        "        model, inputs, sigma_img=sigma_img, pool=pool\n",
        "    )\n",
        "    return lang_te\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 6. TE estimation loop over a probe dataloader\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def estimate_te_llava(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    processor: AutoProcessor,\n",
        "    dataloader: DataLoader,\n",
        "    max_steps: int = TE_MAX_STEPS,\n",
        "    img_noise_std: float = IMG_NOISE_STD,\n",
        ") -> Tuple[List[float], List[float], List[float]]:\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      vision_te:  [L_v] TE values for CLIP vision edges (ℓ = layer ℓ -> ℓ+1)\n",
        "      lang_te:    [L_l] TE values for LLaMA language edges\n",
        "      v2t_te:     [L_l] cross-modal TE per language layer (vision → text)\n",
        "    \"\"\"\n",
        "    vision_sum, lang_sum, v2t_sum = None, None, None\n",
        "    v_count = l_count = v2t_count = 0\n",
        "\n",
        "    for step, (images, prompts) in enumerate(dataloader):\n",
        "        if max_steps is not None and (step + 1) > max_steps:\n",
        "            break\n",
        "\n",
        "        encoded = processor(\n",
        "            images=images,\n",
        "            text=list(prompts),\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        )\n",
        "\n",
        "        # Vision TE\n",
        "        v_te = compute_te_vision_batch(\n",
        "            model,\n",
        "            encoded[\"pixel_values\"],\n",
        "            sigma=img_noise_std,\n",
        "            pool=VISION_POOL,\n",
        "        )\n",
        "        v_te = torch.tensor(v_te)\n",
        "\n",
        "        # Language TE + vision→text TE\n",
        "        l_te, v2t_te = compute_te_language_and_v2t_batch(\n",
        "            model,\n",
        "            encoded,\n",
        "            sigma_img=img_noise_std,\n",
        "            pool=LANG_POOL,\n",
        "        )\n",
        "        l_te   = torch.tensor(l_te)\n",
        "        v2t_te = torch.tensor(v2t_te)\n",
        "\n",
        "        if vision_sum is None:\n",
        "            vision_sum = v_te\n",
        "            lang_sum   = l_te\n",
        "            v2t_sum    = v2t_te\n",
        "        else:\n",
        "            vision_sum = vision_sum + v_te\n",
        "            lang_sum   = lang_sum   + l_te\n",
        "            v2t_sum    = v2t_sum    + v2t_te\n",
        "\n",
        "        v_count  += 1\n",
        "        l_count  += 1\n",
        "        v2t_count += 1\n",
        "\n",
        "    vision_avg = (vision_sum / max(1, v_count)).tolist()\n",
        "    lang_avg   = (lang_sum   / max(1, l_count)).tolist()\n",
        "    v2t_avg    = (v2t_sum    / max(1, v2t_count)).tolist()\n",
        "    return vision_avg, lang_avg, v2t_avg\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 7. LoRA training on LLaVA (rank=4)\n",
        "# ============================================\n",
        "def print_trainable_parameters(m: nn.Module):\n",
        "    trainable, total = 0, 0\n",
        "    for _, p in m.named_parameters():\n",
        "        total += p.numel()\n",
        "        if p.requires_grad:\n",
        "            trainable += p.numel()\n",
        "    pct = 100 * trainable / max(1, total)\n",
        "    print(f\"Trainable params: {trainable} / {total} ({pct:.4f}%)\")\n",
        "\n",
        "\n",
        "def train_one_epoch_llava_lora(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    processor: AutoProcessor,\n",
        "    dataloader: DataLoader,\n",
        "    optimizer: torch.optim.Optimizer,\n",
        "    epoch: int,\n",
        "):\n",
        "    model.train()\n",
        "    total_loss = 0.0\n",
        "    n_steps = 0\n",
        "\n",
        "    use_amp = (DEVICE == \"cuda\" and DTYPE == torch.float16)\n",
        "    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
        "\n",
        "    for step, (images, prompts) in enumerate(dataloader):\n",
        "        optimizer.zero_grad(set_to_none=True)\n",
        "\n",
        "        encoded = processor(\n",
        "            images=images,\n",
        "            text=list(prompts),\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        )\n",
        "        input_ids      = encoded[\"input_ids\"].to(DEVICE)\n",
        "        attention_mask = encoded[\"attention_mask\"].to(DEVICE)\n",
        "        pixel_values   = encoded[\"pixel_values\"].to(DEVICE, dtype=model.dtype)\n",
        "\n",
        "        # Causal LM labels: predict next tokens; mask pads\n",
        "        labels = input_ids.clone()\n",
        "        labels[attention_mask == 0] = -100\n",
        "\n",
        "        if use_amp:\n",
        "            with torch.cuda.amp.autocast(dtype=DTYPE):\n",
        "                out = model(\n",
        "                    input_ids=input_ids,\n",
        "                    attention_mask=attention_mask,\n",
        "                    pixel_values=pixel_values,\n",
        "                    labels=labels,\n",
        "                    use_cache=False,\n",
        "                    return_dict=True,\n",
        "                )\n",
        "                loss = out.loss\n",
        "            scaler.scale(loss).backward()\n",
        "            scaler.step(optimizer)\n",
        "            scaler.update()\n",
        "        else:\n",
        "            out = model(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                pixel_values=pixel_values,\n",
        "                labels=labels,\n",
        "                use_cache=False,\n",
        "                return_dict=True,\n",
        "            )\n",
        "            loss = out.loss\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "        total_loss += loss.item()\n",
        "        n_steps += 1\n",
        "\n",
        "        if (step + 1) % 100 == 0:\n",
        "            print(f\"  [Epoch {epoch+1} | Step {step+1}] loss={total_loss / max(1, n_steps):.4f}\")\n",
        "\n",
        "    avg_loss = total_loss / max(1, n_steps)\n",
        "    print(f\"Epoch {epoch+1} finished. Avg train loss = {avg_loss:.4f}\")\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 8. Main: LoRA rank=4 (2 epochs) + TE after each epoch\n",
        "# ============================================\n",
        "def main():\n",
        "    print(\"Loading LLaVA model:\", MODEL_ID)\n",
        "    base_model = LlavaForConditionalGeneration.from_pretrained(\n",
        "        MODEL_ID,\n",
        "        torch_dtype=DTYPE,\n",
        "        low_cpu_mem_usage=True,\n",
        "        device_map=None,\n",
        "    )\n",
        "    base_model.to(DEVICE)\n",
        "\n",
        "    # Freeze base parameters\n",
        "    for p in base_model.parameters():\n",
        "        p.requires_grad = False\n",
        "\n",
        "    # LoRA rank=4 on attention projections\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=LORA_R,\n",
        "        lora_alpha=LORA_ALPHA,\n",
        "        lora_dropout=LORA_DROPOUT,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "        target_modules=LORA_TARGET_MODULES,\n",
        "    )\n",
        "    model = get_peft_model(base_model, lora_cfg)\n",
        "    model.to(DEVICE)\n",
        "    model.eval()\n",
        "\n",
        "    print(\"LoRA-wrapped model.\")\n",
        "    print_trainable_parameters(model)\n",
        "\n",
        "    processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
        "\n",
        "    print(\"Building COCO train + probe datasets...\")\n",
        "    train_ds = CocoLlavaTrainDataset(COCO_ROOT, COCO_ANN, max_images=5000)\n",
        "    probe_ds = CocoLlavaProbeDataset(COCO_ROOT, COCO_ANN, max_images=10000)\n",
        "\n",
        "    train_loader = DataLoader(\n",
        "        train_ds,\n",
        "        batch_size=BATCH_SIZE_TRAIN,\n",
        "        shuffle=True,\n",
        "        num_workers=2,\n",
        "        collate_fn=collate_coco,\n",
        "        drop_last=True,\n",
        "    )\n",
        "    probe_loader = DataLoader(\n",
        "        probe_ds,\n",
        "        batch_size=BATCH_SIZE_TE,\n",
        "        shuffle=True,\n",
        "        num_workers=2,\n",
        "        collate_fn=collate_coco,\n",
        "        drop_last=True,\n",
        "    )\n",
        "\n",
        "    optimizer = torch.optim.AdamW(\n",
        "        [p for p in model.parameters() if p.requires_grad],\n",
        "        lr=LR,\n",
        "        weight_decay=0.0,\n",
        "    )\n",
        "\n",
        "    epoch_te_history = []\n",
        "\n",
        "    for epoch in range(NUM_EPOCHS):\n",
        "        print(f\"\\n=== LoRA training: epoch {epoch+1}/{NUM_EPOCHS} ===\")\n",
        "        train_one_epoch_llava_lora(model, processor, train_loader, optimizer, epoch)\n",
        "\n",
        "        print(f\"\\nEstimating TE after epoch {epoch+1} (vision + language + vision→text)...\")\n",
        "        vision_te, lang_te, v2t_te = estimate_te_llava(\n",
        "            model,\n",
        "            processor,\n",
        "            probe_loader,\n",
        "            max_steps=TE_MAX_STEPS,\n",
        "            img_noise_std=IMG_NOISE_STD,\n",
        "        )\n",
        "        epoch_te_history.append((vision_te, lang_te, v2t_te))\n",
        "\n",
        "        print(\"\\n=== Averaged TE (vision tower) ===\")\n",
        "        print(f\"#edges = {len(vision_te)} (edge ℓ = layer ℓ -> layer ℓ+1)\")\n",
        "        print([round(x, 6) for x in vision_te])\n",
        "\n",
        "        print(\"\\n=== Averaged TE (language tower) ===\")\n",
        "        print(f\"#edges = {len(lang_te)} (edge ℓ = layer ℓ -> layer ℓ+1)\")\n",
        "        print([round(x, 6) for x in lang_te])\n",
        "\n",
        "        print(\"\\n=== Cross-modal TE (vision → text, per language layer) ===\")\n",
        "        print(f\"#layers = {len(v2t_te)} (layer i is LLaMA block i)\")\n",
        "        print([round(x, 6) for x in v2t_te])\n",
        "\n",
        "    print(\"\\nDone. `epoch_te_history` now holds TE per layer per epoch (vision, language, vision→text).\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "tXuzpXdjWQTM",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "00bb1f736cc1429bb76ca8f896ab135d",
            "0edb7bf930954a598546a0d600c5df2a",
            "f554d43ff3314cd4adabea173fb9d1b0",
            "e00698f395914dc89d6130f7562cb3d0",
            "b4025e355ed84fd5a291f49607a6642a",
            "e832ca8cbe7647bcacbed0d5a111e527",
            "0e540530f7224389953099fab770960c",
            "ceb9cad7342045399fa8e4b605f0d7b9",
            "33f03393b1294964ad37ce262d66a68c",
            "b13dfe9ce5ab431c8e17080e4de7222a",
            "9b52248ecede4b3386e50a65b6fe3622"
          ]
        },
        "outputId": "6fdd00b8-7da1-4573-f9b6-0a982cae6b02"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading LLaVA model: llava-hf/llava-1.5-7b-hf\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "00bb1f736cc1429bb76ca8f896ab135d"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "LoRA-wrapped model.\n",
            "Trainable params: 4784128 / 7068211200 (0.0677%)\n",
            "Building COCO train + probe datasets...\n",
            "loading annotations into memory...\n",
            "Done (t=0.69s)\n",
            "creating index...\n",
            "index created!\n",
            "loading annotations into memory...\n",
            "Done (t=0.70s)\n",
            "creating index...\n",
            "index created!\n",
            "\n",
            "=== LoRA training: epoch 1/5 ===\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-44385239.py:431: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
            "  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
            "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
            "/tmp/ipython-input-44385239.py:452: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
            "  with torch.cuda.amp.autocast(dtype=DTYPE):\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  [Epoch 1 | Step 100] loss=5.8274\n",
            "  [Epoch 1 | Step 200] loss=4.8545\n",
            "  [Epoch 1 | Step 300] loss=4.4797\n",
            "  [Epoch 1 | Step 400] loss=4.2862\n",
            "  [Epoch 1 | Step 500] loss=4.1694\n",
            "  [Epoch 1 | Step 600] loss=4.0910\n",
            "  [Epoch 1 | Step 700] loss=4.0352\n",
            "  [Epoch 1 | Step 800] loss=3.9932\n",
            "  [Epoch 1 | Step 900] loss=3.9606\n",
            "  [Epoch 1 | Step 1000] loss=3.9342\n",
            "  [Epoch 1 | Step 1100] loss=3.9128\n",
            "  [Epoch 1 | Step 1200] loss=3.8948\n",
            "Epoch 1 finished. Avg train loss = 3.8869\n",
            "\n",
            "Estimating TE after epoch 1 (vision + language + vision→text)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.191455, 0.205068, 0.212704, 0.253446, 0.294574, 0.273211, 0.266992, 0.239615, 0.260743, 0.288651, 0.09233, 0.202806, 0.207214, 0.207371, 0.196218, 0.196254, 0.264486, 0.262906, 0.259218, 0.242855, 0.191075, 0.186666, 0.238649]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.498186, 0.475987, 0.468155, 0.452813, 0.441949, 0.410128, 0.218861, 0.362198, 0.289213, 0.33414, 0.372135, 0.333555, 0.34166, 0.354187, 0.306789, 0.340324, 0.320095, 0.30872, 0.342242, 0.359406, 0.346304, 0.388862, 0.396695, 0.373957, 0.387519, 0.390044, 0.395625, 0.260736, 0.299599, 0.425339, 0.427077, 0.280998]\n",
            "\n",
            "=== Cross-modal TE (vision → text, per language layer) ===\n",
            "#layers = 32 (layer i is LLaMA block i)\n",
            "[0.000285, 0.000455, 0.000518, 0.002116, 0.003397, 0.004607, 0.007564, 0.007834, 0.008405, 0.010323, 0.010628, 0.014145, 0.015823, 0.012645, 0.01348, 0.013269, 0.021144, 0.023427, 0.029229, 0.035464, 0.038224, 0.040676, 0.050166, 0.051215, 0.058364, 0.063507, 0.071233, 0.080725, 0.110552, 0.122285, 0.127648, 0.030381]\n",
            "\n",
            "=== LoRA training: epoch 2/5 ===\n",
            "  [Epoch 2 | Step 100] loss=3.6969\n",
            "  [Epoch 2 | Step 200] loss=3.6966\n",
            "  [Epoch 2 | Step 300] loss=3.6962\n",
            "  [Epoch 2 | Step 400] loss=3.6963\n",
            "  [Epoch 2 | Step 500] loss=3.6961\n",
            "  [Epoch 2 | Step 600] loss=3.6963\n",
            "  [Epoch 2 | Step 700] loss=3.6962\n",
            "  [Epoch 2 | Step 800] loss=3.6962\n",
            "  [Epoch 2 | Step 900] loss=3.6960\n",
            "  [Epoch 2 | Step 1000] loss=3.6960\n",
            "  [Epoch 2 | Step 1100] loss=3.6960\n",
            "  [Epoch 2 | Step 1200] loss=3.6960\n",
            "Epoch 2 finished. Avg train loss = 3.6961\n",
            "\n",
            "Estimating TE after epoch 2 (vision + language + vision→text)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.186839, 0.215888, 0.215021, 0.249156, 0.293105, 0.282557, 0.263733, 0.228799, 0.257151, 0.285976, 0.08437, 0.14487, 0.175013, 0.19853, 0.189448, 0.216634, 0.272201, 0.282825, 0.277201, 0.262839, 0.230325, 0.218267, 0.253068]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.498398, 0.480062, 0.466358, 0.45147, 0.44622, 0.41151, 0.247297, 0.374818, 0.284372, 0.326921, 0.36986, 0.320885, 0.316285, 0.337694, 0.282138, 0.306521, 0.283644, 0.276957, 0.312257, 0.344619, 0.311812, 0.382983, 0.375015, 0.371403, 0.381091, 0.391172, 0.377975, 0.293446, 0.315307, 0.414282, 0.413206, 0.290647]\n",
            "\n",
            "=== Cross-modal TE (vision → text, per language layer) ===\n",
            "#layers = 32 (layer i is LLaMA block i)\n",
            "[0.000291, 0.000271, 0.000373, 0.002286, 0.003529, 0.004965, 0.006954, 0.007639, 0.007784, 0.009848, 0.010458, 0.014595, 0.01847, 0.014052, 0.015984, 0.014925, 0.025303, 0.02774, 0.036382, 0.042245, 0.044293, 0.047745, 0.06034, 0.061825, 0.071714, 0.075922, 0.088326, 0.105674, 0.112849, 0.111766, 0.113251, 0.048344]\n",
            "\n",
            "=== LoRA training: epoch 3/5 ===\n",
            "  [Epoch 3 | Step 100] loss=3.6951\n",
            "  [Epoch 3 | Step 200] loss=3.6945\n",
            "  [Epoch 3 | Step 300] loss=3.6940\n",
            "  [Epoch 3 | Step 400] loss=3.6943\n",
            "  [Epoch 3 | Step 500] loss=3.6942\n",
            "  [Epoch 3 | Step 600] loss=3.6945\n",
            "  [Epoch 3 | Step 700] loss=3.6945\n",
            "  [Epoch 3 | Step 800] loss=3.6942\n",
            "  [Epoch 3 | Step 900] loss=3.6942\n",
            "  [Epoch 3 | Step 1000] loss=3.6942\n",
            "  [Epoch 3 | Step 1100] loss=3.6942\n",
            "  [Epoch 3 | Step 1200] loss=3.6942\n",
            "Epoch 3 finished. Avg train loss = 3.6943\n",
            "\n",
            "Estimating TE after epoch 3 (vision + language + vision→text)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.169949, 0.215979, 0.20457, 0.250538, 0.284583, 0.287032, 0.273873, 0.226213, 0.253751, 0.27821, 0.089645, 0.126356, 0.150528, 0.195008, 0.193891, 0.217825, 0.268951, 0.281822, 0.273312, 0.265258, 0.230807, 0.215452, 0.248284]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.498828, 0.488072, 0.464442, 0.457539, 0.453905, 0.42586, 0.32476, 0.387424, 0.303174, 0.327165, 0.372551, 0.320313, 0.311061, 0.340829, 0.270308, 0.289688, 0.27613, 0.262655, 0.304431, 0.329191, 0.301636, 0.37654, 0.366879, 0.379963, 0.371071, 0.396358, 0.373253, 0.324053, 0.334494, 0.405746, 0.40465, 0.300784]\n",
            "\n",
            "=== Cross-modal TE (vision → text, per language layer) ===\n",
            "#layers = 32 (layer i is LLaMA block i)\n",
            "[0.000294, 0.000212, 0.000368, 0.002764, 0.004003, 0.005307, 0.007544, 0.007949, 0.008698, 0.010182, 0.011002, 0.015356, 0.019982, 0.015951, 0.018609, 0.017953, 0.030528, 0.034205, 0.043563, 0.051458, 0.046898, 0.050954, 0.065125, 0.067563, 0.07967, 0.080436, 0.096987, 0.111103, 0.093244, 0.085921, 0.090096, 0.049849]\n",
            "\n",
            "=== LoRA training: epoch 4/5 ===\n",
            "  [Epoch 4 | Step 100] loss=3.6918\n",
            "  [Epoch 4 | Step 200] loss=3.6915\n",
            "  [Epoch 4 | Step 300] loss=3.6917\n",
            "  [Epoch 4 | Step 400] loss=3.6916\n",
            "  [Epoch 4 | Step 500] loss=3.6917\n",
            "  [Epoch 4 | Step 600] loss=3.6920\n",
            "  [Epoch 4 | Step 700] loss=3.6920\n",
            "  [Epoch 4 | Step 800] loss=3.6921\n",
            "  [Epoch 4 | Step 900] loss=3.6921\n",
            "  [Epoch 4 | Step 1000] loss=3.6922\n",
            "  [Epoch 4 | Step 1100] loss=3.6922\n",
            "  [Epoch 4 | Step 1200] loss=3.6923\n",
            "Epoch 4 finished. Avg train loss = 3.6923\n",
            "\n",
            "Estimating TE after epoch 4 (vision + language + vision→text)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.167277, 0.205844, 0.196485, 0.259159, 0.287849, 0.299862, 0.269123, 0.215267, 0.234732, 0.262328, 0.111142, 0.107614, 0.150454, 0.196992, 0.200728, 0.183416, 0.256611, 0.276556, 0.280952, 0.279859, 0.2388, 0.22372, 0.229135]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.499179, 0.493111, 0.471383, 0.46808, 0.460628, 0.438455, 0.354083, 0.387914, 0.311499, 0.359292, 0.390391, 0.336579, 0.323904, 0.344655, 0.266476, 0.315123, 0.280027, 0.278418, 0.320369, 0.32099, 0.285829, 0.379521, 0.361427, 0.384917, 0.37438, 0.407765, 0.366283, 0.342331, 0.337216, 0.402126, 0.406515, 0.314174]\n",
            "\n",
            "=== Cross-modal TE (vision → text, per language layer) ===\n",
            "#layers = 32 (layer i is LLaMA block i)\n",
            "[0.000301, 0.000185, 0.000369, 0.002302, 0.00358, 0.004676, 0.006805, 0.007004, 0.007655, 0.00964, 0.009407, 0.012957, 0.017696, 0.014352, 0.016243, 0.015704, 0.028487, 0.034174, 0.043077, 0.05114, 0.04475, 0.049667, 0.058236, 0.06022, 0.069523, 0.066114, 0.0818, 0.090186, 0.07008, 0.063691, 0.066477, 0.042168]\n",
            "\n",
            "=== LoRA training: epoch 5/5 ===\n",
            "  [Epoch 5 | Step 100] loss=3.6897\n",
            "  [Epoch 5 | Step 200] loss=3.6891\n",
            "  [Epoch 5 | Step 300] loss=3.6894\n",
            "  [Epoch 5 | Step 400] loss=3.6893\n",
            "  [Epoch 5 | Step 500] loss=3.6892\n",
            "  [Epoch 5 | Step 600] loss=3.6894\n",
            "  [Epoch 5 | Step 700] loss=3.6896\n",
            "  [Epoch 5 | Step 800] loss=3.6897\n",
            "  [Epoch 5 | Step 900] loss=3.6898\n",
            "  [Epoch 5 | Step 1000] loss=3.6899\n",
            "  [Epoch 5 | Step 1100] loss=3.6898\n",
            "  [Epoch 5 | Step 1200] loss=3.6898\n",
            "Epoch 5 finished. Avg train loss = 3.6899\n",
            "\n",
            "Estimating TE after epoch 5 (vision + language + vision→text)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.143103, 0.198448, 0.206825, 0.249116, 0.283741, 0.295905, 0.279676, 0.2133, 0.243979, 0.261149, 0.131485, 0.118128, 0.161669, 0.200404, 0.199589, 0.206731, 0.24733, 0.255116, 0.253607, 0.26156, 0.230031, 0.212323, 0.199553]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.499287, 0.494638, 0.468245, 0.473288, 0.46635, 0.443152, 0.396828, 0.406675, 0.329252, 0.361899, 0.394036, 0.342782, 0.342863, 0.340454, 0.241144, 0.314581, 0.263026, 0.283912, 0.300775, 0.32431, 0.296035, 0.368973, 0.357769, 0.390724, 0.384536, 0.410849, 0.362675, 0.353718, 0.339265, 0.401149, 0.400023, 0.335119]\n",
            "\n",
            "=== Cross-modal TE (vision → text, per language layer) ===\n",
            "#layers = 32 (layer i is LLaMA block i)\n",
            "[0.000346, 0.000154, 0.000534, 0.001622, 0.002624, 0.003598, 0.006035, 0.006443, 0.006726, 0.008075, 0.0078, 0.00967, 0.013698, 0.012183, 0.01319, 0.013487, 0.028614, 0.036078, 0.043079, 0.052472, 0.045059, 0.048945, 0.057963, 0.057666, 0.063763, 0.061465, 0.07813, 0.08316, 0.064028, 0.053855, 0.056916, 0.035096]\n",
            "\n",
            "Done. `epoch_te_history` now holds TE per layer per epoch (vision, language, vision→text).\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "x6XOSFCTWQWZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "5YtdDSQbWQaU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "AxtSo8u7WQd7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LMerMwH6WQhU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "agMGXoPwWQru"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}