{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating Balanced Image Datasets from COCO\n",
    "\n",
    "This notebook creates balanced datasets from COCO val2017 by:\n",
    "1. Finding images with significant coverage (>5%) of specific categories\n",
    "2. Pairing each with 99 random images that don't contain that category\n",
    "3. Storing the indices to create balanced datasets (1:99 ratio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pycocotools.coco import COCO\n",
    "import skimage.io as io\n",
    "import random\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Set Up Paths & Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Directory containing COCO dataset - update these paths to match your local setup\n",
    "dataDir = '/your-path'\n",
    "dataType = 'val2017'\n",
    "annFile = f'/your-path/coco_annotations/instances_{dataType}.json'\n",
    "imgDir = f'{dataDir}/{dataType}/'\n",
    "\n",
    "# Coverage threshold (as a percentage of image area)\n",
    "coverage_threshold = 5.0  # Minimum coverage percentage\n",
    "\n",
    "# Number of negative examples to select for each positive example\n",
    "num_negative_examples = 99  # 99 negative examples + 1 positive = 100 total\n",
    "\n",
    "# Random seed for reproducibility\n",
    "random_seed = 42\n",
    "random.seed(random_seed)\n",
    "np.random.seed(random_seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Initialize COCO API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize COCO API for instance annotations\n",
    "coco = COCO(annFile)\n",
    "\n",
    "# Display dataset info\n",
    "print(f\"COCO {dataType} dataset loaded successfully!\")\n",
    "print(f\"Number of images: {len(coco.imgs)}\")\n",
    "print(f\"Number of categories: {len(coco.cats)}\")\n",
    "print(f\"Number of annotations: {len(coco.anns)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Get All Categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get all categories\n",
    "categories = coco.loadCats(coco.getCatIds())\n",
    "print(f\"COCO has {len(categories)} categories:\")\n",
    "\n",
    "# Display categories in a more readable format\n",
    "for i, cat in enumerate(categories):\n",
    "    print(f\"{i+1}. ID: {cat['id']}, Name: {cat['name']}, Supercategory: {cat['supercategory']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Define Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_category_coverage(coco, img_id, cat_id):\n",
    "    \"\"\"\n",
    "    Calculate what percentage of the image area is covered by a specific category.\n",
    "    \n",
    "    Args:\n",
    "        coco: COCO API instance\n",
    "        img_id: Image ID\n",
    "        cat_id: Category ID\n",
    "        \n",
    "    Returns:\n",
    "        float: Coverage percentage (0-100)\n",
    "        list: Annotations for the category in this image\n",
    "    \"\"\"\n",
    "    # Get image info\n",
    "    img_info = coco.loadImgs(img_id)[0]\n",
    "    image_area = img_info['width'] * img_info['height']\n",
    "    \n",
    "    # Get annotations for this category in this image\n",
    "    ann_ids = coco.getAnnIds(imgIds=img_id, catIds=cat_id)\n",
    "    anns = coco.loadAnns(ann_ids)\n",
    "    \n",
    "    if not anns:\n",
    "        return 0.0, []\n",
    "    \n",
    "    # Calculate total area covered by annotations\n",
    "    total_area = sum(ann['area'] for ann in anns)\n",
    "    \n",
    "    # Calculate coverage percentage\n",
    "    coverage_percent = (total_area / image_area) * 100\n",
    "    \n",
    "    return coverage_percent, anns\n",
    "\n",
    "def visualize_sample(coco, dataset_info, category_name, num_samples=3):\n",
    "    \"\"\"\n",
    "    Visualize sample images from a dataset.\n",
    "    \n",
    "    Args:\n",
    "        coco: COCO API instance\n",
    "        dataset_info: Dictionary with dataset information\n",
    "        category_name: Name of the category\n",
    "        num_samples: Number of samples to visualize\n",
    "    \"\"\"\n",
    "    if not dataset_info['positive_examples']:\n",
    "        print(f\"No positive examples found for category: {category_name}\")\n",
    "        return\n",
    "    \n",
    "    # Get category ID\n",
    "    cat_id = dataset_info['category_id']\n",
    "    \n",
    "    # Select a few positive examples\n",
    "    samples = min(num_samples, len(dataset_info['positive_examples']))\n",
    "    positive_samples = random.sample(dataset_info['positive_examples'], samples)\n",
    "    \n",
    "    # Select an equal number of negative examples\n",
    "    negative_samples = random.sample(dataset_info['negative_examples'], samples)\n",
    "    \n",
    "    # Create figure with subplots\n",
    "    fig, axes = plt.subplots(samples, 2, figsize=(12, 5*samples))\n",
    "    if samples == 1:\n",
    "        axes = axes.reshape(1, 2)\n",
    "    \n",
    "    for i in range(samples):\n",
    "        # Positive example\n",
    "        pos_img_id = positive_samples[i]\n",
    "        pos_img_info = coco.loadImgs(pos_img_id)[0]\n",
    "        pos_img_path = os.path.join(imgDir, pos_img_info['file_name'])\n",
    "        pos_img = io.imread(pos_img_path)\n",
    "        \n",
    "        # Get annotations for the category\n",
    "        pos_ann_ids = coco.getAnnIds(imgIds=pos_img_id, catIds=cat_id)\n",
    "        pos_anns = coco.loadAnns(pos_ann_ids)\n",
    "        \n",
    "        # Calculate coverage\n",
    "        pos_coverage, _ = calculate_category_coverage(coco, pos_img_id, cat_id)\n",
    "        \n",
    "        # Plot positive example\n",
    "        axes[i, 0].imshow(pos_img)\n",
    "        axes[i, 0].set_title(f\"Positive Example\\nID: {pos_img_id}, Coverage: {pos_coverage:.1f}%\")\n",
    "        axes[i, 0].axis('off')\n",
    "        \n",
    "        # Draw annotations on positive example\n",
    "        for ax in fig.axes:\n",
    "            if ax == axes[i, 0]:\n",
    "                coco.showAnns(pos_anns, draw_bbox=True)\n",
    "        \n",
    "        # Negative example\n",
    "        neg_img_id = negative_samples[i]\n",
    "        neg_img_info = coco.loadImgs(neg_img_id)[0]\n",
    "        neg_img_path = os.path.join(imgDir, neg_img_info['file_name'])\n",
    "        neg_img = io.imread(neg_img_path)\n",
    "        \n",
    "        # Plot negative example\n",
    "        axes[i, 1].imshow(neg_img)\n",
    "        axes[i, 1].set_title(f\"Negative Example\\nID: {neg_img_id}, No {category_name}\")\n",
    "        axes[i, 1].axis('off')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.suptitle(f\"{category_name} Examples (Coverage Threshold: {coverage_threshold}%)\", \n",
    "                 fontsize=16, y=1.02)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Process All Categories to Create Balanced Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dictionary to store all datasets\n",
    "new_data = []\n",
    "# Process each category\n",
    "for category in tqdm(categories, desc=\"Processing Categories\"):\n",
    "    cat_id = category['id']\n",
    "    cat_name = category['name']\n",
    "    if category['supercategory'] == 'animal':\n",
    "        continue\n",
    "    \n",
    "    # Initialize dataset info for this category\n",
    "    dataset_info = {\n",
    "        'category_id': cat_id,\n",
    "        'category_name': cat_name,\n",
    "        'positive_examples': [],\n",
    "        'negative_examples': [],\n",
    "        'dataset_indices': []\n",
    "    }\n",
    "    \n",
    "    # Get all image IDs containing this category\n",
    "    img_ids_with_category = coco.getImgIds(catIds=cat_id)\n",
    "    \n",
    "    # For each image containing the category, check if coverage > threshold\n",
    "    images_with_significant_coverage = []\n",
    "    \n",
    "    for img_id in img_ids_with_category:\n",
    "        coverage, anns = calculate_category_coverage(coco, img_id, cat_id)\n",
    "        if coverage >= coverage_threshold:\n",
    "            images_with_significant_coverage.append(img_id)\n",
    "            dataset_info['positive_examples'].append(img_id)\n",
    "    \n",
    "    # Get all image IDs that do NOT contain this category\n",
    "    all_img_ids = list(coco.imgs.keys())\n",
    "    img_ids_without_category = list(set(all_img_ids) - set(img_ids_with_category))\n",
    "    dataset_info['negative_examples'] = img_ids_without_category\n",
    "    \n",
    "    # For each positive example, randomly select negative examples\n",
    "    for i, pos_img_id in enumerate(images_with_significant_coverage):\n",
    "        # If we don't have enough negative examples, use all available with replacement\n",
    "        if len(img_ids_without_category) < num_negative_examples:\n",
    "            neg_img_ids = random.choices(img_ids_without_category, k=num_negative_examples)\n",
    "        else:\n",
    "            neg_img_ids = random.sample(img_ids_without_category, num_negative_examples)\n",
    "        \n",
    "        # Create dataset with 1 positive + 99 negative examples\n",
    "        dataset_indices = [pos_img_id] + neg_img_ids\n",
    "\n",
    "        new_item = {}\n",
    "        new_item['qry_text'] = f\"Find me an image that contains any {cat_name}.\\n\" # the scene of ...\n",
    "        new_item['qry_img_path'] = ''\n",
    "        new_item['tgt_text'] = \"<|image_1|> Represent the given image.\"\n",
    "        new_item['tgt_img_path'] = [\"val2017/{:012d}.jpg\".format(img_id) for img_id in dataset_indices]\n",
    "        new_data.append(new_item)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(new_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('COCO_object_retrieval.json', 'w') as f:\n",
    "    json.dump(new_data, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "# Test whether we can load it using load_dataset\n",
    "new_eval_data = load_dataset('json', \n",
    "                      data_files='COCO_object_retrieval.json',\n",
    "                      split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define animal and non-living object categories\n",
    "animal_categories = ['cat', 'dog', 'bird', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe']\n",
    "\n",
    "# Filter the dataset\n",
    "animal_data = []\n",
    "non_living_data = []\n",
    "\n",
    "for item in new_eval_data:\n",
    "    if 'person' in item['qry_text']:\n",
    "        continue  # Skip rows related to the \"person\" category\n",
    "    if 'teddy bear' in item['qry_text'] or 'hot dog' in item['qry_text']:\n",
    "        non_living_data.append(item)\n",
    "    elif any(animal in item['qry_text'] for animal in animal_categories):\n",
    "        animal_data.append(item)\n",
    "    else:\n",
    "        non_living_data.append(item)\n",
    "\n",
    "# Save the filtered data to JSON files\n",
    "with open('COCO_animal_retrieval.json', 'w') as f:\n",
    "    json.dump(animal_data, f, indent=4)\n",
    "\n",
    "with open('COCO_object_retrieval_new.json', 'w') as f:\n",
    "    json.dump(non_living_data, f, indent=4)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
