{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 292,
     "status": "ok",
     "timestamp": 1755204403980,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "cx4i3qG3R_QN"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "import ast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 500
    },
    "executionInfo": {
     "elapsed": 397,
     "status": "ok",
     "timestamp": 1755204413727,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "104941b7",
    "outputId": "c75cce23-d3bb-434b-ceff-1ca423d32656"
   },
   "outputs": [],
   "source": [
    "# Load the JSON file\n",
    "try:\n",
    "    with open('final_flickr_mergedGT_test.json', 'r') as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    # Check if the data has an 'images' key and it's a list\n",
    "    if 'images' in data and isinstance(data['images'], list):\n",
    "        images_data = data['images']\n",
    "        annotations_data = data['annotations']\n",
    "\n",
    "        # Normalize the 'images' data\n",
    "        # This will flatten nested structures and handle varying list lengths\n",
    "        df_images = pd.json_normalize(images_data)\n",
    "        df_annotations = pd.json_normalize(annotations_data)\n",
    "\n",
    "        print(\"JSON file loaded and normalized successfully.\")\n",
    "        display(df_images.head())\n",
    "        display(df_annotations.head())\n",
    "    else:\n",
    "        print(\"Error: The JSON structure does not contain a list under the key 'images'.\")\n",
    "\n",
    "except FileNotFoundError:\n",
    "    print(\"Error: final_flickr_mergedGT_test.json not found. Please make sure the file is in the correct directory.\")\n",
    "except Exception as e:\n",
    "    print(f\"An error occurred: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 224
    },
    "executionInfo": {
     "elapsed": 57,
     "status": "ok",
     "timestamp": 1755204415949,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "e9f9270d",
    "outputId": "cc95af1a-6d9b-4dbe-baca-926af43a4ae1"
   },
   "outputs": [],
   "source": [
    "def extract_positive_spans(row):\n",
    "    caption = row['caption']\n",
    "    tokens_positive_eval = row['tokens_positive_eval']\n",
    "    if isinstance(tokens_positive_eval, list):\n",
    "        positive_spans = [caption[start:end] for group in tokens_positive_eval for start, end in group]\n",
    "        return positive_spans\n",
    "    else:\n",
    "        return [] # Return an empty list or other appropriate value if tokens_positive_eval is not a list\n",
    "\n",
    "df_images['positive_spans'] = df_images.apply(extract_positive_spans, axis=1)\n",
    "\n",
    "print(\"New column 'positive_spans' created successfully.\")\n",
    "display(df_images[['caption', 'tokens_positive_eval', 'positive_spans']].head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 524
    },
    "executionInfo": {
     "elapsed": 23,
     "status": "ok",
     "timestamp": 1755204417274,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "ad02d2a8",
    "outputId": "ebc01cf7-ae17-478e-db45-81daa9a81a9b"
   },
   "outputs": [],
   "source": [
    "# Calculate the number of positive spans for each row\n",
    "df_images['objects_num'] = df_images['positive_spans'].apply(len)\n",
    "\n",
    "# Get the distribution of object counts\n",
    "distribution = df_images['objects_num'].value_counts().sort_index()\n",
    "\n",
    "# Convert to DataFrame for cumulative calculation\n",
    "dist_df = distribution.reset_index()\n",
    "dist_df.columns = ['objects_num', 'count']\n",
    "dist_df = dist_df.sort_values('objects_num')\n",
    "\n",
    "# Compute cumulative sum and cumulative percentage\n",
    "dist_df['cumulative'] = dist_df['count'].cumsum()\n",
    "total = dist_df['count'].sum()\n",
    "dist_df['cumulative_percent'] = dist_df['cumulative'] / total * 100\n",
    "\n",
    "# Determine cutoff where cumulative percent reaches > 95%\n",
    "# Find the first row where cumulative_percent is greater than 95%\n",
    "cutoff_row = dist_df[dist_df['cumulative_percent'] > 95].iloc[0]\n",
    "cutoff_value = cutoff_row['objects_num']\n",
    "\n",
    "print(f\"\\nCutoff where cumulative percent exceeds 95%: {cutoff_value} objects\")\n",
    "\n",
    "# Optional: display the full distribution table with cumulative info\n",
    "from IPython.display import display\n",
    "display(dist_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 542
    },
    "executionInfo": {
     "elapsed": 1121,
     "status": "ok",
     "timestamp": 1755204423817,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "db29e640",
    "outputId": "7859d920-32e5-419f-a434-0f1ad4d82054"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot the histogram\n",
    "plt.figure(figsize=(10, 6))\n",
    "counts, bin_edges, patches = plt.hist(df_images['objects_num'], bins=range(1, df_images['objects_num'].max() + 2), align='left', edgecolor='black')\n",
    "plt.xlabel('Number of Objects')\n",
    "plt.ylabel('Number of Captions')\n",
    "plt.xticks(bin_edges) # Set xticks to be the bin edges\n",
    "plt.grid(axis='y', alpha=0.75)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 2978,
     "status": "ok",
     "timestamp": 1755204918301,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "rwkjkQQfMf2i"
   },
   "outputs": [],
   "source": [
    "# Only convert tokens_positive if it's stored as string\n",
    "if isinstance(df_annotations['tokens_positive'].iloc[0], str):\n",
    "    df_annotations['tokens_positive'] = df_annotations['tokens_positive'].apply(ast.literal_eval)\n",
    "\n",
    "# Find local prompt based on the tokens and the caption\n",
    "def findLocalPrompt(tokens, caption):\n",
    "    local_prompt = [caption[start:end] for start, end in tokens]\n",
    "    return local_prompt[0]\n",
    "\n",
    "def xywh_to_xyxy(box):\n",
    "    x, y, w, h = box\n",
    "    return [x, y, x + w, y + h]\n",
    "\n",
    "def rescale_box_xyxy(box, orig_w, orig_h, target_size=512):\n",
    "    x1, y1, x2, y2 = box\n",
    "    x_scale = target_size / orig_w\n",
    "    y_scale = target_size / orig_h\n",
    "    return [\n",
    "        int(x1 * x_scale),\n",
    "        int(y1 * y_scale),\n",
    "        int(x2 * x_scale),\n",
    "        int(y2 * y_scale)\n",
    "    ]\n",
    "\n",
    "image_boxes = {}\n",
    "\n",
    "# Iterate over each image\n",
    "for _, img_row in df_images.iterrows():\n",
    "    img_id = img_row['id']\n",
    "    caption = img_row['caption']\n",
    "    eval_tokens = img_row['tokens_positive_eval']\n",
    "    orig_w = int(img_row['width'])\n",
    "    orig_h = int(img_row['height'])\n",
    "    boxes = []\n",
    "\n",
    "    # Filter annotations for the current image\n",
    "    matching_annots = df_annotations[df_annotations['image_id'] == img_id]\n",
    "\n",
    "    for _, ann_row in matching_annots.iterrows():\n",
    "        if ann_row['tokens_positive'] in eval_tokens:\n",
    "            # extract local prompt from the captions using the tokens\n",
    "            obj = findLocalPrompt(ann_row['tokens_positive'], caption)\n",
    "            # convert bbox format\n",
    "            xyxy_box = xywh_to_xyxy(ann_row['bbox'])\n",
    "            # rescale bbox from original size to 512x512 for our models compatibility\n",
    "            scaled_box = rescale_box_xyxy(xyxy_box, orig_w, orig_h)\n",
    "            boxes.append({\n",
    "                'obj': obj,\n",
    "                'bbox': scaled_box\n",
    "            })\n",
    "\n",
    "    image_boxes[img_id] = boxes\n",
    "\n",
    "# Add boxes to df_images\n",
    "df_images['boxes'] = df_images['id'].map(image_boxes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1755204918318,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "oBX9PwHjUJpQ"
   },
   "outputs": [],
   "source": [
    "# Keep only the rows with 5 objects or fewer\n",
    "df_images = df_images[df_images['boxes'].map(len) <= 8].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 25,
     "status": "ok",
     "timestamp": 1755204918861,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "K0BOaSZPNORy",
    "outputId": "d20788bb-2439-4ab7-d67d-0931d43340b9"
   },
   "outputs": [],
   "source": [
    "df_images[['id', 'caption', 'boxes']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11,
     "status": "ok",
     "timestamp": 1755204921896,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "k4dvUMmXNjYw",
    "outputId": "a3c1971a-3fe1-4572-cd19-b43c954e0486"
   },
   "outputs": [],
   "source": [
    "print(df_images['boxes'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1755204922491,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "WJ2lQ82hSuXX"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "import math\n",
    "import random\n",
    "import textwrap\n",
    "\n",
    "\n",
    "def visualize_df_boxes_grid(df_images, num_samples=35, boxes_per_row=5, random_seed=42):\n",
    "    # Sample rows with at least one box\n",
    "    sampled_rows = df_images[df_images['boxes'].map(lambda x: len(x) > 0)].sample(n=num_samples, random_state=random_seed)\n",
    "\n",
    "    num_cols = boxes_per_row\n",
    "    num_rows = math.ceil(num_samples / num_cols)\n",
    "\n",
    "    fig, axes = plt.subplots(num_rows, num_cols, figsize=(boxes_per_row * 4, num_rows * 4))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for ax, (_, row) in zip(axes, sampled_rows.iterrows()):\n",
    "        boxes = row['boxes']\n",
    "        caption = row['caption']\n",
    "        image_id = row['id']\n",
    "\n",
    "        # Prepare plot\n",
    "        ax.set_xlim(0, 512)\n",
    "        ax.set_ylim(0, 512)\n",
    "        ax.invert_yaxis()\n",
    "        ax.set_aspect('equal')\n",
    "\n",
    "        # Plot boxes\n",
    "        for box in boxes:\n",
    "            x1, y1, x2, y2 = box['bbox']\n",
    "            width = x2 - x1\n",
    "            height = y2 - y1\n",
    "            rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='red', facecolor='none')\n",
    "            ax.add_patch(rect)\n",
    "            ax.text(x1 + 3, y1 - 5, box['obj'], color='blue', fontsize=8)\n",
    "\n",
    "        # Wrap long captions\n",
    "        wrapped_caption = textwrap.fill(caption, width=50)\n",
    "        ax.set_title(f\"ID {image_id}\\n{wrapped_caption}\", fontsize=9)\n",
    "        ax.grid(True)\n",
    "\n",
    "    # Hide unused axes\n",
    "    for ax in axes[len(sampled_rows):]:\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 819
    },
    "executionInfo": {
     "elapsed": 2941,
     "status": "ok",
     "timestamp": 1753430208095,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "KFvlICtCTGUE",
    "outputId": "cbf4b11b-cbe2-4a30-e17c-aac178f6988e"
   },
   "outputs": [],
   "source": [
    "visualize_df_boxes_grid(df_images, num_samples=16, boxes_per_row=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 23,
     "status": "ok",
     "timestamp": 1755204925430,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "f04530c6",
    "outputId": "7b111ae6-4616-46e3-89ef-eb40962e5667"
   },
   "outputs": [],
   "source": [
    "# Filter the DataFrame to get rows with exactly 8 objects\n",
    "df_8_objects = df_images[df_images['boxes'].map(len) == 8]\n",
    "\n",
    "# Iterate through the filtered DataFrame and print the full caption for each row\n",
    "for index, row in df_8_objects.iterrows():\n",
    "    print(f\"ID: {row['id']}\")\n",
    "    print(f\"Caption: {row['caption']}\")\n",
    "    print(f\"Boxes: {row['boxes']}\")\n",
    "    print(\"-\" * 50) # Print a separator line for clarity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a0ef142d"
   },
   "source": [
    "# Task\n",
    "Sample the dataframe `df_images` to approximately 3000 rows, maintaining the original distribution of the 'objects_num' column. Display the head of the sampled dataframe and the total number of rows."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7f90d82d"
   },
   "source": [
    "## Calculate original distribution\n",
    "\n",
    "### Subtask:\n",
    "Determine the current distribution of 'objects_num' in `df_images`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e6b93949"
   },
   "source": [
    "**Reasoning**:\n",
    "The subtask is to determine the current distribution of 'objects_num' in `df_images`. This can be achieved by calculating the value counts of the 'objects_num' column and storing it in `original_distribution`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1755204928150,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "3682fcef",
    "outputId": "533800e1-c14e-4c92-92da-e25b22f58e1d"
   },
   "outputs": [],
   "source": [
    "original_distribution = df_images['objects_num'].value_counts().sort_index()\n",
    "print(\"Original distribution of 'objects_num':\")\n",
    "print(original_distribution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3e6a2af2"
   },
   "source": [
    "## Determine target counts\n",
    "\n",
    "### Subtask:\n",
    "Calculate the number of rows to sample for each 'objects_num' category to achieve a total of approximately 3000 rows while maintaining the original proportions.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "35620a04"
   },
   "source": [
    "**Reasoning**:\n",
    "Calculate the target number of rows for each `objects_num` category based on the original distribution to achieve a total of approximately 3000 rows.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1755204929738,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "3feddfca",
    "outputId": "76bb7d58-074a-4a23-8af8-c691501889ae"
   },
   "outputs": [],
   "source": [
    "# Calculate the total number of rows in the original distribution\n",
    "total_original_rows = original_distribution.sum()\n",
    "\n",
    "# Define the target number of rows\n",
    "target_total_rows = 3300\n",
    "\n",
    "# Calculate the sampling factor\n",
    "sampling_factor = (target_total_rows / total_original_rows).round(2)\n",
    "\n",
    "# Calculate the target counts for each category, ensuring they are integers\n",
    "target_counts = (original_distribution * sampling_factor).round().astype(int)\n",
    "\n",
    "print(\"Target number of rows per 'objects_num' category for sampling:\")\n",
    "print(target_counts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b0059e2d"
   },
   "source": [
    "## Sample based on target counts\n",
    "\n",
    "### Subtask:\n",
    "Sample the `df_images` DataFrame, taking the calculated number of rows for each 'objects_num' category.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "12a437a1"
   },
   "source": [
    "**Reasoning**:\n",
    "Sample the DataFrame based on the calculated target counts for each category to maintain the distribution.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 610
    },
    "executionInfo": {
     "elapsed": 82,
     "status": "ok",
     "timestamp": 1755204931171,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "d2dadaaf",
    "outputId": "70aa7c25-86ca-4350-b2c6-84cee99ed990"
   },
   "outputs": [],
   "source": [
    "sampled_dfs = []\n",
    "\n",
    "for obj_num, count in target_counts.items():\n",
    "    # Filter rows for the current objects_num category\n",
    "    category_df = df_images[df_images['objects_num'] == obj_num]\n",
    "\n",
    "    # Sample 'count' rows, or all available rows if 'count' is larger\n",
    "    if len(category_df) >= count:\n",
    "        sampled_category_df = category_df.sample(n=count, random_state=42) # Use random_state for reproducibility\n",
    "    else:\n",
    "        sampled_category_df = category_df # Take all rows if not enough\n",
    "\n",
    "    sampled_dfs.append(sampled_category_df)\n",
    "\n",
    "# Concatenate the sampled dataframes\n",
    "df_images_sampled = pd.concat(sampled_dfs).reset_index(drop=True)\n",
    "\n",
    "print(\"DataFrame sampled successfully, maintaining the original distribution.\")\n",
    "display(df_images_sampled.head())\n",
    "print(f\"\\nTotal number of rows in the sampled DataFrame: {len(df_images_sampled)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fb1ef69f"
   },
   "source": [
    "## Verify sampled distribution\n",
    "\n",
    "### Subtask:\n",
    "Check the distribution of 'objects_num' in the new sampled DataFrame to confirm it closely matches the original distribution.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ccacd893"
   },
   "source": [
    "**Reasoning**:\n",
    "Calculate and print the distribution of 'objects_num' in the sampled DataFrame and compare it to the original distribution.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1755204933509,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "1b3f4608",
    "outputId": "72dbceda-27c8-4d97-c85c-58535ce9d857"
   },
   "outputs": [],
   "source": [
    "sampled_distribution = df_images_sampled['objects_num'].value_counts().sort_index()\n",
    "print(\"Sampled distribution of 'objects_num':\")\n",
    "print(sampled_distribution)\n",
    "\n",
    "print(\"\\nOriginal distribution of 'objects_num':\")\n",
    "print(original_distribution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 564
    },
    "executionInfo": {
     "elapsed": 195,
     "status": "ok",
     "timestamp": 1755204997858,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "rVLzzb62W998",
    "outputId": "58e6afa7-e7e4-4987-e93e-23d28141a8a5"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "counts, bin_edges, patches = plt.hist(df_images_sampled['objects_num'], bins=range(1, df_images_sampled['objects_num'].max() + 2), align='left', edgecolor='black')\n",
    "plt.xlabel('Number of Objects')\n",
    "plt.ylabel('Number of Captions')\n",
    "plt.title('Distribution of Objects in Sampled DataFrame')\n",
    "plt.xticks(bin_edges) # Set xticks to be the bin edges\n",
    "plt.grid(axis='y', alpha=0.75)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5db9f8ab"
   },
   "source": [
    "# Task\n",
    "Export the sampled dataset to a CSV file named \"sampled_prompts.csv\" with the following columns: 'id' (0-padded to 4 digits), 'original_id', 'category' (always 'open_set'), 'prompt', and dynamically generated columns for objects and bounding boxes (e.g., 'obj1', 'bbox1', 'obj2', 'bbox2', etc.) based on the number of objects in each prompt."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a404800e"
   },
   "source": [
    "## Determine maximum objects\n",
    "\n",
    "### Subtask:\n",
    "Find the maximum number of objects in the 'boxes' column of the sampled DataFrame.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9ff16b61"
   },
   "source": [
    "**Reasoning**:\n",
    "Calculate the maximum number of objects in the 'boxes' column of the sampled DataFrame and store it in a variable named `max_objects`. Then print `max_objects`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1753430208145,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "84f2839b",
    "outputId": "230e212d-26d2-4ed2-c6ee-856f7239c431"
   },
   "outputs": [],
   "source": [
    "max_objects = df_images_sampled['boxes'].apply(len).max()\n",
    "print(f\"Maximum number of objects in the sampled DataFrame: {max_objects}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "514e159f"
   },
   "source": [
    "## Restructure data\n",
    "\n",
    "Create a new DataFrame with the required columns ('id', 'original_id', 'category', 'prompt') and dynamically generated columns for objects and bounding boxes (e.g., 'obj1', 'bbox1', 'obj2', 'bbox2', etc.) based on the maximum number of objects.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "25bfe7c4"
   },
   "source": [
    "**Reasoning**:\n",
    "Create the list of required columns and dynamically generate the object and bbox column names, then create an empty DataFrame with these columns.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 91
    },
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1753430208149,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "32a776d1",
    "outputId": "2d6100d8-ba9f-44cc-a0b7-030d9103ef7c"
   },
   "outputs": [],
   "source": [
    "# Define base columns\n",
    "base_columns = ['id', 'original_id', 'category', 'prompt']\n",
    "\n",
    "# Dynamically generate object and bbox column names\n",
    "object_bbox_columns = []\n",
    "for i in range(1, max_objects + 1):\n",
    "    object_bbox_columns.append(f'obj{i}')\n",
    "    object_bbox_columns.append(f'bbox{i}')\n",
    "\n",
    "# Combine all column names\n",
    "all_columns = base_columns + object_bbox_columns\n",
    "\n",
    "# Create an empty DataFrame with the defined columns\n",
    "export_df = pd.DataFrame(columns=all_columns)\n",
    "\n",
    "print(\"Empty DataFrame with required columns created successfully.\")\n",
    "display(export_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "37c712ed"
   },
   "source": [
    "Populate the `export_df` with data from `df_images_sampled`, ensuring correct data mapping and handling of missing object/bbox values for rows with fewer than `max_objects`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 556
    },
    "executionInfo": {
     "elapsed": 334,
     "status": "ok",
     "timestamp": 1753430208483,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "a8e16b04",
    "outputId": "de842b31-c1a0-4b4e-c634-bc54962557d0"
   },
   "outputs": [],
   "source": [
    "# Create an empty list to store the data for the new DataFrame\n",
    "export_data = []\n",
    "\n",
    "# Iterate over each row in the sampled DataFrame\n",
    "for index, row in df_images_sampled.iterrows():\n",
    "    # The 'prompt' column will now only contain the original caption\n",
    "    formatted_prompt = row['caption']\n",
    "    boxes = row['boxes']\n",
    "\n",
    "    row_data = {\n",
    "        'id': str(index).zfill(4),  # 0-pad index to 4 digits\n",
    "        'original_id': str(row['id']),\n",
    "        'category': 'open_set',      # Set category to 'open_set'\n",
    "        'prompt': formatted_prompt\n",
    "    }\n",
    "\n",
    "    # Add object and bbox data to their respective columns\n",
    "    for i in range(max_objects):\n",
    "        if i < len(boxes):\n",
    "            row_data[f'obj{i+1}'] = boxes[i]['obj']\n",
    "            # Convert bbox list to a comma-separated string for bbox columns\n",
    "            row_data[f'bbox{i+1}'] = \",\".join(map(str, boxes[i]['bbox']))\n",
    "        else:\n",
    "            row_data[f'obj{i+1}'] = None # Use None for missing objects\n",
    "            row_data[f'bbox{i+1}'] = None # Use None for missing bboxes\n",
    "\n",
    "\n",
    "    export_data.append(row_data)\n",
    "\n",
    "# Create the new DataFrame from the collected data\n",
    "export_df = pd.DataFrame(export_data, columns=all_columns)\n",
    "\n",
    "print(\"DataFrame populated with sampled data.\")\n",
    "display(export_df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "13907084"
   },
   "source": [
    "## Export to csv\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 46,
     "status": "ok",
     "timestamp": 1753430208530,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "90a72509",
    "outputId": "ef6aa542-2b77-4a4e-97d2-7a6a4444e136"
   },
   "outputs": [],
   "source": [
    "export_df.to_csv(\"sampled_prompts.csv\", index=False)\n",
    "print(\"DataFrame saved to sampled_prompts.csv successfully.\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyMfeksVYlk115ZF7ZsBfrv5",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
