{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "f9b64788-52f1-475b-f569-96c7f7efe673"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m80.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m41.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m43.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m98.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "#!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the Flickr8k Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget \"https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip\"\n",
        "!unzip -q flickr8k.zip -d ./flickr8k\n",
        "!rm flickr8k.zip\n",
        "!echo \"Downloaded Flickr8k dataset successfully.\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OcIsaazVKPio",
        "outputId": "a1e97d74-9611-490c-f01f-be0cdbda37d6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2025-05-14 04:08:17--  https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip\n",
            "Resolving github.com (github.com)... 20.205.243.166\n",
            "Connecting to github.com (github.com)|20.205.243.166|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/753516996/d7c62b13-1e50-40ea-8fae-f34a44b1695f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250514%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250514T040817Z&X-Amz-Expires=300&X-Amz-Signature=ee78a52331af04129eb03d08155ba99874dae1d67869b4082cb80587c9d7f348&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dflickr8k.zip&response-content-type=application%2Foctet-stream [following]\n",
            "--2025-05-14 04:08:17--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/753516996/d7c62b13-1e50-40ea-8fae-f34a44b1695f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250514%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250514T040817Z&X-Amz-Expires=300&X-Amz-Signature=ee78a52331af04129eb03d08155ba99874dae1d67869b4082cb80587c9d7f348&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dflickr8k.zip&response-content-type=application%2Foctet-stream\n",
            "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n",
            "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 1112971163 (1.0G) [application/octet-stream]\n",
            "Saving to: ‘flickr8k.zip’\n",
            "\n",
            "flickr8k.zip        100%[===================>]   1.04G   176MB/s    in 6.1s    \n",
            "\n",
            "2025-05-14 04:08:24 (175 MB/s) - ‘flickr8k.zip’ saved [1112971163/1112971163]\n",
            "\n",
            "Downloaded Flickr8k dataset successfully.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os, glob, random, torch\n",
        "from PIL import Image\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from torchvision import transforms\n",
        "import clip\n",
        "\n",
        "# ─── 0. Set data_root to where you unzipped flickr8k.zip ────────────────\n",
        "data_root = \"./flickr8k\"\n",
        "\n",
        "# ─── 1. Auto‐detect the captions file ────────────────────────────────────\n",
        "#    (search recursively for any .txt under data_root)\n",
        "caption_candidates = glob.glob(os.path.join(data_root, \"**\", \"*.txt\"), recursive=True)\n",
        "assert caption_candidates, f\"No .txt files found under {data_root}\"\n",
        "captions_file = caption_candidates[0]\n",
        "print(\"Using captions file:\", captions_file)\n",
        "\n",
        "# ─── 2. Auto‐detect the images directory ─────────────────────────────────\n",
        "#    (find the first subfolder containing JPGs)\n",
        "images_dir = None\n",
        "for root, dirs, files in os.walk(data_root):\n",
        "    if any(f.lower().endswith((\".jpg\", \".png\")) for f in files):\n",
        "        images_dir = root\n",
        "        break\n",
        "assert images_dir, f\"No image files found under {data_root}\"\n",
        "print(\"Using images dir:\", images_dir)\n",
        "\n",
        "# ─── 3. Hyperparameters ─────────────────────────────────────────────────\n",
        "input_resolution = 224\n",
        "context_length   = 77\n",
        "batch_size       = 64\n",
        "seed             = 42\n",
        "\n",
        "random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "# ─── 4. Read & split the captions ────────────────────────────────────────\n",
        "#    Flickr8k.token.txt lines are \"<image_name>#<idx>  <caption>\"\n",
        "# ─── 4. Read & split the captions (fix split on space) ────────────────────\n",
        "# captions.txt lines look like:\n",
        "#   \"1000268201_693b08cb0e.jpg#0 A child in a pink dress is climbing up a set of stairs .\"\n",
        "# ─── 4. Read & split the captions (CSV-style) ───────────────────────────────\n",
        "import csv\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "# ─── 4. Read & split the captions (CSV parsing + skip header) ──────────────\n",
        "caps = {}\n",
        "with open(captions_file, \"r\") as f:\n",
        "    reader = csv.reader(f)\n",
        "    next(reader, None)           # skip the header row if present\n",
        "    for row in reader:\n",
        "        if len(row) < 2: continue\n",
        "        key, caption = row[0], row[1]\n",
        "        img_name = key.split(\"#\")[0]\n",
        "        caps.setdefault(img_name, []).append(caption)\n",
        "\n",
        "# Now filter out any keys that didn’t get images\n",
        "all_imgs = [img for img in caps.keys()\n",
        "            if os.path.exists(os.path.join(images_dir, img))]\n",
        "# and do your 85/15 split\n",
        "train_imgs, val_imgs = train_test_split(all_imgs,\n",
        "                                        test_size=0.15,\n",
        "                                        random_state=seed)\n",
        "\n",
        "\n",
        "# ─── 5. Transforms ───────────────────────────────────────────────────────\n",
        "train_tf = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "eval_tf = train_tf\n",
        "\n",
        "# ─── 6. Dataset classes ──────────────────────────────────────────────────\n",
        "class Flickr8kTrain(Dataset):\n",
        "    def __init__(self, img_list, caps, transform=None):\n",
        "        self.transform = transform\n",
        "        self.data = [\n",
        "            (os.path.join(images_dir, img), cap)\n",
        "            for img in img_list for cap in caps[img]\n",
        "        ]\n",
        "    def __len__(self): return len(self.data)\n",
        "    def __getitem__(self, i):\n",
        "        img_path, caption = self.data[i]\n",
        "        img = Image.open(img_path).convert(\"RGB\")\n",
        "        if self.transform: img = self.transform(img)\n",
        "        txt = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return img, txt\n",
        "\n",
        "class Flickr8kEval(Dataset):\n",
        "    def __init__(self, img_list, caps, transform=None):\n",
        "        self.transform = transform\n",
        "        self.data = [\n",
        "            (os.path.join(images_dir, img), caps[img])\n",
        "            for img in img_list\n",
        "        ]\n",
        "    def __len__(self): return len(self.data)\n",
        "    def __getitem__(self, i):\n",
        "        img_path, captions = self.data[i]\n",
        "        img = Image.open(img_path).convert(\"RGB\")\n",
        "        if self.transform: img = self.transform(img)\n",
        "        return img, captions\n",
        "\n",
        "def eval_collate(batch):\n",
        "    imgs, caps = zip(*batch)\n",
        "    return torch.stack(imgs, 0), list(caps)\n",
        "\n",
        "# ─── 7. Build dataloaders ────────────────────────────────────────────────\n",
        "train_ds = Flickr8kTrain(train_imgs, caps, transform=train_tf)\n",
        "val_ds   = Flickr8kEval(  val_imgs,   caps, transform=eval_tf)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_ds, batch_size=batch_size, shuffle=True,\n",
        "    num_workers=2, pin_memory=True\n",
        ")\n",
        "val_dataloader = DataLoader(\n",
        "    val_ds, batch_size=batch_size, shuffle=False,\n",
        "    num_workers=2, pin_memory=True, collate_fn=eval_collate\n",
        ")\n",
        "\n",
        "print(f\"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WfxmIzvYCRI6",
        "outputId": "575c2acc-6b14-41aa-a583-94743c604106"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using captions file: ./flickr8k/captions.txt\n",
            "Using images dir: ./flickr8k/Images\n",
            "Train samples: 34385, Val samples: 1214\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: ViT-B/16 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ce605f04-b326-46a7-89c0-b7a9239f8b2b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|████████████████████████████████████████| 335M/335M [00:01<00:00, 184MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 149,620,737\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the Student Model (ResNet-34)"
      ],
      "metadata": {
        "id": "GifHBhlWlr29"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch.nn as nn\n",
        "import torchvision.models as models\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "# Student Image Encoder (ResNet-34)\n",
        "class StudentImageEncoder(nn.Module):\n",
        "    def __init__(self, output_dim):\n",
        "        super(StudentImageEncoder, self).__init__()\n",
        "        self.encoder = models.resnet34(pretrained=True)\n",
        "        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.encoder(x)\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize\n",
        "        return x\n",
        "\n",
        "\n",
        "class StudentTextEncoder(nn.Module):\n",
        "    def __init__(self, vocab_size, context_length, output_dim):\n",
        "        super(StudentTextEncoder, self).__init__()\n",
        "        self.token_embedding = nn.Embedding(vocab_size, output_dim)\n",
        "        self.positional_embedding = nn.Parameter(torch.zeros(context_length, output_dim))\n",
        "        nn.init.normal_(self.positional_embedding, std=0.01)\n",
        "        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=8)\n",
        "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)\n",
        "        self.ln_final = nn.LayerNorm(output_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x shape: (batch_size, seq_len)\n",
        "        x = self.token_embedding(x) + self.positional_embedding  # (batch_size, seq_len, output_dim)\n",
        "        x = x.permute(1, 0, 2)  # (seq_len, batch_size, output_dim)\n",
        "        x = self.transformer(x)\n",
        "        x = x.permute(1, 0, 2)  # (batch_size, seq_len, output_dim)\n",
        "        x = self.ln_final(x)\n",
        "        x = x.mean(dim=1)  # Mean pooling over the sequence length\n",
        "        x = x / x.norm(dim=-1, keepdim=True)  # Normalize to unit length\n",
        "        return x  # (batch_size, output_dim)\n",
        "\n"
      ],
      "metadata": {
        "id": "Bu5lXd8ehL2K"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define the Contrastive Loss Function Using CL-TE1-TE2"
      ],
      "metadata": {
        "id": "KijNXkynlvLL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def contrastive_loss_with_kl_l2_te(\n",
        "    student_image_features,\n",
        "    student_text_features,\n",
        "    teacher_image_features,\n",
        "    teacher_text_features,\n",
        "    temperature=0.07,\n",
        "    alpha=1.0,  # weight for KL term\n",
        "    beta=100,   # weight for L2 term\n",
        "    gamma=5  # weight for TE rewards\n",
        "):\n",
        "    eps = 1e-6\n",
        "\n",
        "    # Normalize features\n",
        "    student_image_features = student_image_features / student_image_features.norm(dim=-1, keepdim=True)\n",
        "    student_text_features = student_text_features / student_text_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_image_features = teacher_image_features / teacher_image_features.norm(dim=-1, keepdim=True)\n",
        "    teacher_text_features = teacher_text_features / teacher_text_features.norm(dim=-1, keepdim=True)\n",
        "\n",
        "    # Compute student logits\n",
        "    logits_per_image_student = student_image_features @ student_text_features.t() / temperature\n",
        "    logits_per_text_student = logits_per_image_student.t()\n",
        "\n",
        "    # Compute teacher logits (no gradients)\n",
        "    with torch.no_grad():\n",
        "        logits_per_image_teacher = teacher_image_features @ teacher_text_features.t() / temperature\n",
        "        logits_per_text_teacher = logits_per_image_teacher.t()\n",
        "\n",
        "    # Contrastive loss\n",
        "    batch_size = student_image_features.size(0)\n",
        "    labels = torch.arange(batch_size, device=student_image_features.device)\n",
        "    loss_image = F.cross_entropy(logits_per_image_student, labels)\n",
        "    loss_text = F.cross_entropy(logits_per_text_student, labels)\n",
        "    contrastive_loss = (loss_image + loss_text) / 2\n",
        "\n",
        "\n",
        "    # L2 distance loss\n",
        "    #l2_img = F.mse_loss(student_image_features, teacher_image_features)\n",
        "    #l2_txt = F.mse_loss(student_text_features, teacher_text_features)\n",
        "    #l2_loss = (l2_img + l2_txt) / 2\n",
        "\n",
        "\n",
        "\n",
        "    # ---------------------\n",
        "    # Cosine Similarity TE Surrogate\n",
        "    # ---------------------\n",
        "    def cosine_te(student_features, teacher_features):\n",
        "        \"\"\"\n",
        "        Approximates transfer entropy by computing the cosine similarity between\n",
        "        the differences (i.e., directional changes) of consecutive embeddings.\n",
        "        A higher cosine similarity indicates that the student is following the teacher's direction.\n",
        "        \"\"\"\n",
        "        # Compute differences along the batch dimension (assumes batch ordering approximates temporal ordering)\n",
        "        student_diff = student_features[1:] - student_features[:-1]\n",
        "        teacher_diff = teacher_features[1:] - teacher_features[:-1]\n",
        "        cos_sim = F.cosine_similarity(student_diff, teacher_diff, dim=-1, eps=eps)\n",
        "        return cos_sim.mean()\n",
        "\n",
        "    te_img = cosine_te(student_image_features, teacher_image_features)\n",
        "    te_txt = cosine_te(student_text_features, teacher_text_features)\n",
        "    te1 = (te_img + te_txt) / 2\n",
        "\n",
        "\n",
        "\n",
        "    # ---------------------\n",
        "    # Cosine Similarity TE Surrogate on Concatenated Differences\n",
        "    # ---------------------\n",
        "    # Compute differences along the batch dimension (assumes batch ordering approximates temporal ordering)\n",
        "    student_diff_img = student_image_features[1:] - student_image_features[:-1]\n",
        "    teacher_diff_img = teacher_image_features[1:] - teacher_image_features[:-1]\n",
        "    student_diff_txt = student_text_features[1:] - student_text_features[:-1]\n",
        "    teacher_diff_txt = teacher_text_features[1:] - teacher_text_features[:-1]\n",
        "\n",
        "    # Concatenate differences from image and text modalities\n",
        "    student_diff_cat = torch.cat([student_diff_img, student_diff_txt], dim=-1)\n",
        "    teacher_diff_cat = torch.cat([teacher_diff_img, teacher_diff_txt], dim=-1)\n",
        "\n",
        "    # Compute cosine similarity between concatenated difference vectors\n",
        "    te2 = F.cosine_similarity(student_diff_cat, teacher_diff_cat, dim=-1, eps=eps).mean()\n",
        "\n",
        "\n",
        "    # Combine all losses\n",
        "    # Increase synergy by subtracting gamma * synergy\n",
        "    # Maximize redundancy by subtracting epsilon * redundancy\n",
        "    #total_loss = (contrastive_loss + alpha * kl_loss + beta * l2_loss - gamma * te)\n",
        "    total_loss = (contrastive_loss - gamma * te1 - gamma * te2)\n",
        "\n",
        "\n",
        "    return total_loss, te1, te2\n"
      ],
      "metadata": {
        "id": "adWN2SVlXyrb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set Up the Training Loop"
      ],
      "metadata": {
        "id": "BijGiufjlw8K"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Instantiate student models\n",
        "student_image_encoder = StudentImageEncoder(output_dim=512).to(device)\n",
        "student_text_encoder = StudentTextEncoder(vocab_size, context_length, output_dim=512).to(device)\n",
        "\n",
        "# Define optimizer\n",
        "optimizer = torch.optim.Adam(\n",
        "    list(student_image_encoder.parameters()) + list(student_text_encoder.parameters()),\n",
        "    lr=1e-4\n",
        ")\n"
      ],
      "metadata": {
        "id": "UyA-cl0LhM6g",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c6748839-7e74-4975-e671-bf608b63d148"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n",
            "Downloading: \"https://download.pytorch.org/models/resnet34-b627a593.pth\" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth\n",
            "100%|██████████| 83.3M/83.3M [00:00<00:00, 227MB/s]\n",
            "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
            "  warnings.warn(\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train the Student Model"
      ],
      "metadata": {
        "id": "0isNcb4ClyWa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Training Loop\n",
        "num_epochs = 10  # the number of epochs\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    student_image_encoder.train()\n",
        "    student_text_encoder.train()\n",
        "\n",
        "    # Initialize trackers for average metrics\n",
        "    total_loss = 0.0\n",
        "    total_te1 = 0.0\n",
        "    total_te2 = 0.0\n",
        "\n",
        "    total_batches = len(train_dataloader)\n",
        "\n",
        "    for batch_idx, (images, texts) in enumerate(train_dataloader):\n",
        "        images = images.to(device)\n",
        "        texts = texts.to(device)\n",
        "\n",
        "        # Teacher outputs\n",
        "        with torch.no_grad():\n",
        "            teacher_image_features = model.encode_image(images)\n",
        "            teacher_text_features = model.encode_text(texts)\n",
        "\n",
        "        # Student outputs\n",
        "        student_image_features = student_image_encoder(images).to(teacher_image_features.dtype)\n",
        "        student_text_features = student_text_encoder(texts).to(teacher_text_features.dtype)\n",
        "\n",
        "        # Compute Contrastive Loss with detailed returns\n",
        "        loss, te1, te2 = contrastive_loss_with_kl_l2_te(\n",
        "            student_image_features,\n",
        "            student_text_features,\n",
        "            teacher_image_features,\n",
        "            teacher_text_features\n",
        "        )\n",
        "\n",
        "        # Backpropagation\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Update trackers\n",
        "        total_loss += loss.item()\n",
        "        total_te1 += te1.item()\n",
        "        total_te2 += te2.item()\n",
        "\n",
        "        if batch_idx % 100 == 0:\n",
        "            print(\n",
        "                f\"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_dataloader)}], \"\n",
        "                f\"Loss: {loss.item():.4f}, \"\n",
        "                f\"TE1: {te1:.4f}, TE2: {te2:.4f}\"\n",
        "            )\n",
        "\n",
        "    # Compute average metrics for the epoch\n",
        "    avg_loss = total_loss / total_batches\n",
        "    avg_te1 = total_te1 / total_batches\n",
        "    avg_te2 = total_te2 / total_batches\n",
        "\n",
        "#    avg_synergy_diff = total_synergy_diff / total_batches\n",
        "\n",
        "    # Print epoch-level metrics\n",
        "    print(f\"Epoch [{epoch+1}/{num_epochs}] Averages -> \"\n",
        "          f\"Loss: {avg_loss:.4f}, \"\n",
        "          f\"TE1: {avg_te1:.4f}, TE2: {avg_te2:.4f}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OIO4zuMbXdbY",
        "outputId": "2a5e3c5f-c2f6-4cb1-ebd2-63f0e8801219"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch [1/10], Step [0/538], Loss: 4.1836, TE1: 0.0005, TE2: 0.0004\n",
            "Epoch [1/10], Step [100/538], Loss: -1.5020, TE1: 0.4373, TE2: 0.4385\n",
            "Epoch [1/10], Step [200/538], Loss: -3.3555, TE1: 0.5225, TE2: 0.5244\n",
            "Epoch [1/10], Step [300/538], Loss: -4.1367, TE1: 0.5762, TE2: 0.5776\n",
            "Epoch [1/10], Step [400/538], Loss: -3.9746, TE1: 0.5527, TE2: 0.5552\n",
            "Epoch [1/10], Step [500/538], Loss: -4.4219, TE1: 0.5811, TE2: 0.5811\n",
            "Epoch [1/10] Averages -> Loss: -2.9808, TE1: 0.5037, TE2: 0.5049\n",
            "Epoch [2/10], Step [0/538], Loss: -5.1055, TE1: 0.6143, TE2: 0.6152\n",
            "Epoch [2/10], Step [100/538], Loss: -5.2148, TE1: 0.6104, TE2: 0.6084\n",
            "Epoch [2/10], Step [200/538], Loss: -5.3047, TE1: 0.6387, TE2: 0.6387\n",
            "Epoch [2/10], Step [300/538], Loss: -5.7656, TE1: 0.6685, TE2: 0.6689\n",
            "Epoch [2/10], Step [400/538], Loss: -5.7656, TE1: 0.6445, TE2: 0.6450\n",
            "Epoch [2/10], Step [500/538], Loss: -5.7812, TE1: 0.6768, TE2: 0.6787\n",
            "Epoch [2/10] Averages -> Loss: -5.4984, TE1: 0.6416, TE2: 0.6421\n",
            "Epoch [3/10], Step [0/538], Loss: -6.1875, TE1: 0.6826, TE2: 0.6821\n",
            "Epoch [3/10], Step [100/538], Loss: -5.9688, TE1: 0.6680, TE2: 0.6685\n",
            "Epoch [3/10], Step [200/538], Loss: -6.2500, TE1: 0.6807, TE2: 0.6802\n",
            "Epoch [3/10], Step [300/538], Loss: -6.2422, TE1: 0.7021, TE2: 0.7021\n",
            "Epoch [3/10], Step [400/538], Loss: -6.4844, TE1: 0.7158, TE2: 0.7178\n",
            "Epoch [3/10], Step [500/538], Loss: -6.4492, TE1: 0.7207, TE2: 0.7222\n",
            "Epoch [3/10] Averages -> Loss: -6.2809, TE1: 0.6922, TE2: 0.6924\n",
            "Epoch [4/10], Step [0/538], Loss: -6.6719, TE1: 0.7188, TE2: 0.7192\n",
            "Epoch [4/10], Step [100/538], Loss: -6.7383, TE1: 0.7114, TE2: 0.7119\n",
            "Epoch [4/10], Step [200/538], Loss: -6.6445, TE1: 0.7227, TE2: 0.7241\n",
            "Epoch [4/10], Step [300/538], Loss: -6.9141, TE1: 0.7422, TE2: 0.7412\n",
            "Epoch [4/10], Step [400/538], Loss: -6.7656, TE1: 0.7246, TE2: 0.7241\n",
            "Epoch [4/10], Step [500/538], Loss: -6.8516, TE1: 0.7388, TE2: 0.7402\n",
            "Epoch [4/10] Averages -> Loss: -6.7443, TE1: 0.7265, TE2: 0.7263\n",
            "Epoch [5/10], Step [0/538], Loss: -7.0938, TE1: 0.7500, TE2: 0.7505\n",
            "Epoch [5/10], Step [100/538], Loss: -7.2109, TE1: 0.7617, TE2: 0.7612\n",
            "Epoch [5/10], Step [200/538], Loss: -7.1328, TE1: 0.7637, TE2: 0.7627\n",
            "Epoch [5/10], Step [300/538], Loss: -6.9102, TE1: 0.7393, TE2: 0.7422\n",
            "Epoch [5/10], Step [400/538], Loss: -7.0078, TE1: 0.7466, TE2: 0.7456\n",
            "Epoch [5/10], Step [500/538], Loss: -7.2695, TE1: 0.7666, TE2: 0.7656\n",
            "Epoch [5/10] Averages -> Loss: -7.0718, TE1: 0.7527, TE2: 0.7524\n",
            "Epoch [6/10], Step [0/538], Loss: -7.4688, TE1: 0.7764, TE2: 0.7749\n",
            "Epoch [6/10], Step [100/538], Loss: -7.1602, TE1: 0.7539, TE2: 0.7539\n",
            "Epoch [6/10], Step [200/538], Loss: -7.2031, TE1: 0.7607, TE2: 0.7607\n",
            "Epoch [6/10], Step [300/538], Loss: -7.1875, TE1: 0.7666, TE2: 0.7671\n",
            "Epoch [6/10], Step [400/538], Loss: -7.6016, TE1: 0.7988, TE2: 0.7983\n",
            "Epoch [6/10], Step [500/538], Loss: -7.5234, TE1: 0.7920, TE2: 0.7910\n",
            "Epoch [6/10] Averages -> Loss: -7.3281, TE1: 0.7738, TE2: 0.7732\n",
            "Epoch [7/10], Step [0/538], Loss: -7.3789, TE1: 0.7759, TE2: 0.7739\n",
            "Epoch [7/10], Step [100/538], Loss: -7.6602, TE1: 0.7959, TE2: 0.7954\n",
            "Epoch [7/10], Step [200/538], Loss: -7.5625, TE1: 0.7920, TE2: 0.7910\n",
            "Epoch [7/10], Step [300/538], Loss: -7.5938, TE1: 0.7974, TE2: 0.7964\n",
            "Epoch [7/10], Step [400/538], Loss: -7.5039, TE1: 0.7876, TE2: 0.7861\n",
            "Epoch [7/10], Step [500/538], Loss: -7.6406, TE1: 0.7964, TE2: 0.7949\n",
            "Epoch [7/10] Averages -> Loss: -7.5222, TE1: 0.7903, TE2: 0.7896\n",
            "Epoch [8/10], Step [0/538], Loss: -7.5625, TE1: 0.7876, TE2: 0.7871\n",
            "Epoch [8/10], Step [100/538], Loss: -7.6836, TE1: 0.8047, TE2: 0.8047\n",
            "Epoch [8/10], Step [200/538], Loss: -7.6641, TE1: 0.7988, TE2: 0.7988\n",
            "Epoch [8/10], Step [300/538], Loss: -7.7188, TE1: 0.8086, TE2: 0.8076\n",
            "Epoch [8/10], Step [400/538], Loss: -7.7266, TE1: 0.8052, TE2: 0.8032\n",
            "Epoch [8/10], Step [500/538], Loss: -7.8047, TE1: 0.8218, TE2: 0.8213\n",
            "Epoch [8/10] Averages -> Loss: -7.6800, TE1: 0.8042, TE2: 0.8032\n",
            "Epoch [9/10], Step [0/538], Loss: -7.9023, TE1: 0.8262, TE2: 0.8257\n",
            "Epoch [9/10], Step [100/538], Loss: -7.6055, TE1: 0.7954, TE2: 0.7930\n",
            "Epoch [9/10], Step [200/538], Loss: -7.7461, TE1: 0.8105, TE2: 0.8096\n",
            "Epoch [9/10], Step [300/538], Loss: -7.5938, TE1: 0.7969, TE2: 0.8013\n",
            "Epoch [9/10], Step [400/538], Loss: -7.9141, TE1: 0.8242, TE2: 0.8223\n",
            "Epoch [9/10], Step [500/538], Loss: -7.8203, TE1: 0.8179, TE2: 0.8169\n",
            "Epoch [9/10] Averages -> Loss: -7.8189, TE1: 0.8161, TE2: 0.8151\n",
            "Epoch [10/10], Step [0/538], Loss: -7.9766, TE1: 0.8257, TE2: 0.8257\n",
            "Epoch [10/10], Step [100/538], Loss: -7.8633, TE1: 0.8184, TE2: 0.8159\n",
            "Epoch [10/10], Step [200/538], Loss: -7.9219, TE1: 0.8252, TE2: 0.8242\n",
            "Epoch [10/10], Step [300/538], Loss: -7.9844, TE1: 0.8291, TE2: 0.8281\n",
            "Epoch [10/10], Step [400/538], Loss: -7.9219, TE1: 0.8247, TE2: 0.8237\n",
            "Epoch [10/10], Step [500/538], Loss: -7.9219, TE1: 0.8242, TE2: 0.8232\n",
            "Epoch [10/10] Averages -> Loss: -7.9307, TE1: 0.8260, TE2: 0.8250\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluate the Trained Student Model"
      ],
      "metadata": {
        "id": "Q_tN8A1Cyzzj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset\n",
        "import clip\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "input_resolution = 224\n",
        "context_length = 77\n",
        "\n",
        "\n",
        "\n",
        "student_image_encoder.eval()\n",
        "student_text_encoder.eval()\n",
        "\n",
        "all_image_features = []\n",
        "all_text_features = []\n",
        "image_to_text_indices = []  # For each image, store which text indices correspond to its captions\n",
        "all_captions_flat = []  # We'll store all captions globally\n",
        "\n",
        "with torch.no_grad():\n",
        "    image_count = 0\n",
        "    text_count = 0\n",
        "    for images, batch_captions in val_dataloader:\n",
        "        # images: (B, C, H, W)\n",
        "        # batch_captions: list of length B, each item is a list of captions for that image\n",
        "\n",
        "        images = images.to(device)\n",
        "\n",
        "        # Encode images\n",
        "        image_feats = student_image_encoder(images)\n",
        "        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)\n",
        "        all_image_features.append(image_feats.cpu())\n",
        "\n",
        "        # Flatten captions for this batch\n",
        "        flat_captions = []\n",
        "        image_to_text_map_for_batch = []\n",
        "        for caps in batch_captions:\n",
        "            start_idx = len(flat_captions)\n",
        "            flat_captions.extend(caps)  # add all captions from this image\n",
        "            end_idx = len(flat_captions)\n",
        "            # This image's captions correspond to indices [start_idx+text_count, end_idx+text_count)\n",
        "            image_to_text_map_for_batch.append((start_idx + text_count, end_idx + text_count))\n",
        "\n",
        "        # Tokenize all captions in the batch at once\n",
        "        texts = clip.tokenize(flat_captions, context_length=context_length).to(device)\n",
        "        text_feats = student_text_encoder(texts)\n",
        "        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        # Store text features globally\n",
        "        all_text_features.append(text_feats.cpu())\n",
        "        all_captions_flat.extend(flat_captions)\n",
        "\n",
        "        # Update the global mapping\n",
        "        for (start_idx, end_idx) in image_to_text_map_for_batch:\n",
        "            image_to_text_indices.append(list(range(start_idx, end_idx)))\n",
        "\n",
        "        image_count += images.size(0)\n",
        "        text_count += len(flat_captions)\n",
        "\n",
        "all_image_features = torch.cat(all_image_features, dim=0)  # (N_images, embed_dim)\n",
        "all_text_features = torch.cat(all_text_features, dim=0)    # (N_captions_total, embed_dim)\n",
        "\n",
        "# Compute similarity matrix: shape (N_images, N_captions_total)\n",
        "sim_matrix = all_image_features @ all_text_features.t()\n",
        "\n",
        "def compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1):\n",
        "    n = sim_matrix.size(0)\n",
        "    successes = 0\n",
        "    for i in range(n):\n",
        "        scores = sim_matrix[i]\n",
        "        sorted_indices = torch.argsort(scores, descending=True)\n",
        "\n",
        "        correct_indices = set(image_to_text_indices[i])\n",
        "        ranks_of_correct = []\n",
        "        for cidx in correct_indices:\n",
        "            pos = (sorted_indices == cidx).nonzero(as_tuple=True)\n",
        "            if len(pos) > 0:\n",
        "                ranks_of_correct.append(pos[0].item())\n",
        "\n",
        "        if len(ranks_of_correct) > 0:\n",
        "            min_rank = min(ranks_of_correct)\n",
        "            if min_rank < k:\n",
        "                successes += 1\n",
        "    recall = successes / n\n",
        "    return recall\n",
        "\n",
        "# Image-to-Text Retrieval\n",
        "r1 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=1)\n",
        "r5 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=5)\n",
        "r10 = compute_recall_with_multiple_captions(sim_matrix, image_to_text_indices, k=10)\n",
        "\n",
        "print(\"Image-to-Text Retrieval:\")\n",
        "print(f\"Recall@1: {r1*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "# Text-to-Image Retrieval\n",
        "# Create reverse mapping from text index to image index\n",
        "text_to_image = [None]*all_text_features.size(0)\n",
        "for i, tinds in enumerate(image_to_text_indices):\n",
        "    for t in tinds:\n",
        "        text_to_image[t] = i\n",
        "\n",
        "sim_matrix_t2i = sim_matrix.t()  # (N_captions_total, N_images)\n",
        "\n",
        "def compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1):\n",
        "    m = sim_matrix_t2i.size(0)\n",
        "    successes = 0\n",
        "    for j in range(m):\n",
        "        scores = sim_matrix_t2i[j]\n",
        "        sorted_indices = torch.argsort(scores, descending=True)\n",
        "        correct_image = text_to_image[j]\n",
        "        rank = (sorted_indices == correct_image).nonzero(as_tuple=True)[0].item()\n",
        "        if rank < k:\n",
        "            successes += 1\n",
        "    recall = successes / m\n",
        "    return recall\n",
        "\n",
        "r1_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=1)\n",
        "r5_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=5)\n",
        "r10_t2i = compute_recall_text_to_image(sim_matrix_t2i, text_to_image, k=10)\n",
        "\n",
        "print(\"Text-to-Image Retrieval:\")\n",
        "print(f\"Recall@1: {r1_t2i*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5_t2i*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "id": "d42i-P_Yrq-Y",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "24d5409f-8d06-4a31-bbde-c5a590159fb0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Image-to-Text Retrieval:\n",
            "Recall@1: 28.67%\n",
            "Recall@5: 57.58%\n",
            "Recall@10: 70.43%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1: 22.32%\n",
            "Recall@5: 49.75%\n",
            "Recall@10: 63.21%\n"
          ]
        }
      ]
    }
  ]
}