{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "1b97feb2-59ec-400b-d951-a3c5e8d772ba"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XMDkrWBhLGg",
        "outputId": "60387522-6a6a-4cb8-b485-377e3154875b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [05:28<00:00, 41.2MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing val2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [20:54<00:00, 5.30MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:10<00:00, 24.4MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: CLIP ViT-B/16 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "\n",
        "# after: model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model = model.float()                     # unify everything to fp32\n",
        "for p in model.parameters():              # (LoRA gets injected later; it will also be fp32)\n",
        "    p.requires_grad = p.requires_grad     # no-op; keeps current flags\n",
        "\n",
        "\n",
        "#model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "429ed8b6-0eb3-4535-c895-fd2fc9296aa4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 335M/335M [00:11<00:00, 30.1MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 149,620,737\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the MSCOCO Data Loaders"
      ],
      "metadata": {
        "id": "h_cfc6FcltX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision import transforms, datasets\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "\n",
        "class CocoDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=None)\n",
        "        self.transform = transform\n",
        "        self.data = []\n",
        "        # Flatten (image, captions) so each item is (image, single_caption)\n",
        "        for img_idx in range(len(self.dataset)):\n",
        "            image, captions = self.dataset[img_idx]\n",
        "            for caption in captions:\n",
        "                self.data.append((img_idx, caption))\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_idx, caption = self.data[idx]\n",
        "        image, _ = self.dataset[img_idx]  # get the image\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        text = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return image, text\n",
        "\n",
        "train_img_dir = os.path.join(data_dir, 'train2014')\n",
        "train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')\n",
        "\n",
        "train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=64,\n",
        "    shuffle=True,\n",
        "    num_workers=2,\n",
        "    pin_memory=True\n",
        ")\n"
      ],
      "metadata": {
        "id": "dGosFCtHhMMi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "461b643a-da47-4d58-e68c-89c96a2a5eff"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.63s)\n",
            "creating index...\n",
            "index created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ftKgCZzre6B-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# =========================\n",
        "#  CLIP + LoRA + Cosine-Pruning\n",
        "#  (NaN-safe training + COCO eval) — prune successor blocks with high cos(z_l, z_{l+1})\n",
        "# =========================\n",
        "\n",
        "import math\n",
        "import os\n",
        "import random\n",
        "from typing import List, Tuple, Union\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import AdamW\n",
        "from torch.nn.utils import clip_grad_norm_\n",
        "\n",
        "# -------------------------\n",
        "# Utils: safe numerics\n",
        "# -------------------------\n",
        "def safe_l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:\n",
        "    return x / x.norm(dim=dim, keepdim=True).clamp_min(eps)\n",
        "\n",
        "def get_safe_logit_scale(model, clamp_low: float = -5.0, clamp_high: float = 5.0) -> torch.Tensor:\n",
        "    # Clamp BEFORE exp to avoid overflow, keep on model device\n",
        "    if hasattr(model, \"logit_scale\"):\n",
        "        return model.logit_scale.float().clamp(clamp_low, clamp_high).exp()\n",
        "    return torch.tensor(1.0, device=next(model.parameters()).device)\n",
        "\n",
        "# -------------------------\n",
        "# LoRA wrapper (MHA-compatible)\n",
        "# -------------------------\n",
        "class LoRALinear(nn.Module):\n",
        "    \"\"\"\n",
        "    Wraps an nn.Linear and exposes a read-only merged .weight/.bias\n",
        "    so nn.MultiheadAttention can read out_proj.weight directly.\n",
        "    \"\"\"\n",
        "    def __init__(self, base_linear: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):\n",
        "        super().__init__()\n",
        "        assert isinstance(base_linear, nn.Linear), \"LoRALinear expects an nn.Linear as base\"\n",
        "        self.base = base_linear\n",
        "        for p in self.base.parameters():\n",
        "            p.requires_grad = False\n",
        "\n",
        "        self.r = int(r)\n",
        "        self.scaling = float(alpha) / float(r)\n",
        "\n",
        "        dev, dt = self.base.weight.device, self.base.weight.dtype\n",
        "        self.lora_A = nn.Parameter(torch.zeros(r, self.base.in_features,  device=dev, dtype=dt))\n",
        "        self.lora_B = nn.Parameter(torch.zeros(self.base.out_features, r, device=dev, dtype=dt))\n",
        "        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n",
        "        nn.init.zeros_(self.lora_B)\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
        "\n",
        "    @property\n",
        "    def weight(self) -> torch.Tensor:\n",
        "        # Used directly by nn.MultiheadAttention\n",
        "        return self.base.weight + self.scaling * (self.lora_B @ self.lora_A)\n",
        "\n",
        "    @property\n",
        "    def bias(self):\n",
        "        return self.base.bias\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = self.base(x)\n",
        "        delta = (self.dropout(x) @ self.lora_A.t()) @ self.lora_B.t()\n",
        "        return out + self.scaling * delta\n",
        "\n",
        "# -------------------------\n",
        "# Injection / upgrade helpers\n",
        "# -------------------------\n",
        "def _wrap_or_upgrade_linear(mod: nn.Module, attr: str, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    If mod.<attr> is:\n",
        "      - nn.Linear: wrap with LoRALinear\n",
        "      - old LoRALinear (without .weight): upgrade to new LoRALinear and copy A/B\n",
        "      - new LoRALinear: leave as-is\n",
        "    \"\"\"\n",
        "    cur = getattr(mod, attr, None)\n",
        "\n",
        "    # Already the new wrapper?\n",
        "    if isinstance(cur, LoRALinear) and isinstance(getattr(type(cur), \"weight\", None), property):\n",
        "        return\n",
        "\n",
        "    # Old wrapper (no .weight property) -> upgrade\n",
        "    if isinstance(cur, nn.Module) and cur.__class__.__name__ == \"LoRALinear\" and not hasattr(cur, \"weight\"):\n",
        "        base = cur.base\n",
        "        new = LoRALinear(base, r=r, alpha=alpha, dropout=dropout)\n",
        "        if hasattr(cur, \"lora_A\") and hasattr(cur, \"lora_B\"):\n",
        "            with torch.no_grad():\n",
        "                if new.lora_A.shape == cur.lora_A.shape:\n",
        "                    new.lora_A.copy_(cur.lora_A.data)\n",
        "                if new.lora_B.shape == cur.lora_B.shape:\n",
        "                    new.lora_B.copy_(cur.lora_B.data)\n",
        "        setattr(mod, attr, new)\n",
        "        return\n",
        "\n",
        "    # Fresh wrap\n",
        "    if isinstance(cur, nn.Linear):\n",
        "        setattr(mod, attr, LoRALinear(cur, r=r, alpha=alpha, dropout=dropout))\n",
        "\n",
        "# --- helpers to iterate resblocks safely ---\n",
        "def _iter_resblocks(stack_module: nn.Module):\n",
        "    # Works for model.visual and for the text tower (model itself)\n",
        "    if hasattr(stack_module, \"transformer\") and hasattr(stack_module.transformer, \"resblocks\"):\n",
        "        return list(stack_module.transformer.resblocks)\n",
        "    return []\n",
        "\n",
        "def add_lora_to_clip_vit_b16(model, r=8, alpha=16, dropout=0.0):\n",
        "    \"\"\"\n",
        "    Add/upgrade LoRA at:\n",
        "      - Vision blocks: attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "      - Text blocks:   attn.out_proj, mlp.c_fc, mlp.c_proj\n",
        "    Skips pruned/identity blocks or any block missing the expected attrs.\n",
        "    \"\"\"\n",
        "    # Vision\n",
        "    for blk in _iter_resblocks(model.visual):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Text\n",
        "    for blk in _iter_resblocks(model):\n",
        "        attn = getattr(blk, \"attn\", None)\n",
        "        if attn is not None and hasattr(attn, \"out_proj\"):\n",
        "            _wrap_or_upgrade_linear(attn, \"out_proj\", r, alpha, dropout)\n",
        "        mlp = getattr(blk, \"mlp\", None)\n",
        "        if mlp is not None:\n",
        "            if hasattr(mlp, \"c_fc\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_fc\", r, alpha, dropout)\n",
        "            if hasattr(mlp, \"c_proj\"):\n",
        "                _wrap_or_upgrade_linear(mlp, \"c_proj\", r, alpha, dropout)\n",
        "\n",
        "    # Freeze all, then enable ONLY LoRA params (and keep logit_scale frozen)\n",
        "    for p in model.parameters():\n",
        "        p.requires_grad = False\n",
        "    if hasattr(model, \"logit_scale\") and isinstance(model.logit_scale, torch.Tensor):\n",
        "        model.logit_scale.requires_grad_(False)\n",
        "    for n, p in model.named_parameters():\n",
        "        if n.endswith(\"lora_A\") or n.endswith(\"lora_B\"):\n",
        "            p.requires_grad = True\n",
        "\n",
        "    return [p for p in model.parameters() if p.requires_grad]\n",
        "\n",
        "def _assert_outproj_has_weight(model):\n",
        "    \"\"\"\n",
        "    Only verifies blocks that actually have attn.out_proj.\n",
        "    Skips Identity/pruned blocks cleanly.\n",
        "    \"\"\"\n",
        "    def _check_tower(tower, tower_name: str):\n",
        "        for i, blk in enumerate(_iter_resblocks(tower)):\n",
        "            attn = getattr(getattr(blk, \"attn\", None), \"out_proj\", None)\n",
        "            if attn is None:\n",
        "                continue\n",
        "            w = getattr(attn, \"weight\", None)\n",
        "            if not isinstance(w, torch.Tensor):\n",
        "                raise AssertionError(\n",
        "                    f\"{tower_name} resblock {i} out_proj has no Tensor `.weight` \"\n",
        "                    f\"(got {type(w)} from {type(attn)})\"\n",
        "                )\n",
        "    _check_tower(model.visual, \"vision\")\n",
        "    _check_tower(model, \"text\")\n",
        "\n",
        "# -------------------------\n",
        "# Explicit per-block forwards (no hooks)\n",
        "# -------------------------\n",
        "def _first_param_dtype(module, default=torch.float32):\n",
        "    for p in module.parameters(recurse=True):\n",
        "        return p.dtype\n",
        "    return default\n",
        "\n",
        "@torch.no_grad()\n",
        "def _vision_block_outputs(model, images, sigma_v: float, device: str):\n",
        "    v = model.visual\n",
        "    pdev    = next(v.parameters()).device\n",
        "    v_dtype = _first_param_dtype(v, torch.float32)  # CLIP ViT blocks typically float()\n",
        "\n",
        "    x = images.to(pdev, non_blocking=True).to(v_dtype)\n",
        "    if sigma_v > 0:\n",
        "        x = x + sigma_v * torch.randn_like(x, dtype=v_dtype)\n",
        "\n",
        "    x = v.conv1(x)\n",
        "    x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)\n",
        "    x = torch.cat([v.class_embedding[None, None, :].to(v_dtype).expand(x.shape[0], 1, -1), x], dim=1)\n",
        "    x = x + v.positional_embedding.to(v_dtype)\n",
        "    x = v.ln_pre(x.to(_first_param_dtype(v.ln_pre, v_dtype)))\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in v.transformer.resblocks:\n",
        "        if x.dtype != v_dtype:\n",
        "            x = x.to(v_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())  # [B,T,D]\n",
        "    return outs\n",
        "\n",
        "@torch.no_grad()\n",
        "def _text_block_outputs(model, texts, sigma_t: float, device: str):\n",
        "    tr = model.transformer\n",
        "    tr_dtype = _first_param_dtype(tr, torch.float32)\n",
        "\n",
        "    x = model.token_embedding(texts.to(device, non_blocking=True))\n",
        "    x = x.to(tr_dtype) + model.positional_embedding.to(tr_dtype)\n",
        "    if sigma_t > 0:\n",
        "        x = x + sigma_t * torch.randn_like(x, dtype=tr_dtype)\n",
        "\n",
        "    x = x.permute(1, 0, 2).contiguous()\n",
        "    outs = []\n",
        "    for blk in tr.resblocks:\n",
        "        if x.dtype != tr_dtype:\n",
        "            x = x.to(tr_dtype)\n",
        "        x = blk(x)\n",
        "        outs.append(x.permute(1, 0, 2).contiguous())  # [B,T,D]\n",
        "    return outs\n",
        "\n",
        "# -------------------------\n",
        "# Probe loader\n",
        "# -------------------------\n",
        "def get_probe_loader(train_dataset, batch_size=64, num_batches=2):\n",
        "    from torch.utils.data import Subset, DataLoader\n",
        "    total = len(train_dataset)\n",
        "    need = min(total, batch_size * num_batches)\n",
        "    if need == 0:\n",
        "        raise ValueError(\"Empty train_dataset for probe loader.\")\n",
        "    idxs = random.sample(range(total), need)\n",
        "    probe = Subset(train_dataset, idxs)\n",
        "    bs = min(batch_size, len(probe))\n",
        "    return DataLoader(\n",
        "        probe,\n",
        "        batch_size=bs,\n",
        "        shuffle=False,\n",
        "        num_workers=2,\n",
        "        pin_memory=True,\n",
        "        drop_last=True,   # avoid B=1 batches\n",
        "    )\n",
        "\n",
        "# -------------------------\n",
        "# Pooling helpers\n",
        "# -------------------------\n",
        "def _pooled(list_of_tensors, kind: str):\n",
        "    if kind not in (\"cls\", \"mean\"):\n",
        "        raise ValueError(\"pool must be 'cls' or 'mean'\")\n",
        "    pooled = []\n",
        "    for x in list_of_tensors:  # x: [B, T, D]\n",
        "        if kind == \"cls\":\n",
        "            pooled.append(x[:, 0, :])\n",
        "        else:\n",
        "            pooled.append(x.mean(dim=1))\n",
        "    return pooled\n",
        "\n",
        "# -------------------------\n",
        "# Cosine similarity metric between adjacent blocks\n",
        "# -------------------------\n",
        "@torch.no_grad()\n",
        "def compute_cos_per_layer(\n",
        "    model,\n",
        "    images,\n",
        "    texts,\n",
        "    pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    device: str = \"cuda\",\n",
        "):\n",
        "    \"\"\"\n",
        "    Cosine similarity between consecutive blocks' pooled outputs:\n",
        "      cos_l = mean_i  cos( z_l[i], z_{l+1}[i] ), l=0..L-2\n",
        "    High cosine => near-identity transition => prune successor block (l+1).\n",
        "    \"\"\"\n",
        "    if isinstance(pool, str):\n",
        "        pool_v, pool_t = pool, pool\n",
        "    else:\n",
        "        pool_v, pool_t = pool\n",
        "\n",
        "    vouts = _vision_block_outputs(model, images, sigma_v=0.0, device=device)  # list of [B,T,D]\n",
        "    touts = _text_block_outputs(model, texts,  sigma_t=0.0, device=device)\n",
        "\n",
        "    pv = _pooled(vouts, pool_v)  # list of [B,D]\n",
        "    pt = _pooled(touts, pool_t)\n",
        "\n",
        "    def _cos_adjacent(plist):\n",
        "        eps = 1e-6\n",
        "        vals = []\n",
        "        for l in range(len(plist) - 1):\n",
        "            a, b = plist[l].float(), plist[l+1].float()       # [B,D]\n",
        "            num = (a * b).sum(dim=1)\n",
        "            den = (a.norm(dim=1) * b.norm(dim=1)).clamp_min(eps)\n",
        "            c = (num / den).clamp(-1.0, 1.0).mean().item()\n",
        "            vals.append(float(c))\n",
        "        return vals\n",
        "\n",
        "    cos_v = _cos_adjacent(pv)\n",
        "    cos_t = _cos_adjacent(pt)\n",
        "    return cos_v, cos_t  # each length L-1 (score at l maps to prune block l+1)\n",
        "\n",
        "# -------------------------\n",
        "# Pruning by cosine (highest = most redundant)\n",
        "# -------------------------\n",
        "class IdentityBlock(nn.Module):\n",
        "    def forward(self, x): return x\n",
        "\n",
        "def prune_topk_layers_by_cos(\n",
        "    model,\n",
        "    cos_v: List[float],\n",
        "    cos_t: List[float],\n",
        "    k_v: int = 4,\n",
        "    k_t: int = 4,\n",
        "    guard_skip_first_last: bool = True,\n",
        "    avoid_adjacent: bool = True,\n",
        ") -> Tuple[List[int], List[int]]:\n",
        "    \"\"\"\n",
        "    cos_v[l] corresponds to alignment between blocks l and l+1.\n",
        "    We prune the successor index (l+1) for high cosine.\n",
        "    \"\"\"\n",
        "    # Vision\n",
        "    v_blocks = model.visual.transformer.resblocks\n",
        "    V = len(v_blocks)\n",
        "    cand_v = [l+1 for l in range(V-1)]  # successors\n",
        "    if guard_skip_first_last:\n",
        "        cand_v = [i for i in cand_v if i not in (0, V-1)]\n",
        "    order_v = sorted(cand_v, key=lambda idx: cos_v[idx-1], reverse=True)\n",
        "    chosen_v = []\n",
        "    for idx in order_v:\n",
        "        if len(chosen_v) >= k_v: break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_v):\n",
        "            continue\n",
        "        chosen_v.append(idx)\n",
        "    for idx in chosen_v:\n",
        "        v_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    # Text\n",
        "    t_blocks = model.transformer.resblocks\n",
        "    T = len(t_blocks)\n",
        "    cand_t = [l+1 for l in range(T-1)]\n",
        "    if guard_skip_first_last:\n",
        "        cand_t = [i for i in cand_t if i not in (0, T-1)]\n",
        "    order_t = sorted(cand_t, key=lambda idx: cos_t[idx-1], reverse=True)\n",
        "    chosen_t = []\n",
        "    for idx in order_t:\n",
        "        if len(chosen_t) >= k_t: break\n",
        "        if avoid_adjacent and any(abs(idx - j) == 1 for j in chosen_t):\n",
        "            continue\n",
        "        chosen_t.append(idx)\n",
        "    for idx in chosen_t:\n",
        "        t_blocks[idx] = IdentityBlock()\n",
        "\n",
        "    return chosen_v, chosen_t\n",
        "\n",
        "# -------------------------\n",
        "# Contrastive loss (NaN-safe)\n",
        "# -------------------------\n",
        "def clip_contrastive_loss(model, image_features, text_features):\n",
        "    # Cast to fp32 for numerics, then normalize safely\n",
        "    image_features = safe_l2_normalize(image_features.float(), dim=-1, eps=1e-6)\n",
        "    text_features  = safe_l2_normalize(text_features.float(),  dim=-1, eps=1e-6)\n",
        "\n",
        "    # B×B cosine sims scaled by a SAFE logit scale\n",
        "    logit_scale = get_safe_logit_scale(model)  # scalar tensor on correct device\n",
        "    logits_per_image = (logit_scale * (image_features @ text_features.t()))\n",
        "    logits_per_image = torch.nan_to_num(logits_per_image, nan=0.0, posinf=1e4, neginf=-1e4)\n",
        "    logits_per_text = logits_per_image.t()\n",
        "\n",
        "    B = image_features.size(0)\n",
        "    labels = torch.arange(B, device=logits_per_image.device)\n",
        "\n",
        "    loss_i = F.cross_entropy(logits_per_image, labels)\n",
        "    loss_t = F.cross_entropy(logits_per_text,  labels)\n",
        "    return 0.5 * (loss_i + loss_t)\n",
        "\n",
        "# -------------------------\n",
        "# Training (cosine-based pruning)\n",
        "# -------------------------\n",
        "def train_with_cos_prune_lora(\n",
        "    model,\n",
        "    train_dataloader,\n",
        "    train_dataset,\n",
        "    device,\n",
        "    lora_rank=8,\n",
        "    lora_alpha=16,\n",
        "    lora_dropout=0.0,\n",
        "    lr=1e-4,\n",
        "    weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "    cos_pool: Union[str, Tuple[str, str]] = (\"cls\", \"cls\"),\n",
        "    probe_batches=2,\n",
        "):\n",
        "    # LoRA injection/upgrade\n",
        "    lora_params = add_lora_to_clip_vit_b16(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)\n",
        "    _assert_outproj_has_weight(model)\n",
        "    model.to(device)\n",
        "\n",
        "    optimizer = AdamW(lora_params, lr=lr, weight_decay=weight_decay)\n",
        "    probe_loader = get_probe_loader(train_dataset, batch_size=64, num_batches=probe_batches)\n",
        "\n",
        "    history = {\"vision_cos\": [], \"text_cos\": []}\n",
        "\n",
        "    def epoch_pass(epoch_idx):\n",
        "        model.train()\n",
        "        total_loss, num_steps = 0.0, 0\n",
        "        for images, texts in train_dataloader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "            texts  = texts.to(device,  non_blocking=True)\n",
        "\n",
        "            optimizer.zero_grad(set_to_none=True)\n",
        "            img_feat = model.encode_image(images)\n",
        "            txt_feat = model.encode_text(texts)\n",
        "            loss = clip_contrastive_loss(model, img_feat, txt_feat)\n",
        "\n",
        "            if not torch.isfinite(loss):\n",
        "                with torch.no_grad():\n",
        "                    bad = {\n",
        "                        \"loss\": float(loss.detach().float().item()),\n",
        "                        \"img_feat_has_nan\": bool(torch.isnan(img_feat).any().item()),\n",
        "                        \"txt_feat_has_nan\": bool(torch.isnan(txt_feat).any().item()),\n",
        "                    }\n",
        "                print(\"[Warn] Skipping non-finite batch:\", bad)\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            clip_grad_norm_(lora_params, max_norm=1.0)\n",
        "            optimizer.step()\n",
        "\n",
        "            total_loss += float(loss.item())\n",
        "            num_steps += 1\n",
        "\n",
        "        avg_loss = total_loss / max(1, num_steps)\n",
        "\n",
        "        # --- Cosine on probe (eval mode) ---\n",
        "        model.eval()\n",
        "        cos_v_all, cos_t_all = [], []\n",
        "        with torch.no_grad():\n",
        "            for images, texts in probe_loader:\n",
        "                images = images.to(device, non_blocking=True)\n",
        "                texts  = texts.to(device,  non_blocking=True)\n",
        "                cv, ct = compute_cos_per_layer(model, images, texts, pool=cos_pool, device=device)\n",
        "                cos_v_all.append(torch.tensor(cv))\n",
        "                cos_t_all.append(torch.tensor(ct))\n",
        "\n",
        "        cos_v_mean = torch.stack(cos_v_all).nanmean(dim=0).tolist()\n",
        "        cos_t_mean = torch.stack(cos_t_all).nanmean(dim=0).tolist()\n",
        "\n",
        "        history[\"vision_cos\"].append(cos_v_mean)\n",
        "        history[\"text_cos\"].append(cos_t_mean)\n",
        "\n",
        "        print(f\"[Epoch {epoch_idx}] loss={avg_loss:.6f}\")\n",
        "        print(\"  Vision COS per layer (cos(z_l, z_{l+1})):\")\n",
        "        print(\"   \", [round(x, 6) for x in cos_v_mean])\n",
        "        print(\"  Text   COS per layer (cos(z_l, z_{l+1})):\")\n",
        "        print(\"   \", [round(x, 6) for x in cos_t_mean])\n",
        "\n",
        "\n",
        "    # --- Phase 1 (warmup before pruning) ---\n",
        "    for ep in range(1, epochs_before_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    # --- Prune based on highest cosine (redundant transitions) ---\n",
        "    last_cos_v = history[\"vision_cos\"][-1]\n",
        "    last_cos_t = history[\"text_cos\"][-1]\n",
        "    chosen_v, chosen_t = prune_topk_layers_by_cos(\n",
        "        model, last_cos_v, last_cos_t,\n",
        "        k_v=prune_k_vision, k_t=prune_k_text,\n",
        "        guard_skip_first_last=True, avoid_adjacent=True\n",
        "    )\n",
        "    print(f\"Pruned Vision blocks at indices: {chosen_v} (Identity)\")\n",
        "    print(f\"Pruned Text   blocks at indices: {chosen_t} (Identity)\")\n",
        "\n",
        "    # --- Phase 2 (continue training) ---\n",
        "    for ep in range(epochs_before_prune + 1, epochs_before_prune + epochs_after_prune + 1):\n",
        "        epoch_pass(ep)\n",
        "\n",
        "    return history\n",
        "\n",
        "# ======================================================\n",
        "# OPTIONAL: COCO Retrieval Evaluation (I2T / T2I)\n",
        "# ======================================================\n",
        "def evaluate_coco_retrieval(model, data_dir: str, batch_size: int = 64, input_resolution: int = 224, context_length: int = 77):\n",
        "    \"\"\"\n",
        "    Evaluate trained model on COCO val2014 captions.\n",
        "    Assumes `pip install pycocotools` and images/annotations present in data_dir.\n",
        "    \"\"\"\n",
        "    import clip\n",
        "    import numpy as np\n",
        "    from torchvision import datasets, transforms\n",
        "    from torchvision.transforms import InterpolationMode\n",
        "    from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "    device = next(model.parameters()).device\n",
        "\n",
        "    eval_transform = transforms.Compose([\n",
        "        transforms.Resize((input_resolution, input_resolution), interpolation=InterpolationMode.BICUBIC),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                             std=(0.26862954, 0.26130258, 0.27577711))\n",
        "    ])\n",
        "\n",
        "    class CocoEvalDataset(Dataset):\n",
        "        def __init__(self, root, annFile, transform=None):\n",
        "            self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "            self.transform = transform\n",
        "        def __len__(self):\n",
        "            return len(self.dataset)\n",
        "        def __getitem__(self, idx):\n",
        "            image, captions = self.dataset[idx]\n",
        "            return image, captions\n",
        "\n",
        "    def coco_collate_fn(batch):\n",
        "        images = [img for img, _ in batch]\n",
        "        captions = [caps for _, caps in batch]  # list[list[str]]\n",
        "        images = torch.stack(images, dim=0)\n",
        "        return images, captions\n",
        "\n",
        "    val_img_dir  = os.path.join(data_dir, 'val2014')\n",
        "    val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "    val_dataset  = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2,\n",
        "                              pin_memory=torch.cuda.is_available(), collate_fn=coco_collate_fn)\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    all_image_features = []\n",
        "    all_text_features  = []\n",
        "    image_to_text_indices = []\n",
        "    all_captions_flat = []\n",
        "\n",
        "    with torch.no_grad():\n",
        "        text_count = 0\n",
        "        for images, batch_captions in val_loader:\n",
        "            images = images.to(device, non_blocking=True)\n",
        "\n",
        "            image_feats = model.encode_image(images)\n",
        "            image_feats = safe_l2_normalize(image_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_image_features.append(image_feats.cpu())\n",
        "\n",
        "            flat_captions = []\n",
        "            image_to_text_map_for_batch = []\n",
        "            for caps in batch_captions:\n",
        "                start_idx = text_count + len(flat_captions)\n",
        "                flat_captions.extend(caps)\n",
        "                end_idx = text_count + len(flat_captions)\n",
        "                image_to_text_map_for_batch.append((start_idx, end_idx))\n",
        "\n",
        "            if len(flat_captions) == 0:\n",
        "                image_to_text_indices.extend([[] for _ in batch_captions])\n",
        "                continue\n",
        "\n",
        "            texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "            text_feats = model.encode_text(texts)\n",
        "            text_feats = safe_l2_normalize(text_feats.float(), dim=-1, eps=1e-6)\n",
        "            all_text_features.append(text_feats.cpu())\n",
        "            all_captions_flat.extend(flat_captions)\n",
        "\n",
        "            for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "                image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "            text_count += len(flat_captions)\n",
        "\n",
        "    all_image_features = torch.cat(all_image_features, dim=0)  # [N_img, D]\n",
        "    all_text_features  = torch.cat(all_text_features,  dim=0)  # [N_txt, D]\n",
        "\n",
        "    sim_matrix = all_image_features @ all_text_features.t()    # [N_img, N_txt]\n",
        "\n",
        "    def compute_recall_with_multiple_captions(sim_matrix: torch.Tensor,\n",
        "                                              image_to_text_indices,\n",
        "                                              k: int = 1) -> float:\n",
        "        n = sim_matrix.size(0)\n",
        "        successes = 0\n",
        "        for i in range(n):\n",
        "            scores = sim_matrix[i]\n",
        "            sorted_idx = torch.argsort(scores, descending=True)\n",
        "            correct = image_to_text_indices[i]\n",
        "            if not correct:\n",
        "                continue\n",
        "            min_rank = None\n",
        "            for cidx in correct:\n",
        "                pos = (sorted_idx == cidx).nonzero(as_tuple=False)\n",
        "                if pos.numel() > 0:\n",
        "                    rank = int(pos[0, 0].item())\n",
        "                    if (min_rank is None) or (rank < min_rank):\n",
        "                        min_rank = rank\n",
        "            if (min_rank is not None) and (min_rank < k):\n",
        "                successes += 1\n",
        "        return successes / max(1, n)\n",
        "\n",
        "    # I2T\n",
        "    r1  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "    r5  = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "    r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "    print(\"Image-to-Text Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "    # T2I\n",
        "    text_to_image = [None] * all_text_features.size(0)\n",
        "    for img_idx, tinds in enumerate(image_to_text_indices):\n",
        "        for t in tinds:\n",
        "            text_to_image[t] = img_idx\n",
        "\n",
        "    sim_matrix_t2i = sim_matrix.t()  # [N_txt, N_img]\n",
        "\n",
        "    def compute_recall_text_to_image(sim_matrix_t2i: torch.Tensor,\n",
        "                                     text_to_image_list,\n",
        "                                     k: int = 1) -> float:\n",
        "        m = sim_matrix_t2i.size(0)\n",
        "        successes = 0\n",
        "        for j in range(m):\n",
        "            scores = sim_matrix_t2i[j]\n",
        "            sorted_indices = torch.argsort(scores, descending=True)\n",
        "            correct_img = text_to_image_list[j]\n",
        "            if correct_img is None:\n",
        "                continue\n",
        "            pos = (sorted_indices == correct_img).nonzero(as_tuple=False)\n",
        "            if pos.numel() > 0 and int(pos[0, 0].item()) < k:\n",
        "                successes += 1\n",
        "        return successes / max(1, m)\n",
        "\n",
        "    r1_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "    r5_t2i  = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "    r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "    print(\"Text-to-Image Retrieval:\")\n",
        "    print(f\"Recall@1:  {r1_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@5:  {r5_t2i*100:.2f}%\")\n",
        "    print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n",
        "\n"
      ],
      "metadata": {
        "id": "ashw2_19e6k6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- quick run (warmup -> cosine pruning -> continue) ---\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "history = train_with_cos_prune_lora(\n",
        "    model=model,\n",
        "    train_dataloader=train_dataloader,\n",
        "    train_dataset=train_dataset,\n",
        "    device=device,\n",
        "    lora_rank=8, lora_alpha=16, lora_dropout=0.0,\n",
        "    lr=1e-4, weight_decay=0.0,\n",
        "    epochs_before_prune=2,\n",
        "    epochs_after_prune=3,\n",
        "    prune_k_vision=4,\n",
        "    prune_k_text=4,\n",
        "    cos_pool=\"mean\",        # or (\"cls\",\"cls\")\n",
        "    probe_batches=2,\n",
        ")"
      ],
      "metadata": {
        "id": "etHTsgFsNDdf",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8bc2965d-d099-4a4a-a44b-215e49584265"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[Epoch 1] loss=0.161243\n",
            "  Vision COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.65502, 0.742011, 0.795352, 0.724441, 0.780022, 0.757468, 0.847595, 0.878538, 0.885164, 0.884294, 0.162023]\n",
            "  Text   COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.973751, 0.941971, 0.922186, 0.978636, 0.984134, 0.981334, 0.973206, 0.974065, 0.949328, 0.957886, 0.663371]\n",
            "[Epoch 2] loss=0.126116\n",
            "  Vision COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.639934, 0.720125, 0.778841, 0.711403, 0.764447, 0.757799, 0.838901, 0.874032, 0.879603, 0.878873, 0.170623]\n",
            "  Text   COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.972238, 0.941775, 0.925949, 0.981238, 0.983476, 0.979199, 0.974325, 0.973078, 0.948275, 0.957019, 0.677478]\n",
            "Pruned Vision blocks at indices: [9, 7, 3, 5] (Identity)\n",
            "Pruned Text   blocks at indices: [5, 7, 1, 10] (Identity)\n",
            "[Epoch 3] loss=0.614583\n",
            "  Vision COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.627279, 0.572, 1.0, 0.595354, 1.0, 0.624992, 1.0, 0.748053, 1.0, 0.773061, 0.050909]\n",
            "  Text   COS per layer (cos(z_l, z_{l+1})):\n",
            "    [1.0, 0.938676, 0.910933, 0.97559, 1.0, 0.97127, 1.0, 0.936413, 0.946688, 1.0, 0.718774]\n",
            "[Epoch 4] loss=0.342055\n",
            "  Vision COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.627584, 0.559396, 1.0, 0.593829, 1.0, 0.648783, 1.0, 0.768954, 1.0, 0.792982, 0.065442]\n",
            "  Text   COS per layer (cos(z_l, z_{l+1})):\n",
            "    [1.0, 0.936281, 0.919301, 0.975151, 1.0, 0.972927, 1.0, 0.944136, 0.949046, 1.0, 0.706562]\n",
            "[Epoch 5] loss=0.275506\n",
            "  Vision COS per layer (cos(z_l, z_{l+1})):\n",
            "    [0.627438, 0.566216, 1.0, 0.592625, 1.0, 0.622518, 1.0, 0.767353, 1.0, 0.790749, 0.039474]\n",
            "  Text   COS per layer (cos(z_l, z_{l+1})):\n",
            "    [1.0, 0.940836, 0.921211, 0.975797, 1.0, 0.971567, 1.0, 0.944952, 0.953472, 1.0, 0.713215]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "evaluate_coco_retrieval(model, data_dir, batch_size=64)\n"
      ],
      "metadata": {
        "id": "cBmSHD9-kipW",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "717f841b-6cad-4809-b472-d5633ce8f86d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.34s)\n",
            "creating index...\n",
            "index created!\n",
            "Image-to-Text Retrieval:\n",
            "Recall@1:  12.02%\n",
            "Recall@5:  28.49%\n",
            "Recall@10: 38.18%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1:  7.57%\n",
            "Recall@5:  20.03%\n",
            "Recall@10: 28.37%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "RPv7HIqxkjON"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LQXgZ_alw9gc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "d42i-P_Yrq-Y"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}