{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "7adf1231caa24572aa7cc5809f746bd0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7152f146393b4bfd966a6af81d1ad994",
              "IPY_MODEL_556cb5d5d91945688917e664b88c65e5",
              "IPY_MODEL_9b64ba3dbd0348a586d4a2620e56c442"
            ],
            "layout": "IPY_MODEL_fd697c90fa094346a6381912ccda42cc"
          }
        },
        "7152f146393b4bfd966a6af81d1ad994": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d6d3661a1b244c80823867e569deaec0",
            "placeholder": "​",
            "style": "IPY_MODEL_b50fb00f90424e4f8a67e53cc2f74e12",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "556cb5d5d91945688917e664b88c65e5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9f108fb36acd4ac5895e9b537a87c065",
            "max": 3,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_da7d5bfc871f4520988ecefbcbbeb34c",
            "value": 3
          }
        },
        "9b64ba3dbd0348a586d4a2620e56c442": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_46bf9df8442f4b86a22dcede25b914a6",
            "placeholder": "​",
            "style": "IPY_MODEL_7727dbef492b4d1c82b8ec10f9895657",
            "value": " 3/3 [00:00&lt;00:00,  5.39it/s]"
          }
        },
        "fd697c90fa094346a6381912ccda42cc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d6d3661a1b244c80823867e569deaec0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b50fb00f90424e4f8a67e53cc2f74e12": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9f108fb36acd4ac5895e9b537a87c065": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "da7d5bfc871f4520988ecefbcbbeb34c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "46bf9df8442f4b86a22dcede25b914a6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7727dbef492b4d1c82b8ec10f9895657": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n9NaU-YAWPjl",
        "outputId": "e2a67707-736f-4c73-cb40-a9160f574cb3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "COCO2014 not found, downloading...\n",
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [17:11<00:00, 13.1MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:21<00:00, 11.8MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ],
      "source": [
        "# ============================================\n",
        "# 0. COCO2014 download + paths\n",
        "# ============================================\n",
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "data_dir = \"/content/coco2014\"\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "datasets_urls = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get(\"content-length\", 0))\n",
        "    with open(dest_path, \"wb\") as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit=\"B\",\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024,\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "def ensure_coco2014(data_dir: str):\n",
        "    ann_path = os.path.join(data_dir, \"annotations\", \"captions_train2014.json\")\n",
        "    if os.path.exists(ann_path):\n",
        "        print(f\"COCO2014 already present under {data_dir}\")\n",
        "        return\n",
        "\n",
        "    print(\"COCO2014 not found, downloading...\")\n",
        "    for name, url in datasets_urls.items():\n",
        "        zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "        print(f\"Processing {name}...\")\n",
        "        download_file(url, zip_path)\n",
        "        with ZipFile(zip_path, \"r\") as zip_ref:\n",
        "            zip_ref.extractall(data_dir)\n",
        "        os.remove(zip_path)\n",
        "        print(f\"{name} downloaded and extracted.\")\n",
        "    print(\"All datasets and annotations successfully downloaded and extracted!\")\n",
        "\n",
        "ensure_coco2014(data_dir)\n",
        "\n",
        "COCO_ROOT = os.path.join(data_dir, \"train2014\")\n",
        "COCO_ANN  = os.path.join(data_dir, \"annotations\", \"captions_train2014.json\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# ============================================\n",
        "# 1. LLaVA + LoRA + TE imports / config\n",
        "# ============================================\n",
        "import math\n",
        "import random\n",
        "from typing import List, Tuple\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from torchvision import datasets as tv_datasets\n",
        "\n",
        "from transformers import LlavaForConditionalGeneration, AutoProcessor\n",
        "from peft import LoraConfig, get_peft_model\n",
        "\n",
        "MODEL_ID = \"llava-hf/llava-1.5-7b-hf\"\n",
        "\n",
        "DEVICE   = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "DTYPE    = torch.float16 if DEVICE == \"cuda\" else torch.float32\n",
        "\n",
        "# Training config\n",
        "BATCH_SIZE_TRAIN = 4        # keep tiny for 7B\n",
        "NUM_EPOCHS       = 2\n",
        "LR               = 5e-5\n",
        "\n",
        "# TE config (perturbation-based, like CLIP TE)\n",
        "IMG_NOISE_STD    = 1e-3     # small Gaussian noise on images\n",
        "VISION_POOL      = \"cls\"    # \"cls\" or \"mean\" or \"last\"\n",
        "LANG_POOL        = \"mean\"\n",
        "BATCH_SIZE_TE    = 16        # >=2 for TE proxy\n",
        "TE_MAX_STEPS     = 128      # max TE batches; images ~= TE_MAX_STEPS * BATCH_SIZE_TE\n",
        "\n",
        "# LoRA config\n",
        "LORA_R           = 4\n",
        "LORA_ALPHA       = 16\n",
        "LORA_DROPOUT     = 0.1\n",
        "LORA_TARGET_MODULES = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"]\n",
        "\n",
        "SEED             = 42\n",
        "\n",
        "random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 2. COCO datasets for LLaVA (train + probe)\n",
        "# ============================================\n",
        "class CocoLlavaTrainDataset(Dataset):\n",
        "    \"\"\"\n",
        "    Training dataset: (PIL image, prompt-with-caption-as-target).\n",
        "    Uses the first COCO caption as the assistant answer.\n",
        "    \"\"\"\n",
        "    def __init__(self, root: str, annFile: str, max_images: int = 512):\n",
        "        self.base = tv_datasets.CocoCaptions(root=root, annFile=annFile)\n",
        "        idxs = list(range(len(self.base)))\n",
        "        random.shuffle(idxs)\n",
        "        self.indices = idxs[:max_images]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.indices)\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        img_idx = self.indices[idx]\n",
        "        image, captions = self.base[img_idx]\n",
        "        image = image.convert(\"RGB\")\n",
        "\n",
        "        caption = captions[0] if len(captions) > 0 else \"A photo.\"\n",
        "\n",
        "        prompt = (\n",
        "            \"USER: <image>\\nDescribe this image in one sentence.\\n\"\n",
        "            f\"ASSISTANT: {caption}\"\n",
        "        )\n",
        "        return image, prompt\n",
        "\n",
        "\n",
        "class CocoLlavaProbeDataset(Dataset):\n",
        "    \"\"\"\n",
        "    TE probe dataset (no labels, generic prompt).\n",
        "    Each item: (PIL image, prompt string).\n",
        "    \"\"\"\n",
        "    def __init__(self, root: str, annFile: str, max_images: int = 512):\n",
        "        self.base = tv_datasets.CocoCaptions(root=root, annFile=annFile)\n",
        "        idxs = list(range(len(self.base)))\n",
        "        random.shuffle(idxs)\n",
        "        self.indices = idxs[:max_images]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.indices)\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        img_idx = self.indices[idx]\n",
        "        image, _captions = self.base[img_idx]\n",
        "        image = image.convert(\"RGB\")\n",
        "        prompt = \"USER: <image>\\nDescribe this image in one sentence.\\nASSISTANT:\"\n",
        "        return image, prompt\n",
        "\n",
        "\n",
        "def collate_coco(batch):\n",
        "    images, prompts = zip(*batch)\n",
        "    return list(images), list(prompts)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 3. Vision tower helper + TE utilities\n",
        "# ============================================\n",
        "def get_vision_tower(model: LlavaForConditionalGeneration) -> nn.Module:\n",
        "    \"\"\"\n",
        "    Extract the CLIP vision tower from a LLaVA HF model.\n",
        "    \"\"\"\n",
        "    base = getattr(model, \"base_model\", model)\n",
        "    vt = getattr(base, \"vision_tower\", None)\n",
        "    if vt is None:\n",
        "        raise ValueError(\"This LLaVA model has no `vision_tower` attribute.\")\n",
        "    if isinstance(vt, (list, tuple)):\n",
        "        vt = vt[0]\n",
        "    if hasattr(vt, \"vision_tower\"):\n",
        "        vt = vt.vision_tower\n",
        "    return vt\n",
        "\n",
        "\n",
        "def _pool_hidden(h: torch.Tensor, pool: str) -> torch.Tensor:\n",
        "    # h: [B, T, D]\n",
        "    if pool == \"cls\":\n",
        "        return h[:, 0, :]\n",
        "    elif pool == \"last\":\n",
        "        return h[:, -1, :]\n",
        "    elif pool == \"mean\":\n",
        "        return h.mean(dim=1)\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown pool={pool}\")\n",
        "\n",
        "\n",
        "def _te_from_two_runs(pooled1: List[torch.Tensor],\n",
        "                      pooled2: List[torch.Tensor]) -> List[float]:\n",
        "    \"\"\"\n",
        "    Perturbation-based TE proxy (same idea as your CLIP TE):\n",
        "\n",
        "    For each edge ℓ -> ℓ+1:\n",
        "      Δℓ   = zℓ(2)   − zℓ(1)\n",
        "      Δℓ+1 = zℓ+1(2) − zℓ+1(1)\n",
        "      TEℓ  = 0.5 * mean_b cos²(Δℓ, Δℓ+1)\n",
        "    \"\"\"\n",
        "    assert len(pooled1) == len(pooled2) and len(pooled1) >= 2\n",
        "    eps = 1e-6\n",
        "    te_vals = []\n",
        "    for l in range(len(pooled1) - 1):\n",
        "        A = torch.nan_to_num(pooled2[l]   - pooled1[l]).float()   # [B, D]\n",
        "        B = torch.nan_to_num(pooled2[l+1] - pooled1[l+1]).float() # [B, D]\n",
        "        num = (A * B).sum(dim=1)\n",
        "        den = (A.norm(dim=1) * B.norm(dim=1)).clamp_min(eps)\n",
        "        cos = (num / den).clamp(-1.0, 1.0)\n",
        "        te_vals.append(0.5 * float((cos * cos).mean().item()))\n",
        "    return te_vals\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 4. Perturbation TE: vision tower\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def compute_te_vision_batch(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    pixel_values: torch.Tensor,\n",
        "    sigma: float = IMG_NOISE_STD,\n",
        "    pool: str = VISION_POOL,\n",
        ") -> List[float]:\n",
        "    vt = get_vision_tower(model)\n",
        "    vt = vt.to(DEVICE)\n",
        "    vt_dtype = next(vt.parameters()).dtype\n",
        "\n",
        "    pv = pixel_values.to(DEVICE, dtype=vt_dtype)\n",
        "    noise1 = sigma * torch.randn_like(pv)\n",
        "    noise2 = sigma * torch.randn_like(pv)\n",
        "\n",
        "    out1 = vt(pixel_values=pv + noise1,\n",
        "              output_hidden_states=True,\n",
        "              return_dict=True)\n",
        "    out2 = vt(pixel_values=pv + noise2,\n",
        "              output_hidden_states=True,\n",
        "              return_dict=True)\n",
        "\n",
        "    hs1 = out1.hidden_states  # tuple(len = L+1), each [B, T, D]\n",
        "    hs2 = out2.hidden_states\n",
        "\n",
        "    pooled1 = [_pool_hidden(h, pool) for h in hs1]  # [B, D]\n",
        "    pooled2 = [_pool_hidden(h, pool) for h in hs2]\n",
        "\n",
        "    return _te_from_two_runs(pooled1, pooled2)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 5. Perturbation TE: language tower (LLaMA)\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def compute_te_language_batch(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    inputs: dict,\n",
        "    sigma_img: float = IMG_NOISE_STD,\n",
        "    pool: str = LANG_POOL,\n",
        ") -> List[float]:\n",
        "    input_ids      = inputs[\"input_ids\"].to(DEVICE)\n",
        "    attention_mask = inputs[\"attention_mask\"].to(DEVICE)\n",
        "    pixel_values   = inputs[\"pixel_values\"].to(DEVICE, dtype=model.dtype)\n",
        "\n",
        "    noise1 = sigma_img * torch.randn_like(pixel_values)\n",
        "    noise2 = sigma_img * torch.randn_like(pixel_values)\n",
        "\n",
        "    kwargs1 = dict(\n",
        "        input_ids=input_ids,\n",
        "        attention_mask=attention_mask,\n",
        "        pixel_values=pixel_values + noise1,\n",
        "        output_hidden_states=True,\n",
        "        use_cache=False,\n",
        "        return_dict=True,\n",
        "    )\n",
        "    kwargs2 = dict(\n",
        "        input_ids=input_ids,\n",
        "        attention_mask=attention_mask,\n",
        "        pixel_values=pixel_values + noise2,\n",
        "        output_hidden_states=True,\n",
        "        use_cache=False,\n",
        "        return_dict=True,\n",
        "    )\n",
        "\n",
        "    out1 = model(**kwargs1)\n",
        "    out2 = model(**kwargs2)\n",
        "\n",
        "    hs1 = out1.hidden_states\n",
        "    hs2 = out2.hidden_states\n",
        "\n",
        "    pooled1 = [_pool_hidden(h, pool) for h in hs1]  # [B, D]\n",
        "    pooled2 = [_pool_hidden(h, pool) for h in hs2]\n",
        "\n",
        "    return _te_from_two_runs(pooled1, pooled2)\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 6. TE estimation loop over a probe dataloader\n",
        "# ============================================\n",
        "@torch.no_grad()\n",
        "def estimate_te_llava(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    processor: AutoProcessor,\n",
        "    dataloader: DataLoader,\n",
        "    max_steps: int = TE_MAX_STEPS,\n",
        "    img_noise_std: float = IMG_NOISE_STD,\n",
        ") -> Tuple[List[float], List[float]]:\n",
        "    \"\"\"\n",
        "    Returns:\n",
        "      vision_te:  [L_v] TE values for CLIP vision edges (ℓ = layer ℓ -> ℓ+1)\n",
        "      lang_te:    [L_l] TE values for LLaMA language edges\n",
        "    \"\"\"\n",
        "    vision_sum, lang_sum = None, None\n",
        "    v_count = l_count = 0\n",
        "\n",
        "    for step, (images, prompts) in enumerate(dataloader):\n",
        "        if max_steps is not None and (step + 1) > max_steps:\n",
        "            break\n",
        "\n",
        "        encoded = processor(\n",
        "            images=images,\n",
        "            text=list(prompts),\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        )\n",
        "\n",
        "        v_te = compute_te_vision_batch(\n",
        "            model,\n",
        "            encoded[\"pixel_values\"],\n",
        "            sigma=img_noise_std,\n",
        "            pool=VISION_POOL,\n",
        "        )\n",
        "        l_te = compute_te_language_batch(\n",
        "            model,\n",
        "            encoded,\n",
        "            sigma_img=img_noise_std,\n",
        "            pool=LANG_POOL,\n",
        "        )\n",
        "\n",
        "        v_te = torch.tensor(v_te)\n",
        "        l_te = torch.tensor(l_te)\n",
        "\n",
        "        if vision_sum is None:\n",
        "            vision_sum = v_te\n",
        "            lang_sum   = l_te\n",
        "        else:\n",
        "            vision_sum = vision_sum + v_te\n",
        "            lang_sum   = lang_sum + l_te\n",
        "\n",
        "        v_count += 1\n",
        "        l_count += 1\n",
        "\n",
        "    vision_avg  = (vision_sum / max(1, v_count)).tolist()\n",
        "    lang_avg    = (lang_sum   / max(1, l_count)).tolist()\n",
        "    return vision_avg, lang_avg\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 7. LoRA training on LLaVA (rank=4)\n",
        "# ============================================\n",
        "def print_trainable_parameters(m: nn.Module):\n",
        "    trainable, total = 0, 0\n",
        "    for _, p in m.named_parameters():\n",
        "        total += p.numel()\n",
        "        if p.requires_grad:\n",
        "            trainable += p.numel()\n",
        "    pct = 100 * trainable / max(1, total)\n",
        "    print(f\"Trainable params: {trainable} / {total} ({pct:.4f}%)\")\n",
        "\n",
        "\n",
        "def train_one_epoch_llava_lora(\n",
        "    model: LlavaForConditionalGeneration,\n",
        "    processor: AutoProcessor,\n",
        "    dataloader: DataLoader,\n",
        "    optimizer: torch.optim.Optimizer,\n",
        "    epoch: int,\n",
        "):\n",
        "    model.train()\n",
        "    total_loss = 0.0\n",
        "    n_steps = 0\n",
        "\n",
        "    use_amp = (DEVICE == \"cuda\" and DTYPE == torch.float16)\n",
        "    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
        "\n",
        "    for step, (images, prompts) in enumerate(dataloader):\n",
        "        optimizer.zero_grad(set_to_none=True)\n",
        "\n",
        "        encoded = processor(\n",
        "            images=images,\n",
        "            text=list(prompts),\n",
        "            return_tensors=\"pt\",\n",
        "            padding=True,\n",
        "            truncation=True,\n",
        "        )\n",
        "        input_ids      = encoded[\"input_ids\"].to(DEVICE)\n",
        "        attention_mask = encoded[\"attention_mask\"].to(DEVICE)\n",
        "        pixel_values   = encoded[\"pixel_values\"].to(DEVICE, dtype=model.dtype)\n",
        "\n",
        "        # Causal LM labels: predict next tokens; mask pads\n",
        "        labels = input_ids.clone()\n",
        "        labels[attention_mask == 0] = -100\n",
        "\n",
        "        if use_amp:\n",
        "            with torch.cuda.amp.autocast(dtype=DTYPE):\n",
        "                out = model(\n",
        "                    input_ids=input_ids,\n",
        "                    attention_mask=attention_mask,\n",
        "                    pixel_values=pixel_values,\n",
        "                    labels=labels,\n",
        "                    use_cache=False,\n",
        "                    return_dict=True,\n",
        "                )\n",
        "                loss = out.loss\n",
        "            scaler.scale(loss).backward()\n",
        "            scaler.step(optimizer)\n",
        "            scaler.update()\n",
        "        else:\n",
        "            out = model(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                pixel_values=pixel_values,\n",
        "                labels=labels,\n",
        "                use_cache=False,\n",
        "                return_dict=True,\n",
        "            )\n",
        "            loss = out.loss\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "        total_loss += loss.item()\n",
        "        n_steps += 1\n",
        "\n",
        "        if (step + 1) % 100 == 0:\n",
        "            print(f\"  [Epoch {epoch+1} | Step {step+1}] loss={total_loss / max(1, n_steps):.4f}\")\n",
        "\n",
        "    avg_loss = total_loss / max(1, n_steps)\n",
        "    print(f\"Epoch {epoch+1} finished. Avg train loss = {avg_loss:.4f}\")\n",
        "\n",
        "\n",
        "# ============================================\n",
        "# 8. Main: LoRA rank=4 (2 epochs) + TE after each epoch\n",
        "# ============================================\n",
        "def main():\n",
        "    print(\"Loading LLaVA model:\", MODEL_ID)\n",
        "    base_model = LlavaForConditionalGeneration.from_pretrained(\n",
        "        MODEL_ID,\n",
        "        torch_dtype=DTYPE,\n",
        "        low_cpu_mem_usage=True,\n",
        "        device_map=None,\n",
        "    )\n",
        "    base_model.to(DEVICE)\n",
        "\n",
        "    # Freeze base parameters\n",
        "    for p in base_model.parameters():\n",
        "        p.requires_grad = False\n",
        "\n",
        "    # LoRA rank=4 on attention projections\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=LORA_R,\n",
        "        lora_alpha=LORA_ALPHA,\n",
        "        lora_dropout=LORA_DROPOUT,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "        target_modules=LORA_TARGET_MODULES,\n",
        "    )\n",
        "    model = get_peft_model(base_model, lora_cfg)\n",
        "    model.to(DEVICE)\n",
        "    model.eval()\n",
        "\n",
        "    print(\"LoRA-wrapped model.\")\n",
        "    print_trainable_parameters(model)\n",
        "\n",
        "    processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
        "\n",
        "    print(\"Building COCO train + probe datasets...\")\n",
        "    train_ds = CocoLlavaTrainDataset(COCO_ROOT, COCO_ANN, max_images=2000)\n",
        "    probe_ds = CocoLlavaProbeDataset(COCO_ROOT, COCO_ANN, max_images=4096)\n",
        "\n",
        "    train_loader = DataLoader(\n",
        "        train_ds,\n",
        "        batch_size=BATCH_SIZE_TRAIN,\n",
        "        shuffle=True,\n",
        "        num_workers=2,\n",
        "        collate_fn=collate_coco,\n",
        "        drop_last=True,\n",
        "    )\n",
        "    probe_loader = DataLoader(\n",
        "        probe_ds,\n",
        "        batch_size=BATCH_SIZE_TE,\n",
        "        shuffle=True,\n",
        "        num_workers=2,\n",
        "        collate_fn=collate_coco,\n",
        "        drop_last=True,\n",
        "    )\n",
        "\n",
        "    optimizer = torch.optim.AdamW(\n",
        "        [p for p in model.parameters() if p.requires_grad],\n",
        "        lr=LR,\n",
        "        weight_decay=0.0,\n",
        "    )\n",
        "\n",
        "    epoch_te_history = []\n",
        "\n",
        "    for epoch in range(NUM_EPOCHS):\n",
        "        print(f\"\\n=== LoRA training: epoch {epoch+1}/{NUM_EPOCHS} ===\")\n",
        "        train_one_epoch_llava_lora(model, processor, train_loader, optimizer, epoch)\n",
        "\n",
        "        print(f\"\\nEstimating TE after epoch {epoch+1} (vision + language towers)...\")\n",
        "        vision_te, lang_te = estimate_te_llava(\n",
        "            model,\n",
        "            processor,\n",
        "            probe_loader,\n",
        "            max_steps=TE_MAX_STEPS,\n",
        "            img_noise_std=IMG_NOISE_STD,\n",
        "        )\n",
        "        epoch_te_history.append((vision_te, lang_te))\n",
        "\n",
        "        print(\"\\n=== Averaged TE (vision tower) ===\")\n",
        "        print(f\"#edges = {len(vision_te)} (edge ℓ = layer ℓ -> layer ℓ+1)\")\n",
        "        print([round(x, 6) for x in vision_te])\n",
        "\n",
        "        print(\"\\n=== Averaged TE (language tower) ===\")\n",
        "        print(f\"#edges = {len(lang_te)} (edge ℓ = layer ℓ -> layer ℓ+1)\")\n",
        "        print([round(x, 6) for x in lang_te])\n",
        "\n",
        "    print(\"\\nDone. `epoch_te_history` now holds TE per layer per epoch.\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "7adf1231caa24572aa7cc5809f746bd0",
            "7152f146393b4bfd966a6af81d1ad994",
            "556cb5d5d91945688917e664b88c65e5",
            "9b64ba3dbd0348a586d4a2620e56c442",
            "fd697c90fa094346a6381912ccda42cc",
            "d6d3661a1b244c80823867e569deaec0",
            "b50fb00f90424e4f8a67e53cc2f74e12",
            "9f108fb36acd4ac5895e9b537a87c065",
            "da7d5bfc871f4520988ecefbcbbeb34c",
            "46bf9df8442f4b86a22dcede25b914a6",
            "7727dbef492b4d1c82b8ec10f9895657"
          ]
        },
        "id": "p3PogFf0WQLJ",
        "outputId": "2d2ea136-38a7-4d07-b439-dab1d119647b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading LLaVA model: llava-hf/llava-1.5-7b-hf\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7adf1231caa24572aa7cc5809f746bd0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "LoRA-wrapped model.\n",
            "Trainable params: 4784128 / 7068211200 (0.0677%)\n",
            "Building COCO train + probe datasets...\n",
            "loading annotations into memory...\n",
            "Done (t=0.69s)\n",
            "creating index...\n",
            "index created!\n",
            "loading annotations into memory...\n",
            "Done (t=0.70s)\n",
            "creating index...\n",
            "index created!\n",
            "\n",
            "=== LoRA training: epoch 1/2 ===\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-1368885993.py:326: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
            "  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)\n",
            "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
            "/tmp/ipython-input-1368885993.py:347: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
            "  with torch.cuda.amp.autocast(dtype=DTYPE):\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  [Epoch 1 | Step 100] loss=7.0116\n",
            "  [Epoch 1 | Step 200] loss=5.5307\n",
            "  [Epoch 1 | Step 300] loss=4.9675\n",
            "  [Epoch 1 | Step 400] loss=4.6647\n",
            "  [Epoch 1 | Step 500] loss=4.4773\n",
            "Epoch 1 finished. Avg train loss = 4.4773\n",
            "\n",
            "Estimating TE after epoch 1 (vision + language towers)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.175469, 0.212128, 0.227631, 0.227094, 0.221401, 0.260484, 0.17297, 0.194362, 0.25686, 0.334507, 0.150285, 0.216711, 0.228454, 0.185307, 0.214263, 0.217679, 0.265438, 0.247861, 0.25104, 0.240116, 0.204173, 0.18036, 0.198366]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.498937, 0.439427, 0.467656, 0.47349, 0.462265, 0.453441, 0.363507, 0.32762, 0.297578, 0.358366, 0.37551, 0.359817, 0.375669, 0.370887, 0.375615, 0.385202, 0.403485, 0.382477, 0.412371, 0.418366, 0.409708, 0.429224, 0.419295, 0.408145, 0.390597, 0.404841, 0.366907, 0.302987, 0.343938, 0.427732, 0.428054, 0.27301]\n",
            "\n",
            "=== LoRA training: epoch 2/2 ===\n",
            "  [Epoch 2 | Step 100] loss=3.7103\n",
            "  [Epoch 2 | Step 200] loss=3.7073\n",
            "  [Epoch 2 | Step 300] loss=3.7054\n",
            "  [Epoch 2 | Step 400] loss=3.7042\n",
            "  [Epoch 2 | Step 500] loss=3.7033\n",
            "Epoch 2 finished. Avg train loss = 3.7033\n",
            "\n",
            "Estimating TE after epoch 2 (vision + language towers)...\n",
            "\n",
            "=== Averaged TE (vision tower) ===\n",
            "#edges = 24 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.0, 0.176785, 0.21216, 0.232166, 0.231098, 0.219414, 0.259877, 0.170106, 0.193804, 0.256638, 0.3332, 0.136328, 0.218012, 0.233609, 0.188534, 0.21286, 0.198716, 0.244589, 0.241365, 0.246436, 0.22373, 0.199047, 0.18697, 0.208931]\n",
            "\n",
            "=== Averaged TE (language tower) ===\n",
            "#edges = 32 (edge ℓ = layer ℓ -> layer ℓ+1)\n",
            "[0.498915, 0.484827, 0.472803, 0.473821, 0.462154, 0.44623, 0.324688, 0.299919, 0.271569, 0.344501, 0.362353, 0.33974, 0.355216, 0.352455, 0.344569, 0.365764, 0.383891, 0.360026, 0.392755, 0.397821, 0.397414, 0.417938, 0.413493, 0.398534, 0.377269, 0.398099, 0.35208, 0.284063, 0.347549, 0.435848, 0.429603, 0.277234]\n",
            "\n",
            "Done. `epoch_te_history` now holds TE per layer per epoch.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "tXuzpXdjWQTM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "x6XOSFCTWQWZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "5YtdDSQbWQaU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "AxtSo8u7WQd7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LMerMwH6WQhU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "agMGXoPwWQru"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}