{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "b6b9c0b3-28fc-49e0-f7a3-0cd81039417f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m76.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m85.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m48.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m39.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m96.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XMDkrWBhLGg",
        "outputId": "2824016d-75d5-4369-ac49-defb16788922"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Processing train2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading train2014.zip: 100%|██████████| 12.6G/12.6G [15:34<00:00, 14.5MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train2014 downloaded and extracted.\n",
            "Processing val2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [06:39<00:00, 16.6MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:21<00:00, 11.9MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: CLIP RN50 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"RN50\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e1c7f69f-daea-4005-b6e0-80a1b4eec4d7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 244M/244M [00:20<00:00, 12.7MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 102,007,137\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the Student Model (ResNet-18)"
      ],
      "metadata": {
        "id": "GifHBhlWlr29"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch.nn as nn\n",
        "import torchvision.models as models\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "# Student Image Encoder (ResNet-18)\n",
        "class StudentImageEncoder(nn.Module):\n",
        "    def __init__(self, output_dim):\n",
        "        super(StudentImageEncoder, self).__init__()\n",
        "        self.encoder = models.resnet18(pretrained=True)\n",
        "        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.encoder(x)\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize\n",
        "        return x\n",
        "\n",
        "\n",
        "class StudentTextEncoder(nn.Module):\n",
        "    def __init__(self, vocab_size, context_length, output_dim):\n",
        "        super(StudentTextEncoder, self).__init__()\n",
        "        self.token_embedding = nn.Embedding(vocab_size, output_dim)\n",
        "        self.positional_embedding = nn.Parameter(torch.zeros(context_length, output_dim))\n",
        "        nn.init.normal_(self.positional_embedding, std=0.01)\n",
        "        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=8)\n",
        "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)\n",
        "        self.ln_final = nn.LayerNorm(output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x shape: (batch_size, seq_len)\n",
        "        x = self.token_embedding(x) + self.positional_embedding  # (batch_size, seq_len, output_dim)\n",
        "        x = x.permute(1, 0, 2)  # (seq_len, batch_size, output_dim)\n",
        "        x = self.transformer(x)\n",
        "        x = x.permute(1, 0, 2)  # (batch_size, seq_len, output_dim)\n",
        "        x = self.ln_final(x)\n",
        "        x = x.mean(dim=1)  # Mean pooling over the sequence length\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize to unit length\n",
        "        return x  # (batch_size, output_dim)\n",
        "\n"
      ],
      "metadata": {
        "id": "Bu5lXd8ehL2K"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Prepare the MSCOCO Data Loaders"
      ],
      "metadata": {
        "id": "h_cfc6FcltX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision import transforms, datasets\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "\n",
        "class CocoDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=None)\n",
        "        self.transform = transform\n",
        "        self.data = []\n",
        "        # Flatten (image, captions) so each item is (image, single_caption)\n",
        "        for img_idx in range(len(self.dataset)):\n",
        "            image, captions = self.dataset[img_idx]\n",
        "            for caption in captions:\n",
        "                self.data.append((img_idx, caption))\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        img_idx, caption = self.data[idx]\n",
        "        image, _ = self.dataset[img_idx]  # get the image\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "        text = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return image, text\n",
        "\n",
        "train_img_dir = os.path.join(data_dir, 'train2014')\n",
        "train_ann_file = os.path.join(data_dir, 'annotations', 'captions_train2014.json')\n",
        "\n",
        "train_dataset = CocoDataset(root=train_img_dir, annFile=train_ann_file, transform=transform)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_dataset,\n",
        "    batch_size=64,\n",
        "    shuffle=True,\n",
        "    num_workers=2,\n",
        "    pin_memory=True\n",
        ")\n",
        "\n"
      ],
      "metadata": {
        "id": "dGosFCtHhMMi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5207f886-c948-4ace-c026-963b6dd60f09"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.60s)\n",
            "creating index...\n",
            "index created!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the Contrastive Loss Function Using Logit + KL + MSE - TE"
      ],
      "metadata": {
        "id": "KijNXkynlvLL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def contrastive_loss_with_kl_l2_te(\n",
        "    student_image_features,\n",
        "    student_text_features,\n",
        "    teacher_image_features,\n",
        "    teacher_text_features,\n",
        "    temperature=0.07,\n",
        "    alpha=1.0,  # weight for KL term\n",
        "    beta=50,   # weight for L2 term\n",
        "    gamma=1  # weight for TE rewards\n",
        "):\n",
        "    eps = 1e-6\n",
        "\n",
        "    # Normalize features\n",
        "    student_image_features = student_image_features / student_image_features.norm(dim=-1, keepdim=True)\n",
        "    student_text_features = student_text_features / student_text_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_image_features = teacher_image_features / teacher_image_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_text_features = teacher_text_features / teacher_text_features.norm(dim=-1, keepdim=True)\n",
        "\n",
        "    # Compute student logits\n",
        "    logits_per_image_student = student_image_features @ student_text_features.t() / temperature\n",
        "    logits_per_text_student = logits_per_image_student.t()\n",
        "\n",
        "    # Compute teacher logits (no gradients)\n",
        "    with torch.no_grad():\n",
        "        logits_per_image_teacher = teacher_image_features @ teacher_text_features.t() / temperature\n",
        "        logits_per_text_teacher = logits_per_image_teacher.t()\n",
        "\n",
        "    # Contrastive loss\n",
        "    batch_size = student_image_features.size(0)\n",
        "    labels = torch.arange(batch_size, device=student_image_features.device)\n",
        "    loss_image = F.cross_entropy(logits_per_image_student, labels)\n",
        "    loss_text = F.cross_entropy(logits_per_text_student, labels)\n",
        "    contrastive_loss = (loss_image + loss_text) / 2\n",
        "\n",
        "    # ---------------------\n",
        "    # Interactive Contrastive Learning (ICL)\n",
        "    # ---------------------\n",
        "    # Image-to-text ICL loss: use student image features as anchor and teacher text features as negatives\n",
        "    icl_logits_I_to_T = student_image_features @ teacher_text_features.t() / temperature\n",
        "    icl_loss_I_to_T = F.cross_entropy(icl_logits_I_to_T, labels)\n",
        "\n",
        "    # Text-to-image ICL loss: use student text features as anchor and teacher image features as negatives\n",
        "    icl_logits_T_to_I = student_text_features @ teacher_image_features.t() / temperature\n",
        "    icl_loss_T_to_I = F.cross_entropy(icl_logits_T_to_I, labels)\n",
        "\n",
        "    # Total ICL loss is the average of both directions\n",
        "    icl_loss = (icl_loss_I_to_T + icl_loss_T_to_I) / 2\n",
        "\n",
        "\n",
        "\n",
        "    # KL-divergence loss\n",
        "    student_img_log_probs = F.log_softmax(logits_per_image_student, dim=-1)\n",
        "    teacher_img_probs = F.softmax(logits_per_image_teacher, dim=-1)\n",
        "    kl_img = F.kl_div(student_img_log_probs, teacher_img_probs, reduction='batchmean')\n",
        "\n",
        "    student_txt_log_probs = F.log_softmax(logits_per_text_student, dim=-1)\n",
        "    teacher_txt_probs = F.softmax(logits_per_text_teacher, dim=-1)\n",
        "    kl_txt = F.kl_div(student_txt_log_probs, teacher_txt_probs, reduction='batchmean')\n",
        "\n",
        "    kl_loss = (kl_img + kl_txt) / 2\n",
        "\n",
        "    # L2 distance loss\n",
        "    l2_img = F.mse_loss(student_image_features, teacher_image_features)\n",
        "    l2_txt = F.mse_loss(student_text_features, teacher_text_features)\n",
        "    l2_loss = (l2_img + l2_txt) / 2\n",
        "\n",
        "\n",
        "\n",
        "    # ---------------------\n",
        "    # Cosine Similarity TE Surrogate\n",
        "    # ---------------------\n",
        "    def cosine_te(student_features, teacher_features):\n",
        "        \"\"\"\n",
        "        Approximates transfer entropy by computing the cosine similarity between\n",
        "        the differences (i.e., directional changes) of consecutive embeddings.\n",
        "        A higher cosine similarity indicates that the student is following the teacher's direction.\n",
        "        \"\"\"\n",
        "        # Compute differences along the batch dimension (assumes batch ordering approximates temporal ordering)\n",
        "        student_diff = student_features[1:] - student_features[:-1]\n",
        "        teacher_diff = teacher_features[1:] - teacher_features[:-1]\n",
        "        cos_sim = F.cosine_similarity(student_diff, teacher_diff, dim=-1, eps=eps)\n",
        "        return cos_sim.mean()\n",
        "\n",
        "    te_img = cosine_te(student_image_features, teacher_image_features)\n",
        "    te_txt = cosine_te(student_text_features, teacher_text_features)\n",
        "    te1 = (te_img + te_txt) / 2\n",
        "\n",
        "\n",
        "    # ---------------------\n",
        "    # Cosine Similarity TE Surrogate on Concatenated Differences\n",
        "    # ---------------------\n",
        "    # Compute differences along the batch dimension (assumes batch ordering approximates temporal ordering)\n",
        "    student_diff_img = student_image_features[1:] - student_image_features[:-1]\n",
        "    teacher_diff_img = teacher_image_features[1:] - teacher_image_features[:-1]\n",
        "    student_diff_txt = student_text_features[1:] - student_text_features[:-1]\n",
        "    teacher_diff_txt = teacher_text_features[1:] - teacher_text_features[:-1]\n",
        "\n",
        "    # Concatenate differences from image and text modalities\n",
        "    student_diff_cat = torch.cat([student_diff_img, student_diff_txt], dim=-1)\n",
        "    teacher_diff_cat = torch.cat([teacher_diff_img, teacher_diff_txt], dim=-1)\n",
        "\n",
        "    # Compute cosine similarity between concatenated difference vectors\n",
        "    te2 = F.cosine_similarity(student_diff_cat, teacher_diff_cat, dim=-1, eps=eps).mean()\n",
        "\n",
        "\n",
        "    # Combine all losses\n",
        "    # Increase synergy by subtracting gamma * synergy\n",
        "    # Maximize redundancy by subtracting epsilon * redundancy\n",
        "    total_loss = (contrastive_loss + alpha * kl_loss + beta * l2_loss + icl_loss - gamma * te1 - gamma *te2)\n",
        "    #total_loss = (contrastive_loss + alpha * kl_loss + icl_loss - gamma * te)\n",
        "\n",
        "\n",
        "    return total_loss, kl_loss, l2_loss, icl_loss, te1, te2\n"
      ],
      "metadata": {
        "id": "adWN2SVlXyrb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set Up the Training Loop"
      ],
      "metadata": {
        "id": "BijGiufjlw8K"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Instantiate student models\n",
        "student_image_encoder = StudentImageEncoder(output_dim=1024).to(device)\n",
        "student_text_encoder = StudentTextEncoder(vocab_size, context_length, output_dim=1024).to(device)\n",
        "\n",
        "# Define optimizer\n",
        "optimizer = torch.optim.Adam(\n",
        "    list(student_image_encoder.parameters()) + list(student_text_encoder.parameters()),\n",
        "    lr=1e-4\n",
        ")\n"
      ],
      "metadata": {
        "id": "UyA-cl0LhM6g",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c321cc15-ee6e-4c6f-8688-4a67fae6396a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n",
            "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
            "100%|██████████| 44.7M/44.7M [00:00<00:00, 221MB/s]\n",
            "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
            "  warnings.warn(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train the Student Model"
      ],
      "metadata": {
        "id": "0isNcb4ClyWa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Training Loop\n",
        "num_epochs = 10  # the number of epochs\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    student_image_encoder.train()\n",
        "    student_text_encoder.train()\n",
        "\n",
        "    # Initialize trackers for average metrics\n",
        "    total_loss = 0.0\n",
        "    total_kl_loss = 0.0\n",
        "    total_l2_loss = 0.0\n",
        "    total_icl_loss = 0.0\n",
        "    total_te1 = 0.0\n",
        "    total_te2 = 0.0\n",
        "\n",
        "    total_batches = len(train_dataloader)\n",
        "\n",
        "    for batch_idx, (images, texts) in enumerate(train_dataloader):\n",
        "        images = images.to(device)\n",
        "        texts = texts.to(device)\n",
        "\n",
        "        # Teacher outputs\n",
        "        with torch.no_grad():\n",
        "            teacher_image_features = model.encode_image(images)\n",
        "            teacher_text_features = model.encode_text(texts)\n",
        "\n",
        "        # Student outputs\n",
        "        student_image_features = student_image_encoder(images).to(teacher_image_features.dtype)\n",
        "        student_text_features = student_text_encoder(texts).to(teacher_text_features.dtype)\n",
        "\n",
        "        # Compute Contrastive Loss with detailed returns\n",
        "        loss, kl_loss, l2_loss, icl_loss, te1, te2 = contrastive_loss_with_kl_l2_te(\n",
        "            student_image_features,\n",
        "            student_text_features,\n",
        "            teacher_image_features,\n",
        "            teacher_text_features\n",
        "        )\n",
        "\n",
        "        # Backpropagation\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Update trackers\n",
        "        total_loss += loss.item()\n",
        "        total_kl_loss += kl_loss.item()\n",
        "        total_l2_loss += l2_loss.item()\n",
        "        total_icl_loss += icl_loss.item()\n",
        "        total_te1 += te1.item()\n",
        "        total_te2 += te2.item()\n",
        "\n",
        "        if batch_idx % 1000 == 0:\n",
        "            print(\n",
        "                f\"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_dataloader)}], \"\n",
        "                f\"Loss: {loss.item():.4f}, KL Loss: {kl_loss:.4f}, L2 Loss: {l2_loss:.6f}, ICL Loss: {icl_loss:.4f}, \"\n",
        "                f\"TE1: {te1:.4f}, TE2: {te2:.4f}\"\n",
        "            )\n",
        "\n",
        "    # Compute average metrics for the epoch\n",
        "    avg_loss = total_loss / total_batches\n",
        "    avg_kl_loss = total_kl_loss / total_batches\n",
        "    avg_l2_loss = total_l2_loss / total_batches\n",
        "    avg_icl_loss = total_icl_loss / total_batches\n",
        "    avg_te1 = total_te1 / total_batches\n",
        "    avg_te2 = total_te2 / total_batches\n",
        "\n",
        "#    avg_synergy_diff = total_synergy_diff / total_batches\n",
        "\n",
        "    # Print epoch-level metrics\n",
        "    print(f\"Epoch [{epoch+1}/{num_epochs}] Averages -> \"\n",
        "          f\"Loss: {avg_loss:.4f}, KL Loss: {avg_kl_loss:.4f}, L2 Loss: {avg_l2_loss:.6f}, ICL Loss: {avg_icl_loss:.4f}, \"\n",
        "          f\"TE1: {avg_te1:.4f}, TE2: {avg_te2:.4f}\")\n"
      ],
      "metadata": {
        "id": "OIO4zuMbXdbY",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "27c44e8d-7259-432b-f65a-b0a4860c0f6f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch [1/10], Step [0/6471], Loss: 8.7500, KL Loss: 0.2915, L2 Loss: 0.001919, ICL Loss: 4.1914, TE1: 0.0034, TE2: 0.0010\n",
            "Epoch [1/10], Step [1000/6471], Loss: 2.4297, KL Loss: 0.5347, L2 Loss: 0.001184, ICL Loss: 1.2383, TE1: 0.4287, TE2: 0.4260\n",
            "Epoch [1/10], Step [2000/6471], Loss: 2.5020, KL Loss: 0.6318, L2 Loss: 0.001144, ICL Loss: 1.2910, TE1: 0.4526, TE2: 0.4507\n",
            "Epoch [1/10], Step [3000/6471], Loss: 1.8604, KL Loss: 0.6064, L2 Loss: 0.001091, ICL Loss: 1.0527, TE1: 0.4819, TE2: 0.4814\n",
            "Epoch [1/10], Step [4000/6471], Loss: 1.5469, KL Loss: 0.5938, L2 Loss: 0.001077, ICL Loss: 0.9282, TE1: 0.5049, TE2: 0.5034\n",
            "Epoch [1/10], Step [5000/6471], Loss: 1.6836, KL Loss: 0.5483, L2 Loss: 0.001052, ICL Loss: 0.9805, TE1: 0.5171, TE2: 0.5161\n",
            "Epoch [1/10], Step [6000/6471], Loss: 1.5410, KL Loss: 0.4868, L2 Loss: 0.001041, ICL Loss: 0.9766, TE1: 0.5303, TE2: 0.5293\n",
            "Epoch [1/10] Averages -> Loss: 2.1167, KL Loss: 0.5647, L2 Loss: 0.001130, ICL Loss: 1.1566, TE1: 0.4667, TE2: 0.4651\n",
            "Epoch [2/10], Step [0/6471], Loss: 1.4453, KL Loss: 0.6133, L2 Loss: 0.001040, ICL Loss: 0.8711, TE1: 0.5156, TE2: 0.5151\n",
            "Epoch [2/10], Step [1000/6471], Loss: 1.6602, KL Loss: 0.5503, L2 Loss: 0.001038, ICL Loss: 0.9829, TE1: 0.5137, TE2: 0.5112\n",
            "Epoch [2/10], Step [2000/6471], Loss: 1.4785, KL Loss: 0.5967, L2 Loss: 0.001055, ICL Loss: 0.9028, TE1: 0.5200, TE2: 0.5195\n",
            "Epoch [2/10], Step [3000/6471], Loss: 1.3145, KL Loss: 0.5859, L2 Loss: 0.001007, ICL Loss: 0.8618, TE1: 0.5474, TE2: 0.5454\n",
            "Epoch [2/10], Step [4000/6471], Loss: 1.4629, KL Loss: 0.5376, L2 Loss: 0.001046, ICL Loss: 0.8975, TE1: 0.5034, TE2: 0.5020\n",
            "Epoch [2/10], Step [5000/6471], Loss: 1.1943, KL Loss: 0.5273, L2 Loss: 0.001007, ICL Loss: 0.8232, TE1: 0.5557, TE2: 0.5547\n",
            "Epoch [2/10], Step [6000/6471], Loss: 1.2656, KL Loss: 0.5693, L2 Loss: 0.001006, ICL Loss: 0.8154, TE1: 0.5410, TE2: 0.5405\n",
            "Epoch [2/10] Averages -> Loss: 1.3827, KL Loss: 0.5696, L2 Loss: 0.001024, ICL Loss: 0.8820, TE1: 0.5340, TE2: 0.5324\n",
            "Epoch [3/10], Step [0/6471], Loss: 1.2305, KL Loss: 0.5332, L2 Loss: 0.000986, ICL Loss: 0.8584, TE1: 0.5615, TE2: 0.5591\n",
            "Epoch [3/10], Step [1000/6471], Loss: 1.2998, KL Loss: 0.5811, L2 Loss: 0.001001, ICL Loss: 0.8545, TE1: 0.5430, TE2: 0.5400\n",
            "Epoch [3/10], Step [2000/6471], Loss: 1.2012, KL Loss: 0.5542, L2 Loss: 0.000998, ICL Loss: 0.8301, TE1: 0.5537, TE2: 0.5518\n",
            "Epoch [3/10], Step [3000/6471], Loss: 1.1777, KL Loss: 0.5166, L2 Loss: 0.000974, ICL Loss: 0.8359, TE1: 0.5630, TE2: 0.5601\n",
            "Epoch [3/10], Step [4000/6471], Loss: 1.2725, KL Loss: 0.6016, L2 Loss: 0.001030, ICL Loss: 0.8252, TE1: 0.5337, TE2: 0.5322\n",
            "Epoch [3/10], Step [5000/6471], Loss: 1.0078, KL Loss: 0.5615, L2 Loss: 0.000968, ICL Loss: 0.7285, TE1: 0.5693, TE2: 0.5698\n",
            "Epoch [3/10], Step [6000/6471], Loss: 1.1699, KL Loss: 0.5449, L2 Loss: 0.000999, ICL Loss: 0.7905, TE1: 0.5562, TE2: 0.5542\n",
            "Epoch [3/10] Averages -> Loss: 1.1853, KL Loss: 0.5670, L2 Loss: 0.000994, ICL Loss: 0.8127, TE1: 0.5539, TE2: 0.5523\n",
            "Epoch [4/10], Step [0/6471], Loss: 0.9990, KL Loss: 0.6143, L2 Loss: 0.000967, ICL Loss: 0.7437, TE1: 0.5625, TE2: 0.5615\n",
            "Epoch [4/10], Step [1000/6471], Loss: 1.2988, KL Loss: 0.5830, L2 Loss: 0.001000, ICL Loss: 0.8853, TE1: 0.5508, TE2: 0.5483\n",
            "Epoch [4/10], Step [2000/6471], Loss: 0.9331, KL Loss: 0.5781, L2 Loss: 0.000968, ICL Loss: 0.7031, TE1: 0.5674, TE2: 0.5659\n",
            "Epoch [4/10], Step [3000/6471], Loss: 1.0430, KL Loss: 0.5674, L2 Loss: 0.001005, ICL Loss: 0.7598, TE1: 0.5557, TE2: 0.5532\n",
            "Epoch [4/10], Step [4000/6471], Loss: 1.0156, KL Loss: 0.5405, L2 Loss: 0.000964, ICL Loss: 0.7588, TE1: 0.5752, TE2: 0.5752\n",
            "Epoch [4/10], Step [5000/6471], Loss: 0.9512, KL Loss: 0.5664, L2 Loss: 0.000970, ICL Loss: 0.7148, TE1: 0.5703, TE2: 0.5703\n",
            "Epoch [4/10], Step [6000/6471], Loss: 1.0293, KL Loss: 0.5859, L2 Loss: 0.000965, ICL Loss: 0.7461, TE1: 0.5625, TE2: 0.5601\n",
            "Epoch [4/10] Averages -> Loss: 1.0736, KL Loss: 0.5651, L2 Loss: 0.000975, ICL Loss: 0.7732, TE1: 0.5659, TE2: 0.5644\n",
            "Epoch [5/10], Step [0/6471], Loss: 0.9150, KL Loss: 0.6665, L2 Loss: 0.000982, ICL Loss: 0.7109, TE1: 0.5703, TE2: 0.5693\n",
            "Epoch [5/10], Step [1000/6471], Loss: 1.0039, KL Loss: 0.5977, L2 Loss: 0.000982, ICL Loss: 0.7178, TE1: 0.5557, TE2: 0.5522\n",
            "Epoch [5/10], Step [2000/6471], Loss: 0.8433, KL Loss: 0.5708, L2 Loss: 0.000949, ICL Loss: 0.6851, TE1: 0.5889, TE2: 0.5874\n",
            "Epoch [5/10], Step [3000/6471], Loss: 1.0859, KL Loss: 0.5581, L2 Loss: 0.000974, ICL Loss: 0.7930, TE1: 0.5698, TE2: 0.5679\n",
            "Epoch [5/10], Step [4000/6471], Loss: 0.9741, KL Loss: 0.5566, L2 Loss: 0.000967, ICL Loss: 0.7119, TE1: 0.5645, TE2: 0.5630\n",
            "Epoch [5/10], Step [5000/6471], Loss: 0.9839, KL Loss: 0.5742, L2 Loss: 0.000943, ICL Loss: 0.7568, TE1: 0.5830, TE2: 0.5815\n",
            "Epoch [5/10], Step [6000/6471], Loss: 1.1504, KL Loss: 0.5664, L2 Loss: 0.000967, ICL Loss: 0.8242, TE1: 0.5615, TE2: 0.5615\n",
            "Epoch [5/10] Averages -> Loss: 0.9962, KL Loss: 0.5640, L2 Loss: 0.000961, ICL Loss: 0.7463, TE1: 0.5748, TE2: 0.5733\n",
            "Epoch [6/10], Step [0/6471], Loss: 0.7437, KL Loss: 0.5879, L2 Loss: 0.000942, ICL Loss: 0.6230, TE1: 0.5947, TE2: 0.5933\n",
            "Epoch [6/10], Step [1000/6471], Loss: 0.9570, KL Loss: 0.6055, L2 Loss: 0.000967, ICL Loss: 0.7104, TE1: 0.5684, TE2: 0.5664\n",
            "Epoch [6/10], Step [2000/6471], Loss: 1.0654, KL Loss: 0.5654, L2 Loss: 0.000972, ICL Loss: 0.8105, TE1: 0.5771, TE2: 0.5762\n",
            "Epoch [6/10], Step [3000/6471], Loss: 0.9297, KL Loss: 0.5308, L2 Loss: 0.000945, ICL Loss: 0.7295, TE1: 0.5820, TE2: 0.5820\n",
            "Epoch [6/10], Step [4000/6471], Loss: 0.9971, KL Loss: 0.5625, L2 Loss: 0.000964, ICL Loss: 0.7842, TE1: 0.5752, TE2: 0.5742\n",
            "Epoch [6/10], Step [5000/6471], Loss: 0.8960, KL Loss: 0.5659, L2 Loss: 0.000937, ICL Loss: 0.7134, TE1: 0.5903, TE2: 0.5884\n",
            "Epoch [6/10], Step [6000/6471], Loss: 0.7939, KL Loss: 0.5518, L2 Loss: 0.000939, ICL Loss: 0.6528, TE1: 0.5918, TE2: 0.5908\n",
            "Epoch [6/10] Averages -> Loss: 0.9386, KL Loss: 0.5627, L2 Loss: 0.000951, ICL Loss: 0.7250, TE1: 0.5814, TE2: 0.5800\n",
            "Epoch [7/10], Step [0/6471], Loss: 1.1875, KL Loss: 0.6123, L2 Loss: 0.000961, ICL Loss: 0.8491, TE1: 0.5635, TE2: 0.5610\n",
            "Epoch [7/10], Step [1000/6471], Loss: 0.9937, KL Loss: 0.5977, L2 Loss: 0.000932, ICL Loss: 0.7759, TE1: 0.5879, TE2: 0.5864\n",
            "Epoch [7/10], Step [2000/6471], Loss: 0.8940, KL Loss: 0.5635, L2 Loss: 0.000938, ICL Loss: 0.7188, TE1: 0.5889, TE2: 0.5874\n",
            "Epoch [7/10], Step [3000/6471], Loss: 0.8799, KL Loss: 0.6064, L2 Loss: 0.000938, ICL Loss: 0.6938, TE1: 0.5884, TE2: 0.5869\n",
            "Epoch [7/10], Step [4000/6471], Loss: 0.7241, KL Loss: 0.5723, L2 Loss: 0.000937, ICL Loss: 0.6123, TE1: 0.5957, TE2: 0.5942\n",
            "Epoch [7/10], Step [5000/6471], Loss: 0.9766, KL Loss: 0.5771, L2 Loss: 0.000944, ICL Loss: 0.7510, TE1: 0.5864, TE2: 0.5859\n",
            "Epoch [7/10], Step [6000/6471], Loss: 0.7920, KL Loss: 0.5127, L2 Loss: 0.000937, ICL Loss: 0.6543, TE1: 0.5830, TE2: 0.5820\n",
            "Epoch [7/10] Averages -> Loss: 0.8922, KL Loss: 0.5621, L2 Loss: 0.000943, ICL Loss: 0.7081, TE1: 0.5869, TE2: 0.5856\n",
            "Epoch [8/10], Step [0/6471], Loss: 0.9482, KL Loss: 0.6309, L2 Loss: 0.000964, ICL Loss: 0.7383, TE1: 0.5869, TE2: 0.5859\n",
            "Epoch [8/10], Step [1000/6471], Loss: 0.8745, KL Loss: 0.6167, L2 Loss: 0.000940, ICL Loss: 0.6909, TE1: 0.5830, TE2: 0.5835\n",
            "Epoch [8/10], Step [2000/6471], Loss: 0.8984, KL Loss: 0.5908, L2 Loss: 0.000948, ICL Loss: 0.7012, TE1: 0.5898, TE2: 0.5879\n",
            "Epoch [8/10], Step [3000/6471], Loss: 1.1025, KL Loss: 0.5229, L2 Loss: 0.000937, ICL Loss: 0.8613, TE1: 0.5913, TE2: 0.5908\n",
            "Epoch [8/10], Step [4000/6471], Loss: 0.8726, KL Loss: 0.6152, L2 Loss: 0.000953, ICL Loss: 0.6494, TE1: 0.5742, TE2: 0.5728\n",
            "Epoch [8/10], Step [5000/6471], Loss: 0.7778, KL Loss: 0.5449, L2 Loss: 0.000926, ICL Loss: 0.6782, TE1: 0.5986, TE2: 0.5981\n",
            "Epoch [8/10], Step [6000/6471], Loss: 0.7026, KL Loss: 0.5850, L2 Loss: 0.000930, ICL Loss: 0.6006, TE1: 0.5938, TE2: 0.5933\n",
            "Epoch [8/10] Averages -> Loss: 0.8534, KL Loss: 0.5614, L2 Loss: 0.000935, ICL Loss: 0.6936, TE1: 0.5915, TE2: 0.5903\n",
            "Epoch [9/10], Step [0/6471], Loss: 0.8843, KL Loss: 0.5576, L2 Loss: 0.000931, ICL Loss: 0.7451, TE1: 0.5996, TE2: 0.6001\n",
            "Epoch [9/10], Step [1000/6471], Loss: 0.8003, KL Loss: 0.5576, L2 Loss: 0.000947, ICL Loss: 0.6675, TE1: 0.5859, TE2: 0.5845\n",
            "Epoch [9/10], Step [2000/6471], Loss: 0.7710, KL Loss: 0.5693, L2 Loss: 0.000933, ICL Loss: 0.6533, TE1: 0.6006, TE2: 0.6001\n",
            "Epoch [9/10], Step [3000/6471], Loss: 0.8813, KL Loss: 0.5820, L2 Loss: 0.000946, ICL Loss: 0.7051, TE1: 0.5820, TE2: 0.5815\n",
            "Epoch [9/10], Step [4000/6471], Loss: 0.8281, KL Loss: 0.5347, L2 Loss: 0.000920, ICL Loss: 0.7021, TE1: 0.6045, TE2: 0.6025\n",
            "Epoch [9/10], Step [5000/6471], Loss: 1.0215, KL Loss: 0.5542, L2 Loss: 0.000930, ICL Loss: 0.7959, TE1: 0.5957, TE2: 0.5942\n",
            "Epoch [9/10], Step [6000/6471], Loss: 0.6733, KL Loss: 0.6025, L2 Loss: 0.000935, ICL Loss: 0.5830, TE1: 0.5942, TE2: 0.5942\n",
            "Epoch [9/10] Averages -> Loss: 0.8214, KL Loss: 0.5610, L2 Loss: 0.000929, ICL Loss: 0.6815, TE1: 0.5955, TE2: 0.5943\n",
            "Epoch [10/10], Step [0/6471], Loss: 0.9521, KL Loss: 0.5742, L2 Loss: 0.000931, ICL Loss: 0.7803, TE1: 0.5991, TE2: 0.6006\n",
            "Epoch [10/10], Step [1000/6471], Loss: 0.6558, KL Loss: 0.5957, L2 Loss: 0.000921, ICL Loss: 0.5752, TE1: 0.5972, TE2: 0.5981\n",
            "Epoch [10/10], Step [2000/6471], Loss: 0.7231, KL Loss: 0.5488, L2 Loss: 0.000909, ICL Loss: 0.6406, TE1: 0.6006, TE2: 0.5991\n",
            "Epoch [10/10], Step [3000/6471], Loss: 0.8291, KL Loss: 0.5439, L2 Loss: 0.000914, ICL Loss: 0.6973, TE1: 0.6045, TE2: 0.6035\n",
            "Epoch [10/10], Step [4000/6471], Loss: 0.7603, KL Loss: 0.5342, L2 Loss: 0.000919, ICL Loss: 0.6758, TE1: 0.6025, TE2: 0.6021\n",
            "Epoch [10/10], Step [5000/6471], Loss: 0.9590, KL Loss: 0.5996, L2 Loss: 0.000937, ICL Loss: 0.7354, TE1: 0.5913, TE2: 0.5898\n",
            "Epoch [10/10], Step [6000/6471], Loss: 0.7896, KL Loss: 0.5801, L2 Loss: 0.000919, ICL Loss: 0.6650, TE1: 0.6084, TE2: 0.6060\n",
            "Epoch [10/10] Averages -> Loss: 0.7930, KL Loss: 0.5607, L2 Loss: 0.000923, ICL Loss: 0.6704, TE1: 0.5990, TE2: 0.5979\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluate the Trained Student Model"
      ],
      "metadata": {
        "id": "Q_tN8A1Cyzzj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset\n",
        "import clip\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "input_resolution = 224\n",
        "context_length = 77\n",
        "\n",
        "# Evaluation transforms (same as training)\n",
        "eval_transform = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "                         std=(0.26862954, 0.26130258, 0.27577711))\n",
        "])\n",
        "\n",
        "class CocoEvalDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataset)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, captions = self.dataset[idx]\n",
        "        # Return the full list of captions for each image\n",
        "        return image, captions\n",
        "\n",
        "def coco_collate_fn(batch):\n",
        "    # batch is a list of (image, captions_list) tuples\n",
        "    images = []\n",
        "    captions = []\n",
        "    for img, caps in batch:\n",
        "        images.append(img)      # img is a Tensor\n",
        "        captions.append(caps)   # caps is a list of strings\n",
        "    images = torch.stack(images, dim=0)  # stack all images into a single tensor\n",
        "    return images, captions\n",
        "\n",
        "\n",
        "# Paths for validation\n",
        "val_img_dir = os.path.join(data_dir, 'val2014')\n",
        "val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "\n",
        "\n",
        "\n",
        "val_dataset = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, collate_fn=coco_collate_fn)\n",
        "\n",
        "\n",
        "student_image_encoder.eval()\n",
        "student_text_encoder.eval()\n",
        "\n",
        "all_image_features = []\n",
        "all_text_features = []\n",
        "image_to_text_indices = []  # For each image, store which text indices correspond to its captions\n",
        "all_captions_flat = []  # We'll store all captions globally\n",
        "\n",
        "with torch.no_grad():\n",
        "    image_count = 0\n",
        "    text_count = 0\n",
        "    for images, batch_captions in val_dataloader:\n",
        "        # images: (B, C, H, W)\n",
        "        # batch_captions: list of length B, each item is a list of captions for that image\n",
        "\n",
        "        images = images.to(device)\n",
        "\n",
        "        # Encode images\n",
        "        image_feats = student_image_encoder(images)\n",
        "        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)\n",
        "        all_image_features.append(image_feats.cpu())\n",
        "\n",
        "        # Flatten captions for this batch\n",
        "        flat_captions = []\n",
        "        image_to_text_map_for_batch = []\n",
        "        for caps in batch_captions:\n",
        "            start_idx = len(flat_captions)\n",
        "            flat_captions.extend(caps)  # add all captions from this image\n",
        "            end_idx = len(flat_captions)\n",
        "            # This image's captions correspond to indices [start_idx+text_count, end_idx+text_count)\n",
        "            image_to_text_map_for_batch.append((start_idx + text_count, end_idx + text_count))\n",
        "\n",
        "        # Tokenize all captions in the batch at once\n",
        "        texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "        text_feats = student_text_encoder(texts)\n",
        "        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        # Store text features globally\n",
        "        all_text_features.append(text_feats.cpu())\n",
        "        all_captions_flat.extend(flat_captions)\n",
        "\n",
        "        # Update the global mapping\n",
        "        for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "            image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "        image_count += images.size(0)\n",
        "        text_count += len(flat_captions)\n",
        "\n",
        "all_image_features = torch.cat(all_image_features, dim=0)  # (N_images, embed_dim)\n",
        "all_text_features = torch.cat(all_text_features, dim=0)    # (N_captions_total, embed_dim)\n",
        "\n",
        "# Compute similarity matrix: shape (N_images, N_captions_total)\n",
        "sim_matrix = all_image_features @ all_text_features.t()\n",
        "\n",
        "def compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1):\n",
        "    n = sim_matrix.size(0)\n",
        "    successes = 0\n",
        "    for i in range(n):\n",
        "        scores = sim_matrix[i]\n",
        "        sorted_indices = torch.argsort(scores, descending=True)\n",
        "\n",
        "        correct_indices = set(image_to_text_indices[i])\n",
        "        ranks_of_correct = []\n",
        "        for cidx in correct_indices:\n",
        "            pos = (sorted_indices == cidx).nonzero(as_tuple=True)\n",
        "            if len(pos) > 0:\n",
        "                ranks_of_correct.append(pos[0].item())\n",
        "\n",
        "        if len(ranks_of_correct) > 0:\n",
        "            min_rank = min(ranks_of_correct)\n",
        "            if min_rank < k:\n",
        "                successes += 1\n",
        "    recall = successes / n\n",
        "    return recall\n",
        "\n",
        "# Image-to-Text Retrieval\n",
        "r1 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "r5 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "\n",
        "print(\"Image-to-Text Retrieval:\")\n",
        "print(f\"Recall@1: {r1*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "# Text-to-Image Retrieval\n",
        "# Create reverse mapping from text index to image index\n",
        "text_to_image = [None]*all_text_features.size(0)\n",
        "for i, tinds in enumerate(image_to_text_indices):\n",
        "    for t in tinds:\n",
        "        text_to_image[t] = i\n",
        "\n",
        "sim_matrix_t2i = sim_matrix.t()  # (N_captions_total, N_images)\n",
        "\n",
        "def compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1):\n",
        "    m = sim_matrix_t2i.size(0)\n",
        "    successes = 0\n",
        "    for j in range(m):\n",
        "        scores = sim_matrix_t2i[j]\n",
        "        sorted_indices = torch.argsort(scores, descending=True)\n",
        "        correct_image = text_to_image[j]\n",
        "        rank = (sorted_indices == correct_image).nonzero(as_tuple=True)[0].item()\n",
        "        if rank < k:\n",
        "            successes += 1\n",
        "    recall = successes / m\n",
        "    return recall\n",
        "\n",
        "r1_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "r5_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "\n",
        "print(\"Text-to-Image Retrieval:\")\n",
        "print(f\"Recall@1: {r1_t2i*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5_t2i*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "id": "d42i-P_Yrq-Y",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0634689b-80fd-44e7-d9c0-3820e78f1de0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.30s)\n",
            "creating index...\n",
            "index created!\n",
            "Image-to-Text Retrieval:\n",
            "Recall@1: 6.52%\n",
            "Recall@5: 18.60%\n",
            "Recall@10: 27.16%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1: 5.79%\n",
            "Recall@5: 16.78%\n",
            "Recall@10: 24.47%\n"
          ]
        }
      ]
    }
  ]
}