{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "9fd2b00e-aa0b-4adc-af69-96b920573166"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the MS COCO Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "from zipfile import ZipFile\n",
        "from tqdm import tqdm\n",
        "\n",
        "# Define paths\n",
        "data_dir = '/content/coco2014'\n",
        "os.makedirs(data_dir, exist_ok=True)\n",
        "\n",
        "# URLs for datasets and annotations\n",
        "datasets = {\n",
        "#    \"train2014\": \"http://images.cocodataset.org/zips/train2014.zip\",\n",
        "    \"val2014\": \"http://images.cocodataset.org/zips/val2014.zip\",\n",
        "    \"annotations_trainval2014\": \"http://images.cocodataset.org/annotations/annotations_trainval2014.zip\",\n",
        "}\n",
        "\n",
        "# Download helper function with progress bar\n",
        "def download_file(url, dest_path):\n",
        "    response = requests.get(url, stream=True)\n",
        "    total_size = int(response.headers.get('content-length', 0))\n",
        "    with open(dest_path, 'wb') as f, tqdm(\n",
        "        desc=f\"Downloading {os.path.basename(dest_path)}\",\n",
        "        total=total_size,\n",
        "        unit='B',\n",
        "        unit_scale=True,\n",
        "        unit_divisor=1024\n",
        "    ) as bar:\n",
        "        for data in response.iter_content(chunk_size=1024):\n",
        "            f.write(data)\n",
        "            bar.update(len(data))\n",
        "\n",
        "# Download and extract datasets\n",
        "for name, url in datasets.items():\n",
        "    zip_path = os.path.join(data_dir, f\"{name}.zip\")\n",
        "    print(f\"Processing {name}...\")\n",
        "\n",
        "    # Download the dataset\n",
        "    download_file(url, zip_path)\n",
        "\n",
        "    # Unzip the dataset\n",
        "    with ZipFile(zip_path, 'r') as zip_ref:\n",
        "        zip_ref.extractall(data_dir)\n",
        "\n",
        "    # Remove the zip file to save space\n",
        "    os.remove(zip_path)\n",
        "    print(f\"{name} downloaded and extracted.\")\n",
        "\n",
        "print(\"All datasets and annotations successfully downloaded and extracted!\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XMDkrWBhLGg",
        "outputId": "1fc679c3-adc6-4528-bb96-24f239fd35f5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Processing val2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading val2014.zip: 100%|██████████| 6.19G/6.19G [14:09<00:00, 7.83MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "val2014 downloaded and extracted.\n",
            "Processing annotations_trainval2014...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:17<00:00, 14.4MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "annotations_trainval2014 downloaded and extracted.\n",
            "All datasets and annotations successfully downloaded and extracted!\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluate the Teacher Model ResNet50"
      ],
      "metadata": {
        "id": "tibEIEfoHYDd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset\n",
        "import clip\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Load the CLIP model (Teacher)\n",
        "model, preprocess = clip.load(\"RN50\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "\n",
        "# Evaluation transforms (same as training)\n",
        "eval_transform = preprocess\n",
        "\n",
        "class CocoEvalDataset(Dataset):\n",
        "    def __init__(self, root, annFile, transform=None):\n",
        "        self.dataset = datasets.CocoCaptions(root=root, annFile=annFile, transform=transform)\n",
        "        self.transform = transform\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataset)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        image, captions = self.dataset[idx]\n",
        "        caption = captions[0]\n",
        "        return image, caption\n",
        "\n",
        "\n",
        "# Paths for validation\n",
        "val_img_dir = os.path.join(data_dir, 'val2014')\n",
        "val_ann_file = os.path.join(data_dir, 'annotations', 'captions_val2014.json')\n",
        "\n",
        "val_dataset = CocoEvalDataset(root=val_img_dir, annFile=val_ann_file, transform=eval_transform)\n",
        "val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)\n",
        "\n",
        "all_image_features = []\n",
        "all_text_features = []\n",
        "all_captions = []  # Store captions for each image in order\n",
        "\n",
        "with torch.no_grad():\n",
        "    for images, captions in val_dataloader:\n",
        "        images = images.to(device)\n",
        "        # Tokenize captions here\n",
        "        texts = clip.tokenize(captions, context_length=context_length).to(device)\n",
        "\n",
        "        # Encode images and texts using the teacher model\n",
        "        image_feats = model.encode_image(images)\n",
        "        text_feats = model.encode_text(texts)\n",
        "\n",
        "        # Normalize\n",
        "        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)\n",
        "        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        all_image_features.append(image_feats.cpu())\n",
        "        all_text_features.append(text_feats.cpu())\n",
        "        all_captions.extend(captions)\n",
        "\n",
        "all_image_features = torch.cat(all_image_features, dim=0)  # (N, 512)\n",
        "all_text_features = torch.cat(all_text_features, dim=0)    # (N, 512)\n",
        "\n",
        "# Compute similarity matrix\n",
        "# image-to-text similarity: each image vs all texts\n",
        "sim_matrix = all_image_features @ all_text_features.t()  # (N, N)\n",
        "\n",
        "# Function to compute recall@K\n",
        "def compute_recall(sim_matrix, k=1):\n",
        "    ranks = []\n",
        "    n = sim_matrix.size(0)\n",
        "    for i in range(n):\n",
        "        # Sort texts by similarity to image i\n",
        "        sorted_indices = torch.argsort(sim_matrix[i], descending=True)\n",
        "        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()\n",
        "        ranks.append(rank)\n",
        "    ranks = torch.tensor(ranks)\n",
        "    recall = (ranks < k).float().mean().item()\n",
        "    return recall\n",
        "\n",
        "r1 = compute_recall(sim_matrix, k=1)\n",
        "r5 = compute_recall(sim_matrix, k=5)\n",
        "r10 = compute_recall(sim_matrix, k=10)\n",
        "\n",
        "print(\"Image-to-Text Retrieval:\")\n",
        "print(f\"Recall@1: {r1*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "# For text-to-image retrieval, we do the same but transpose the matrix\n",
        "sim_matrix_t2i = sim_matrix.t()  # (N, N)\n",
        "\n",
        "r1_t2i = compute_recall(sim_matrix_t2i, k=1)\n",
        "r5_t2i = compute_recall(sim_matrix_t2i, k=5)\n",
        "r10_t2i = compute_recall(sim_matrix_t2i, k=10)\n",
        "\n",
        "print(\"Text-to-Image Retrieval:\")\n",
        "print(f\"Recall@1: {r1_t2i*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5_t2i*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Yua1RSSIG_mr",
        "outputId": "c7d450d2-0318-4dbe-de4c-ad8ed9b2dd38"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 244M/244M [00:05<00:00, 46.7MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading annotations into memory...\n",
            "Done (t=0.26s)\n",
            "creating index...\n",
            "index created!\n",
            "Image-to-Text Retrieval:\n",
            "Recall@1: 15.27%\n",
            "Recall@5: 30.73%\n",
            "Recall@10: 39.05%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1: 11.68%\n",
            "Recall@5: 25.52%\n",
            "Recall@10: 33.50%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-nQFuNGG2YJv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "FUNccvGF2YsX"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}