{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdPysIbDhFsa",
        "outputId": "af4ec693-d81b-4938-8bc1-98067ec73709"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[?25l   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m96.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m78.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m59.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m38.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m18.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m73.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install --quiet ftfy regex tqdm\n",
        "!pip install --quiet git+https://github.com/openai/CLIP.git\n",
        "#!pip install --quiet pycocotools\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download and Prepare the Flickr8k Dataset"
      ],
      "metadata": {
        "id": "OQ5eAnW_lmiK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget \"https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip\"\n",
        "!unzip -q flickr8k.zip -d ./flickr8k\n",
        "!rm flickr8k.zip\n",
        "!echo \"Downloaded Flickr8k dataset successfully.\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OcIsaazVKPio",
        "outputId": "45cf40f2-248e-415e-fabe-a2cf9e89d0bc"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2025-05-13 22:12:34--  https://github.com/awsaf49/flickr-dataset/releases/download/v1.0/flickr8k.zip\n",
            "Resolving github.com (github.com)... 140.82.113.4\n",
            "Connecting to github.com (github.com)|140.82.113.4|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/753516996/d7c62b13-1e50-40ea-8fae-f34a44b1695f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250513%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250513T221234Z&X-Amz-Expires=300&X-Amz-Signature=81cfc3ccad128c3e279ff018620aff954e5363a9676e47cae6f0df732e086dbe&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dflickr8k.zip&response-content-type=application%2Foctet-stream [following]\n",
            "--2025-05-13 22:12:34--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/753516996/d7c62b13-1e50-40ea-8fae-f34a44b1695f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250513%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250513T221234Z&X-Amz-Expires=300&X-Amz-Signature=81cfc3ccad128c3e279ff018620aff954e5363a9676e47cae6f0df732e086dbe&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dflickr8k.zip&response-content-type=application%2Foctet-stream\n",
            "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...\n",
            "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 1112971163 (1.0G) [application/octet-stream]\n",
            "Saving to: ‘flickr8k.zip’\n",
            "\n",
            "flickr8k.zip        100%[===================>]   1.04G   211MB/s    in 5.1s    \n",
            "\n",
            "2025-05-13 22:12:39 (208 MB/s) - ‘flickr8k.zip’ saved [1112971163/1112971163]\n",
            "\n",
            "Downloaded Flickr8k dataset successfully.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os, glob, random, torch\n",
        "from PIL import Image\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "from torchvision import transforms\n",
        "import clip\n",
        "\n",
        "# ─── 0. Set data_root to where you unzipped flickr8k.zip ────────────────\n",
        "data_root = \"./flickr8k\"\n",
        "\n",
        "# ─── 1. Auto‐detect the captions file ────────────────────────────────────\n",
        "#    (search recursively for any .txt under data_root)\n",
        "caption_candidates = glob.glob(os.path.join(data_root, \"**\", \"*.txt\"), recursive=True)\n",
        "assert caption_candidates, f\"No .txt files found under {data_root}\"\n",
        "captions_file = caption_candidates[0]\n",
        "print(\"Using captions file:\", captions_file)\n",
        "\n",
        "# ─── 2. Auto‐detect the images directory ─────────────────────────────────\n",
        "#    (find the first subfolder containing JPGs)\n",
        "images_dir = None\n",
        "for root, dirs, files in os.walk(data_root):\n",
        "    if any(f.lower().endswith((\".jpg\", \".png\")) for f in files):\n",
        "        images_dir = root\n",
        "        break\n",
        "assert images_dir, f\"No image files found under {data_root}\"\n",
        "print(\"Using images dir:\", images_dir)\n",
        "\n",
        "# ─── 3. Hyperparameters ─────────────────────────────────────────────────\n",
        "input_resolution = 224\n",
        "context_length   = 77\n",
        "batch_size       = 64\n",
        "seed             = 42\n",
        "\n",
        "random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "# ─── 4. Read & split the captions ────────────────────────────────────────\n",
        "#    Flickr8k.token.txt lines are \"<image_name>#<idx>  <caption>\"\n",
        "# ─── 4. Read & split the captions (fix split on space) ────────────────────\n",
        "# captions.txt lines look like:\n",
        "#   \"1000268201_693b08cb0e.jpg#0 A child in a pink dress is climbing up a set of stairs .\"\n",
        "# ─── 4. Read & split the captions (CSV-style) ───────────────────────────────\n",
        "import csv\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "# ─── 4. Read & split the captions (CSV parsing + skip header) ──────────────\n",
        "caps = {}\n",
        "with open(captions_file, \"r\") as f:\n",
        "    reader = csv.reader(f)\n",
        "    next(reader, None)           # skip the header row if present\n",
        "    for row in reader:\n",
        "        if len(row) < 2: continue\n",
        "        key, caption = row[0], row[1]\n",
        "        img_name = key.split(\"#\")[0]\n",
        "        caps.setdefault(img_name, []).append(caption)\n",
        "\n",
        "# Now filter out any keys that didn’t get images\n",
        "all_imgs = [img for img in caps.keys()\n",
        "            if os.path.exists(os.path.join(images_dir, img))]\n",
        "# and do your 85/15 split\n",
        "train_imgs, val_imgs = train_test_split(all_imgs,\n",
        "                                        test_size=0.15,\n",
        "                                        random_state=seed)\n",
        "\n",
        "\n",
        "# ─── 5. Transforms ───────────────────────────────────────────────────────\n",
        "train_tf = transforms.Compose([\n",
        "    transforms.Resize((input_resolution, input_resolution)),\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(\n",
        "        mean=(0.48145466, 0.4578275, 0.40821073),\n",
        "        std=(0.26862954, 0.26130258, 0.27577711)\n",
        "    )\n",
        "])\n",
        "eval_tf = train_tf\n",
        "\n",
        "# ─── 6. Dataset classes ──────────────────────────────────────────────────\n",
        "class Flickr8kTrain(Dataset):\n",
        "    def __init__(self, img_list, caps, transform=None):\n",
        "        self.transform = transform\n",
        "        self.data = [\n",
        "            (os.path.join(images_dir, img), cap)\n",
        "            for img in img_list for cap in caps[img]\n",
        "        ]\n",
        "    def __len__(self): return len(self.data)\n",
        "    def __getitem__(self, i):\n",
        "        img_path, caption = self.data[i]\n",
        "        img = Image.open(img_path).convert(\"RGB\")\n",
        "        if self.transform: img = self.transform(img)\n",
        "        txt = clip.tokenize(caption, context_length=context_length)[0]\n",
        "        return img, txt\n",
        "\n",
        "class Flickr8kEval(Dataset):\n",
        "    def __init__(self, img_list, caps, transform=None):\n",
        "        self.transform = transform\n",
        "        self.data = [\n",
        "            (os.path.join(images_dir, img), caps[img])\n",
        "            for img in img_list\n",
        "        ]\n",
        "    def __len__(self): return len(self.data)\n",
        "    def __getitem__(self, i):\n",
        "        img_path, captions = self.data[i]\n",
        "        img = Image.open(img_path).convert(\"RGB\")\n",
        "        if self.transform: img = self.transform(img)\n",
        "        return img, captions\n",
        "\n",
        "def eval_collate(batch):\n",
        "    imgs, caps = zip(*batch)\n",
        "    return torch.stack(imgs, 0), list(caps)\n",
        "\n",
        "# ─── 7. Build dataloaders ────────────────────────────────────────────────\n",
        "train_ds = Flickr8kTrain(train_imgs, caps, transform=train_tf)\n",
        "val_ds   = Flickr8kEval(  val_imgs,   caps, transform=eval_tf)\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "    train_ds, batch_size=batch_size, shuffle=True,\n",
        "    num_workers=2, pin_memory=True\n",
        ")\n",
        "val_dataloader = DataLoader(\n",
        "    val_ds, batch_size=batch_size, shuffle=False,\n",
        "    num_workers=2, pin_memory=True, collate_fn=eval_collate\n",
        ")\n",
        "\n",
        "print(f\"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WfxmIzvYCRI6",
        "outputId": "0e179fb3-3cf4-455d-da53-02ed01d58d8c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using captions file: ./flickr8k/captions.txt\n",
            "Using images dir: ./flickr8k/Images\n",
            "Train samples: 34385, Val samples: 1214\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load the Teacher Model: CLIP RN50 Model"
      ],
      "metadata": {
        "id": "FMucybFulqLQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import clip\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "\n",
        "# Load the CLIP model\n",
        "model, preprocess = clip.load(\"RN50\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "vocab_size = model.vocab_size\n",
        "\n",
        "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
        "print(\"Input resolution:\", input_resolution)\n",
        "print(\"Context length:\", context_length)\n",
        "print(\"Vocab size:\", vocab_size)\n"
      ],
      "metadata": {
        "id": "NtcJ2B3fhLfo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "36e8ac6b-9f29-4b7e-8d27-b303cac5d245"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 244M/244M [00:13<00:00, 19.4MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model parameters: 102,007,137\n",
            "Input resolution: 224\n",
            "Context length: 77\n",
            "Vocab size: 49408\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluate the Pre-Trained Teacher Model"
      ],
      "metadata": {
        "id": "Q_tN8A1Cyzzj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset\n",
        "import clip\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Load the CLIP model (Teacher)\n",
        "model, preprocess = clip.load(\"RN50\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "\n",
        "\n",
        "all_image_features = []\n",
        "all_text_features = []\n",
        "all_captions = []  # Store captions for each image in order\n",
        "\n",
        "with torch.no_grad():\n",
        "    for images, captions in val_dataloader:\n",
        "        images = images.to(device)\n",
        "        # Tokenize captions here\n",
        "        one_caption_per_image = [caps[0] for caps in captions]\n",
        "\n",
        "        texts = clip.tokenize(one_caption_per_image, context_length=context_length).to(device)\n",
        "\n",
        "        # Encode images and texts using the teacher model\n",
        "        image_feats = model.encode_image(images)\n",
        "        text_feats = model.encode_text(texts)\n",
        "\n",
        "        # Normalize\n",
        "        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)\n",
        "        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        all_image_features.append(image_feats.cpu())\n",
        "        all_text_features.append(text_feats.cpu())\n",
        "        all_captions.extend(captions)\n",
        "\n",
        "all_image_features = torch.cat(all_image_features, dim=0)  # (N, 512)\n",
        "all_text_features = torch.cat(all_text_features, dim=0)    # (N, 512)\n",
        "\n",
        "# Compute similarity matrix\n",
        "# image-to-text similarity: each image vs all texts\n",
        "sim_matrix = all_image_features @ all_text_features.t()  # (N, N)\n",
        "\n",
        "# Function to compute recall@K\n",
        "def compute_recall(sim_matrix, k=1):\n",
        "    ranks = []\n",
        "    n = sim_matrix.size(0)\n",
        "    for i in range(n):\n",
        "        # Sort texts by similarity to image i\n",
        "        sorted_indices = torch.argsort(sim_matrix[i], descending=True)\n",
        "        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()\n",
        "        ranks.append(rank)\n",
        "    ranks = torch.tensor(ranks)\n",
        "    recall = (ranks < k).float().mean().item()\n",
        "    return recall\n",
        "\n",
        "r1 = compute_recall(sim_matrix, k=1)\n",
        "r5 = compute_recall(sim_matrix, k=5)\n",
        "r10 = compute_recall(sim_matrix, k=10)\n",
        "\n",
        "print(\"Image-to-Text Retrieval:\")\n",
        "print(f\"Recall@1: {r1*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "# For text-to-image retrieval, we do the same but transpose the matrix\n",
        "sim_matrix_t2i = sim_matrix.t()  # (N, N)\n",
        "\n",
        "r1_t2i = compute_recall(sim_matrix_t2i, k=1)\n",
        "r5_t2i = compute_recall(sim_matrix_t2i, k=5)\n",
        "r10_t2i = compute_recall(sim_matrix_t2i, k=10)\n",
        "\n",
        "print(\"Text-to-Image Retrieval:\")\n",
        "print(f\"Recall@1: {r1_t2i*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5_t2i*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VsB1BHbQwFRu",
        "outputId": "bab6dbd5-5d62-404d-f0d6-8f81dd2e8a87"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Image-to-Text Retrieval:\n",
            "Recall@1: 51.65%\n",
            "Recall@5: 78.17%\n",
            "Recall@10: 87.73%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1: 47.28%\n",
            "Recall@5: 75.21%\n",
            "Recall@10: 84.60%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "from torch.utils.data import Dataset\n",
        "import clip\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Load the CLIP model (Teacher)\n",
        "model, preprocess = clip.load(\"ViT-B/16\", device)\n",
        "model.eval()\n",
        "\n",
        "input_resolution = model.visual.input_resolution\n",
        "context_length = model.context_length\n",
        "\n",
        "\n",
        "all_image_features = []\n",
        "all_text_features = []\n",
        "all_captions = []  # Store captions for each image in order\n",
        "\n",
        "with torch.no_grad():\n",
        "    for images, captions in val_dataloader:\n",
        "        images = images.to(device)\n",
        "        # Tokenize captions here\n",
        "        one_caption_per_image = [caps[0] for caps in captions]\n",
        "\n",
        "        texts = clip.tokenize(one_caption_per_image, context_length=context_length).to(device)\n",
        "\n",
        "        # Encode images and texts using the teacher model\n",
        "        image_feats = model.encode_image(images)\n",
        "        text_feats = model.encode_text(texts)\n",
        "\n",
        "        # Normalize\n",
        "        image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True)\n",
        "        text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        all_image_features.append(image_feats.cpu())\n",
        "        all_text_features.append(text_feats.cpu())\n",
        "        all_captions.extend(captions)\n",
        "\n",
        "all_image_features = torch.cat(all_image_features, dim=0)  # (N, 512)\n",
        "all_text_features = torch.cat(all_text_features, dim=0)    # (N, 512)\n",
        "\n",
        "# Compute similarity matrix\n",
        "# image-to-text similarity: each image vs all texts\n",
        "sim_matrix = all_image_features @ all_text_features.t()  # (N, N)\n",
        "\n",
        "# Function to compute recall@K\n",
        "def compute_recall(sim_matrix, k=1):\n",
        "    ranks = []\n",
        "    n = sim_matrix.size(0)\n",
        "    for i in range(n):\n",
        "        # Sort texts by similarity to image i\n",
        "        sorted_indices = torch.argsort(sim_matrix[i], descending=True)\n",
        "        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item()\n",
        "        ranks.append(rank)\n",
        "    ranks = torch.tensor(ranks)\n",
        "    recall = (ranks < k).float().mean().item()\n",
        "    return recall\n",
        "\n",
        "r1 = compute_recall(sim_matrix, k=1)\n",
        "r5 = compute_recall(sim_matrix, k=5)\n",
        "r10 = compute_recall(sim_matrix, k=10)\n",
        "\n",
        "print(\"Image-to-Text Retrieval:\")\n",
        "print(f\"Recall@1: {r1*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10*100:.2f}%\")\n",
        "\n",
        "# For text-to-image retrieval, we do the same but transpose the matrix\n",
        "sim_matrix_t2i = sim_matrix.t()  # (N, N)\n",
        "\n",
        "r1_t2i = compute_recall(sim_matrix_t2i, k=1)\n",
        "r5_t2i = compute_recall(sim_matrix_t2i, k=5)\n",
        "r10_t2i = compute_recall(sim_matrix_t2i, k=10)\n",
        "\n",
        "print(\"Text-to-Image Retrieval:\")\n",
        "print(f\"Recall@1: {r1_t2i*100:.2f}%\")\n",
        "print(f\"Recall@5: {r5_t2i*100:.2f}%\")\n",
        "print(f\"Recall@10: {r10_t2i*100:.2f}%\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NN8K4fskxRTa",
        "outputId": "26736b00-61c1-476c-fb4a-67effb0fabc1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|███████████████████████████████████████| 335M/335M [00:13<00:00, 25.3MiB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Image-to-Text Retrieval:\n",
            "Recall@1: 57.41%\n",
            "Recall@5: 82.70%\n",
            "Recall@10: 90.61%\n",
            "Text-to-Image Retrieval:\n",
            "Recall@1: 55.02%\n",
            "Recall@5: 81.63%\n",
            "Recall@10: 87.64%\n"
          ]
        }
      ]
    }
  ]
}