{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1750689244042,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "REIFjR2qUwKh"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import csv\n",
    "import itertools\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "prompts_per_diff_level = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ruUd9GfJjgZx"
   },
   "source": [
    "# Create the prompts and the boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 74
    },
    "executionInfo": {
     "elapsed": 4488,
     "status": "ok",
     "timestamp": 1750689248529,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "pzEFHgwJluqs",
    "outputId": "36329e12-1c02-4cd3-9ddf-1d199ae6614e"
   },
   "outputs": [],
   "source": [
    "# Step 1: upload the csv file containing the complex compositions prompts\n",
    "from google.colab import files\n",
    "uploaded = files.upload()\n",
    "\n",
    "# Load the uploaded CSV into a DataFrame\n",
    "filename = list(uploaded.keys())[0]  # get the uploaded file name\n",
    "complex_prompts_df = pd.read_csv(filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qLDzAytRLiaY"
   },
   "source": [
    "COCO classes extended to a total of 128\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1750689248544,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "xUchDn19HTQv"
   },
   "outputs": [],
   "source": [
    "obj = ['rose', 'oak', 'beetle', 'skyscraper', 'tree', 'baby', 'bed', 'lamp', 'dog', 'laptop', 'bicycle', 'person',\n",
    "       'car', 'bus', 'cat', 'book', 'chair', 'boy', 'couch', 'table', 'plant', 'toilet', 'cellphone', 'microwave',\n",
    "       'sheep', 'boat', 'banana', 'stop sign', 'donut', 'cow', 'clock', 'bottle', 'umbrella', 'bird', 'guitar',\n",
    "       'toothbrush', 'parking meter', 'bench', 'platypus', 'keyboard', 'baseball bat', 'vase', 'surfboard', 'tiger',\n",
    "       'train', 'flower', 'sandwich', 'spoon', 'pizza', 'carrot', 'teddy bear', 'hot-dog', 'skateboard', 'kite', 'broom',\n",
    "       'apple', 'handbag', 'horse', 'snowboard', 'giraffe', 'tie', 'shower', 'traffic light', 'bear', 'toaster', 'knife',\n",
    "       'baseball glove', 'crocodile', 'suitcase', 'fork', 'cake', 'cup', 'bowl', 'hair drier', 'elephant', 'mouse',\n",
    "       'mushroom', 'motorcycle', 'turtle', 'tennis racket', 'truck', 'zebra', 'fire hydrant', 'oven', 'sink', 'frisbee',\n",
    "       'hat', 'ruler', 'shoe', 'ball', 'candle', 'ladder', 'charger', 'mug', 'tape', 'shirt', 'pillow', 'pan', 'plate',\n",
    "       'shampoo', 'hammer', 'blender', 'basket', 'screwdriver', 'wallet', 'bin', 'leaf', 'bucket', 'monitor', 'watch',\n",
    "       'flashlight', 'sock', 'door', 'scarf', 'speaker', 'desk', 'backpack', 'printer', 'remote', 'glass', 'curtain',\n",
    "       'toolbox', 'drill', 'notebook', 'television', 'soap', 'ring', 'refrigerator']\n",
    "\n",
    "obj_with_articles = [\n",
    "    'a rose', 'an oak', 'a beetle', 'a skyscraper', 'a tree', 'a baby', 'a bed', 'a lamp', 'a dog', 'a laptop',\n",
    "    'a bicycle', 'a person', 'a car', 'a bus', 'a cat', 'a book', 'a chair', 'a boy', 'a couch', 'a table',\n",
    "    'a plant', 'a toilet', 'a cellphone', 'a microwave', 'a sheep', 'a boat', 'a banana', 'a stop sign',\n",
    "    'a donut', 'a cow', 'a clock', 'a bottle', 'an umbrella', 'a bird', 'a guitar', 'a toothbrush', 'a parking meter',\n",
    "    'a bench', 'a platypus', 'a keyboard', 'a baseball bat', 'a vase', 'a surfboard', 'a tiger', 'a train', 'a flower',\n",
    "    'a sandwich', 'a spoon', 'a pizza', 'a carrot', 'a teddy bear', 'an hot-dog', 'a skateboard', 'a kite', 'a broom',\n",
    "    'an apple', 'a handbag', 'a horse', 'a snowboard', 'a giraffe', 'a tie', 'a shower', 'a traffic light', 'a bear',\n",
    "    'a toaster', 'a knife', 'a baseball glove', 'a crocodile', 'a suitcase', 'a fork', 'a cake', 'a cup', 'a bowl',\n",
    "    'a hair drier', 'an elephant', 'a mouse', 'a mushroom', 'a motorcycle', 'a turtle', 'a tennis racket', 'a truck',\n",
    "    'a zebra', 'a fire hydrant', 'an oven', 'a sink', 'a frisbee', 'a hat', 'a ruler', 'a shoe', 'a ball', 'a candle',\n",
    "    'a ladder', 'a charger', 'a mug', 'a tape', 'a shirt', 'a pillow', 'a pan', 'a plate', 'a shampoo', 'a hammer',\n",
    "    'a blender', 'a basket', 'a screwdriver', 'a wallet', 'a bin', 'a leaf', 'a bucket', 'a monitor', 'a watch',\n",
    "    'a flashlight', 'a sock', 'a door', 'a scarf', 'a speaker', 'a desk', 'a backpack', 'a printer', 'a remote',\n",
    "    'a glass', 'a curtain', 'a toolbox', 'a drill', 'a notebook', 'a television', 'a soap', 'a ring', 'a refrigerator'\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1750689248546,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "Y8M6pDn-zvhk"
   },
   "outputs": [],
   "source": [
    "colors = ['black', 'blue', 'brown', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow', 'orange']\n",
    "\n",
    "spatial_relations = ['above', 'below', 'beside', 'far from', 'near', 'next to', 'on', 'over', 'to the left of', 'to the right of', 'under']\n",
    "\n",
    "\n",
    "\n",
    "attributes = [\n",
    "    'aggressive', 'black', 'blue', 'bright', 'clean', 'crowded', 'dark', 'fast', 'fluffy', 'fuzzy', 'green', 'happy', 'large', 'pink', 'red', 'rotten',\n",
    "    'rough', 'shiny', 'short', 'silver', 'small', 'smooth', 'snowy', 'soft', 'tall', 'warm', 'white', 'wooden', 'yellow'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ejQsl3kGy4op"
   },
   "source": [
    "For one object prompts, the size of the box can be bigger since there is no risk of overlapping.\n",
    "\n",
    "The more objects are in the prompt, the smaller the boxes should be to avoid overlapping."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1750689248560,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "CbvyrmbrLvZk"
   },
   "outputs": [],
   "source": [
    "# image size\n",
    "IMAGE_SIZE = 512\n",
    "\n",
    "# box size ranges depending on number of objects\n",
    "NO_OVERLAPPING_RANGES_OLD = {\n",
    "    1: (150, 350),  # larger boxes\n",
    "    2: (120, 250),\n",
    "    3: (100, 180),\n",
    "    4: (80, 150),   # smaller boxes\n",
    "}\n",
    "\n",
    "NO_OVERLAPPING_RANGES = {\n",
    "    1: (150, 500),  # larger boxes\n",
    "    2: (120, 250),\n",
    "    3: (100, 180),\n",
    "    4: (80, 150),   # smaller boxes\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c04dBopGxPjc"
   },
   "source": [
    "**Overlap check** ⬇️:\n",
    "\n",
    "Each box is defined by:\n",
    "*   top-left corner → (x_min, y_min)\n",
    "*   bottom-right corner → (x_max, y_max)\n",
    "\n",
    "Two boxes **do NOT overlap** if one of these is true:\n",
    "* Box1 is completely to the left of Box2 → x1_max < x2_min\n",
    "* Box1 is completely to the right of Box2 → x2_max < x1_min\n",
    "* Box1 is completely above Box2 → y1_max < y2_min\n",
    "* Box1 is completely below Box2 → y2_max < y1_min\n",
    "\n",
    "When all of these are false, the two boxes overlap.*italicised text*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 48,
     "status": "ok",
     "timestamp": 1750689248608,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "tXlJXmAQU3Yu"
   },
   "outputs": [],
   "source": [
    "# overlap check\n",
    "def boxes_overlap(box1, box2):\n",
    "    x1_min, y1_min, x1_max, y1_max = box1\n",
    "    x2_min, y2_min, x2_max, y2_max = box2\n",
    "    return not (x1_max < x2_min or x2_max < x1_min or y1_max < y2_min or y2_max < y1_min)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yiT_EPWOzXyd"
   },
   "source": [
    "**Random box generation tools** ⬇️:\n",
    "\n",
    "First width and height of the box are generated randomly between the correct sizes based on the number of object there will be in the prompt.\n",
    "\n",
    "Then the box is positioned inside the space by choosing randomly the top-left corner according to the height and width."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1750689248609,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "E7f3BOzPU7sb"
   },
   "outputs": [],
   "source": [
    "# generate box\n",
    "def generate_random_box(min_size, max_size):\n",
    "    width = random.randint(min_size, max_size)\n",
    "    height = random.randint(min_size, max_size)\n",
    "    x1 = random.randint(0, IMAGE_SIZE - width)\n",
    "    y1 = random.randint(0, IMAGE_SIZE - height)\n",
    "    x3 = x1 + width\n",
    "    y3 = y1 + height\n",
    "    return (x1, y1, x3, y3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1750689248610,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "CaEjOhJ-VFlG"
   },
   "outputs": [],
   "source": [
    "# build prompt string: fixed structure based on the number of objects\n",
    "def build_prompt(objs):\n",
    "    if len(objs) == 1:\n",
    "        return objs[0]\n",
    "    elif len(objs) == 2:\n",
    "        return f'{objs[0]} and {objs[1]}'\n",
    "    elif len(objs) == 3:\n",
    "        return f'{objs[0]}, {objs[1]} and {objs[2]}'\n",
    "    elif len(objs) == 4:\n",
    "        return f'{objs[0]}, {objs[1]}, {objs[2]} and {objs[3]}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1750689248610,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "80sDbk_nU-k9"
   },
   "outputs": [],
   "source": [
    "# generate N non-overlapping boxes\n",
    "def generate_non_overlapping_boxes(num_boxes, min_size, max_size):\n",
    "    boxes = []\n",
    "    attempts = 0\n",
    "    max_attempts = 1000\n",
    "    while len(boxes) < num_boxes and attempts < max_attempts:\n",
    "        new_box = generate_random_box(min_size, max_size)\n",
    "        if all(not boxes_overlap(new_box, existing_box) for existing_box in boxes):\n",
    "            boxes.append(new_box)\n",
    "        attempts += 1\n",
    "    # if after 1000 times still not overlapping bounging boxes were generated, raise an error\n",
    "    if len(boxes) < num_boxes:\n",
    "        raise RuntimeError(\"Failed to generate non-overlapping boxes after many attempts.\")\n",
    "    return boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EXbFg5ss6MNs"
   },
   "source": [
    "### Object binding functions\n",
    "\n",
    "With no overlapping bboxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 9,
     "status": "ok",
     "timestamp": 1750689248611,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "I6Ro-Of_VJLi"
   },
   "outputs": [],
   "source": [
    "def generate_csv_object_binding(output_path, id_counter):\n",
    "    mode = 'w' if id_counter == 0 else 'a'\n",
    "\n",
    "    with open(output_path, mode=mode, newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        # Write header only if starting fresh\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt', 'obj1', 'bbox1', 'obj2', 'bbox2', 'obj3', 'bbox3', 'obj4', 'bbox4'])\n",
    "\n",
    "\n",
    "        # for each number of objects\n",
    "        for num_objs in [1, 2, 3, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects...\")\n",
    "            min_size, max_size = NO_OVERLAPPING_RANGES[num_objs]\n",
    "\n",
    "            # prepare unique combinations of objects: they should not repeat in the same prompt\n",
    "            if num_objs == 1:\n",
    "                combinations = [[obj] for obj in obj_with_articles]\n",
    "            else:\n",
    "                # itertools.combinations only generates combinations of unique objects (no repetition within a tuple)\n",
    "                combinations = list(itertools.combinations(obj_with_articles, num_objs))\n",
    "\n",
    "            # Shuffle combinations to ensure that when we repeat and truncate them to 512 prompts,\n",
    "            # the resulting prompts are more varied and not biased toward always starting with the same combinations.\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            # how many times do we need to repeat combinations to reach 512 prompts per number of objects\n",
    "            total_needed = prompts_per_diff_level\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            # tqdm is to create the progress bar\n",
    "            for objs in tqdm(used_combinations):\n",
    "                boxes = generate_non_overlapping_boxes(num_objs, min_size, max_size)\n",
    "\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'object_binding'\n",
    "                prompt = build_prompt(objs)\n",
    "                row = [id_str, category, prompt]\n",
    "\n",
    "                # pair together objects + boxes\n",
    "                for obj, box in itertools.zip_longest(objs, boxes, fillvalue=''):\n",
    "                    row.append(obj)\n",
    "                    # write the corresponding bounding box, or empty string if there is no box\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                # fill the row with empty string for prompts with less than 4 objects\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "\n",
    "    return id_counter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1XOMRz7q6RQn"
   },
   "source": [
    "### Color binding functions\n",
    "\n",
    "Still no overlapping bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1750689248613,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "8_ogXncE2z0Y"
   },
   "outputs": [],
   "source": [
    "def build_color_prompt(objs, colors_sampled):\n",
    "    colored_objs = []\n",
    "    for obj, color in zip(objs, colors_sampled):\n",
    "        article, noun = obj.split(' ', 1)\n",
    "        colored_objs.append(f'{article} {color} {noun}')\n",
    "\n",
    "    if len(colored_objs) == 1:\n",
    "        prompt = colored_objs[0]\n",
    "    else:\n",
    "        prompt = ', '.join(colored_objs[:-1]) + ' and ' + colored_objs[-1]\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1750689248614,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "vYW-K_3B2pvE"
   },
   "outputs": [],
   "source": [
    "def generate_csv_color_binding(output_path, id_counter, colors):\n",
    "    with open(output_path, mode='a', newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        # Write header only if starting fresh\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt', 'obj1', 'bbox1', 'obj2', 'bbox2', 'obj3', 'bbox3', 'obj4', 'bbox4'])\n",
    "\n",
    "        for num_objs in [1, 2, 3, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects (with colors)...\")\n",
    "            min_size, max_size = NO_OVERLAPPING_RANGES[num_objs]\n",
    "\n",
    "            if num_objs == 1:\n",
    "                combinations = [[obj] for obj in obj_with_articles]\n",
    "            else:\n",
    "                combinations = list(itertools.combinations(obj_with_articles, num_objs))\n",
    "\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            total_needed = prompts_per_diff_level\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            for objs in tqdm(used_combinations):\n",
    "                boxes = generate_non_overlapping_boxes(num_objs, min_size, max_size)\n",
    "\n",
    "                # Randomly sample N colors, all different\n",
    "                colors_sampled = random.sample(colors, num_objs)\n",
    "\n",
    "                prompt = build_color_prompt(objs, colors_sampled)\n",
    "\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'color_binding'\n",
    "                row = [id_str, category, prompt]\n",
    "\n",
    "                for obj, box, color in itertools.zip_longest(objs, boxes, colors_sampled, fillvalue=''):\n",
    "                    article, noun = obj.split(' ', 1)\n",
    "                    obj_with_color = f'{article} {color} {noun}'\n",
    "\n",
    "                    row.append(obj_with_color)\n",
    "\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "\n",
    "    return id_counter\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JwmkqkAM6Vj6"
   },
   "source": [
    "### Attribute binding functions\n",
    "\n",
    "Still no overlapping bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 35,
     "status": "ok",
     "timestamp": 1750689248650,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "tH4mXdPC6U7j"
   },
   "outputs": [],
   "source": [
    "def build_attribute_prompt(objs, attribute_sampled):\n",
    "    attribute_objs = []\n",
    "    for obj, attribute in zip(objs, attribute_sampled):\n",
    "        article, noun = obj.split(' ', 1)\n",
    "        attribute_objs.append(f'{article} {attribute} {noun}')\n",
    "\n",
    "    if len(attribute_objs) == 1:\n",
    "        prompt = attribute_objs[0]\n",
    "    else:\n",
    "        prompt = ', '.join(attribute_objs[:-1]) + ' and ' + attribute_objs[-1]\n",
    "\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689248651,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "IKRvHgui6wei"
   },
   "outputs": [],
   "source": [
    "def generate_csv_attribute_binding(output_path, id_counter, attributes):\n",
    "    with open(output_path, mode='a', newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        # Write header only if starting fresh\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt', 'obj1', 'bbox1', 'obj2', 'bbox2', 'obj3', 'bbox3', 'obj4', 'bbox4'])\n",
    "\n",
    "        for num_objs in [1, 2, 3, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects (with attributes)...\")\n",
    "            min_size, max_size = NO_OVERLAPPING_RANGES[num_objs]\n",
    "\n",
    "            if num_objs == 1:\n",
    "                combinations = [[obj] for obj in obj_with_articles]\n",
    "            else:\n",
    "                combinations = list(itertools.combinations(obj_with_articles, num_objs))\n",
    "\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            total_needed = prompts_per_diff_level\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            for objs in tqdm(used_combinations):\n",
    "                boxes = generate_non_overlapping_boxes(num_objs, min_size, max_size)\n",
    "\n",
    "                # Randomly sample N attributes, all different\n",
    "                attributes_sampled = random.sample(attributes, num_objs)\n",
    "\n",
    "                prompt = build_attribute_prompt(objs, attributes_sampled)\n",
    "\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'attribute_binding'\n",
    "                row = [id_str, category, prompt]\n",
    "\n",
    "                for obj, box, attribute in itertools.zip_longest(objs, boxes, attributes_sampled, fillvalue=''):\n",
    "                    article, noun = obj.split(' ', 1)\n",
    "                    obj_with_attribute = f'{article} {attribute} {noun}'\n",
    "\n",
    "                    row.append(obj_with_attribute)\n",
    "\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "\n",
    "    return id_counter\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "byCqk_2c91ln"
   },
   "source": [
    "### Overlapping bounding boxes functions\n",
    "\n",
    "No more contraints on the dimension of the boxes since now they must overlap.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F2Rgu-9b_ebJ"
   },
   "source": [
    "**Overlapping boxes generation ⬇️:**\n",
    "\n",
    "- choose the first box randomly. The size will be between 10x10 and 510x510\n",
    "- the other boxes are chosen randomly so that they overlap with at least one of the already generated boxes by picking an initial point inside of one of them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1750689248651,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "3uN6XeKF9-v2"
   },
   "outputs": [],
   "source": [
    "def generate_overlapping_boxes(num_objs, min_size=80, max_size=510):\n",
    "    boxes = []\n",
    "\n",
    "    # First box chosen randomly from 10x10 to a max of 510x510\n",
    "    w = random.randint(min_size, max_size)\n",
    "    h = random.randint(min_size, max_size)\n",
    "\n",
    "    x1 = random.randint(0, 512 - w)\n",
    "    y1 = random.randint(0, 512 - h)\n",
    "\n",
    "    box1 = (x1, y1, x1 + w, y1 + h)\n",
    "    boxes.append(box1)\n",
    "\n",
    "    for _ in range(1, num_objs):\n",
    "        # Random size for new box\n",
    "        w = random.randint(min_size, max_size)\n",
    "        h = random.randint(min_size, max_size)\n",
    "\n",
    "        # Pick one of the boxes already generated so far\n",
    "        target_box = random.choice(boxes)\n",
    "\n",
    "        # Pick a random point inside that target box\n",
    "        overlap_x1 = random.randint(target_box[0], target_box[2] - 1)\n",
    "        overlap_y1 = random.randint(target_box[1], target_box[3] - 1)\n",
    "\n",
    "        # Now choose x1_new/y1_new so that box fully fits inside image\n",
    "        # and contains the overlap point\n",
    "\n",
    "        # For x1_new:\n",
    "        x1_min = max(0, overlap_x1 - (w - 1))\n",
    "        x1_max = min(overlap_x1, 512 - w)\n",
    "        x1_new = random.randint(x1_min, x1_max)\n",
    "\n",
    "        # For y1_new:\n",
    "        y1_min = max(0, overlap_y1 - (h - 1))\n",
    "        y1_max = min(overlap_y1, 512 - h)\n",
    "        y1_new = random.randint(y1_min, y1_max)\n",
    "\n",
    "        # Now the box is guaranteed to fit and to be >= 80x80\n",
    "        x2_new = x1_new + w\n",
    "        y2_new = y1_new + h\n",
    "\n",
    "        box_new = (x1_new, y1_new, x2_new, y2_new)\n",
    "        boxes.append(box_new)\n",
    "\n",
    "    return boxes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 0,
     "status": "ok",
     "timestamp": 1750689248652,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "Cq9P_HdY91Kv"
   },
   "outputs": [],
   "source": [
    "def generate_csv_overlapping_bboxes(output_path, id_counter):\n",
    "    mode = 'w' if id_counter == 0 else 'a'\n",
    "\n",
    "    with open(output_path, mode=mode, newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt',\n",
    "                             'obj1', 'bbox1',\n",
    "                             'obj2', 'bbox2',\n",
    "                             'obj3', 'bbox3',\n",
    "                             'obj4', 'bbox4'])\n",
    "\n",
    "        for num_objs in [1, 2, 3, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects (with overlapping bboxes)...\")\n",
    "            min_size = 80\n",
    "            max_size = 510\n",
    "\n",
    "            if num_objs == 1:\n",
    "                combinations = [[obj] for obj in obj_with_articles]\n",
    "            else:\n",
    "                combinations = list(itertools.combinations(obj_with_articles, num_objs))\n",
    "\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            total_needed = prompts_per_diff_level\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            for objs in tqdm(used_combinations):\n",
    "                boxes = generate_overlapping_boxes(num_objs, min_size, max_size)\n",
    "\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'overlapping_bboxes'\n",
    "                prompt = build_prompt(objs)\n",
    "\n",
    "                row = [id_str, category, prompt]\n",
    "\n",
    "                for obj, box in itertools.zip_longest(objs, boxes, fillvalue=''):\n",
    "                    row.append(obj)\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "\n",
    "    return id_counter\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zsSXe6SqGv7d"
   },
   "source": [
    "### Small bounding boxes functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 0,
     "status": "ok",
     "timestamp": 1750689248653,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "j1A3eOtzXBU5"
   },
   "outputs": [],
   "source": [
    "def generate_box_with_area_constraints(min_area, max_area, max_attempts=100):\n",
    "    for _ in range(max_attempts):\n",
    "        area = random.randint(min_area, max_area)\n",
    "        aspect_ratio = random.uniform(0.5, 2)\n",
    "        h = int(round((area / aspect_ratio) ** 0.5))\n",
    "        w = int(round(h * aspect_ratio))\n",
    "        if w <= IMAGE_SIZE and h <= IMAGE_SIZE:\n",
    "            x = random.randint(0, IMAGE_SIZE - w)\n",
    "            y = random.randint(0, IMAGE_SIZE - h)\n",
    "            return (x, y, x + w, y + h)\n",
    "\n",
    "def generate_non_overlapping_boxes_with_area(num_boxes, min_area, max_area, max_attempts=1000):\n",
    "    boxes = []\n",
    "    attempts = 0\n",
    "    while len(boxes) < num_boxes and attempts < max_attempts:\n",
    "        new_box = generate_box_with_area_constraints(min_area, max_area)\n",
    "        if all(not boxes_overlap(new_box, existing) for existing in boxes):\n",
    "            boxes.append(new_box)\n",
    "        attempts += 1\n",
    "    if len(boxes) < num_boxes:\n",
    "        raise RuntimeError(f\"Could not generate {num_boxes} non-overlapping boxes in time.\")\n",
    "    return boxes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 0,
     "status": "ok",
     "timestamp": 1750689248654,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "KgPVEkHCG1lV"
   },
   "outputs": [],
   "source": [
    "def generate_csv_small_bboxes(output_path, id_counter):\n",
    "    mode = 'w' if id_counter == 0 else 'a'\n",
    "\n",
    "    with open(output_path, mode=mode, newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt',\n",
    "                             'obj1', 'bbox1',\n",
    "                             'obj2', 'bbox2',\n",
    "                             'obj3', 'bbox3',\n",
    "                             'obj4', 'bbox4'])\n",
    "\n",
    "        for num_objs in [1, 2, 3, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects (small bboxes)...\")\n",
    "\n",
    "            image_area = IMAGE_SIZE * IMAGE_SIZE\n",
    "\n",
    "            min_area = int(image_area * 0.03)\n",
    "            max_area = int(image_area * 0.10)\n",
    "\n",
    "            if num_objs == 1:\n",
    "                combinations = [[obj] for obj in obj_with_articles]\n",
    "            else:\n",
    "                combinations = list(itertools.combinations(obj_with_articles, num_objs))\n",
    "\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            total_needed = prompts_per_diff_level\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            for objs in tqdm(used_combinations):\n",
    "                # IMPORTANT → generate small non-overlapping boxes\n",
    "                boxes = generate_non_overlapping_boxes_with_area(num_objs, min_area, max_area)\n",
    "\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'small_bboxes'\n",
    "                prompt = build_prompt(objs)\n",
    "\n",
    "                row = [id_str, category, prompt]\n",
    "\n",
    "                for obj, box in itertools.zip_longest(objs, boxes, fillvalue=''):\n",
    "                    row.append(obj)\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "\n",
    "    return id_counter\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4CybmSdNMIfd"
   },
   "source": [
    "### Object relations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 59,
     "status": "ok",
     "timestamp": 1750689248725,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "wySeay2YUBDn"
   },
   "outputs": [],
   "source": [
    "def generate_boxes_for_relation(obj1, obj2, relation, min_size, max_size):\n",
    "    for attempt in range(400):\n",
    "        corner = '' # needed if relation is far from\n",
    "        on_offset = 0 # needed if relation is on\n",
    "        # Generate first box conditioned on the relation\n",
    "        if relation in ['to the left of']:\n",
    "            # something will be on the right of the first box → leave space on the right\n",
    "            w1 = random.randint(min_size, min(max_size, 512 - min_size))\n",
    "            h1 = random.randint(min_size, max_size)\n",
    "            x1 = random.randint(0, 512 - w1 - min_size)  # leave space on right (- min_size)\n",
    "            y1 = random.randint(0, 512 - h1)\n",
    "\n",
    "        elif relation in ['to the right of']:\n",
    "            # something will be on the left of the first box → leave space on the left\n",
    "            w1 = random.randint(min_size, min(max_size, 512 - min_size))\n",
    "            h1 = random.randint(min_size, max_size)\n",
    "            x1 = random.randint(min_size, 512 - w1)  # leave space on left (minimum random is min_size)\n",
    "            y1 = random.randint(0, 512 - h1)\n",
    "\n",
    "        elif relation in ['below', 'under']:\n",
    "            # Place box1 in lower part of image, leaving space above for box2\n",
    "            w1 = random.randint(min_size, max_size)\n",
    "            h1 = random.randint(min_size, min(max_size, 512 - min_size))\n",
    "            x1 = random.randint(0, 512 - w1)\n",
    "            y1 = random.randint(min_size, 512 - h1) # Ensure space above (minimum random is min_size)\n",
    "\n",
    "        elif relation in ['above', 'over']:\n",
    "            # Place box1 in upper part of image, leaving space below for box2\n",
    "            w1 = random.randint(min_size, max_size)\n",
    "            h1 = random.randint(min_size, min(max_size, 512 - min_size))\n",
    "            x1 = random.randint(0, 512 - w1)\n",
    "            y1 = random.randint(0, 512 - h1 - min_size) # Ensure space below (- min_size)\n",
    "\n",
    "        elif relation in ['on']:\n",
    "            on_offset = random.randint(5, 30)\n",
    "            # Place box1 in upper part of image, leaving space below for box2\n",
    "            w1 = random.randint(min_size, max_size)\n",
    "            h1 = random.randint(min_size, min(max_size, 512 - min_size - on_offset))\n",
    "            x1 = random.randint(0, 512 - w1)\n",
    "            y1 = random.randint(0, 512 - h1 - min_size) # Ensure space below (- min_size)\n",
    "\n",
    "        elif relation in ['far from']:\n",
    "            # select width and height for box1\n",
    "            w1 = random.randint(min_size, max_size)\n",
    "            h1 = random.randint(min_size, max_size)\n",
    "\n",
    "            # define image halves\n",
    "            mid_x, mid_y = 256, 256\n",
    "\n",
    "            # choose a random corner\n",
    "            corner = random.choice(['top left', 'top right', 'bottom left', 'bottom right'])\n",
    "\n",
    "            if corner == 'top left':\n",
    "                x1 = random.randint(0, mid_x - w1)\n",
    "                y1 = random.randint(0, mid_y - h1)\n",
    "            elif corner == 'top right':\n",
    "                x1 = random.randint(mid_x, 512 - w1)\n",
    "                y1 = random.randint(0, mid_y - h1)\n",
    "            elif corner == 'bottom left':\n",
    "                x1 = random.randint(0, mid_x - w1)\n",
    "                y1 = random.randint(mid_y, 512 - h1)\n",
    "            elif corner == 'bottom right':\n",
    "                x1 = random.randint(mid_x, 512 - w1)\n",
    "                y1 = random.randint(mid_y, 512 - h1)\n",
    "\n",
    "        else:\n",
    "            # near, next to, beside → fully random, let's leave everything to the probability\n",
    "            w1 = random.randint(min_size, max_size)\n",
    "            h1 = random.randint(min_size, max_size)\n",
    "            x1 = random.randint(0, 512 - w1)\n",
    "            y1 = random.randint(0, 512 - h1)\n",
    "\n",
    "        box1 = (x1, y1, x1 + w1, y1 + h1)\n",
    "\n",
    "        # Now generate box2 according to relation\n",
    "        if relation in ['to the left of']:\n",
    "            available_width = 512 - box1[2]\n",
    "            if available_width < min_size:\n",
    "                continue  # not enough space\n",
    "\n",
    "            w2 = random.randint(min_size, available_width)\n",
    "            h2 = random.randint(min_size, max_size)\n",
    "\n",
    "            xmin = box1[2]\n",
    "            xmax = 512 - w2\n",
    "            if xmin > xmax:\n",
    "                continue  # invalid range\n",
    "            x2 = random.randint(xmin, xmax)\n",
    "            y2 = random.randint(0, 512 - h2)\n",
    "\n",
    "        elif relation in ['to the right of']:\n",
    "            available_width = box1[0]\n",
    "            if available_width < min_size:\n",
    "                continue  # not enough space\n",
    "\n",
    "            w2 = random.randint(min_size, available_width)\n",
    "            h2 = random.randint(min_size, max_size)\n",
    "\n",
    "            xmin = 0\n",
    "            xmax = box1[0] - w2\n",
    "            if xmin > xmax:\n",
    "                continue  # invalid range\n",
    "            x2 = random.randint(xmin, xmax)\n",
    "            y2 = random.randint(0, 512 - h2)\n",
    "\n",
    "        elif relation in ['below', 'under']:\n",
    "            available_height = box1[1]\n",
    "            if available_height < min_size:\n",
    "                continue  # not enough space\n",
    "\n",
    "            h2 = random.randint(min_size, available_height)\n",
    "            w2 = random.randint(min_size, max_size)\n",
    "\n",
    "            ymin = 0\n",
    "            ymax = box1[1] - h2\n",
    "            if ymin > ymax:\n",
    "                continue  # invalid range\n",
    "            y2 = random.randint(ymin, ymax)\n",
    "            x2 = random.randint(0, 512 - w2)\n",
    "\n",
    "        elif relation in ['above', 'over']:\n",
    "            # box1 is in the top part of the image — box2 should be below\n",
    "            available_height = 512 - box1[3]\n",
    "            if available_height < min_size:\n",
    "                continue  # not enough space\n",
    "\n",
    "            h2 = random.randint(min_size, available_height)\n",
    "            w2 = random.randint(min_size, max_size)\n",
    "\n",
    "            ymin = box1[3]\n",
    "            ymax = 512 - h2\n",
    "            if ymin > ymax:\n",
    "                continue  # invalid range\n",
    "            y2 = random.randint(ymin, ymax)\n",
    "            x2 = random.randint(0, 512 - w2)\n",
    "\n",
    "        elif relation in [ 'on']:\n",
    "            # box1 is in the top part of the image — box2 should be below\n",
    "            available_space = 512 - box1[3] - on_offset\n",
    "\n",
    "            if(available_space < min_size):\n",
    "                continue  # not enough space\n",
    "\n",
    "            w2 = random.randint(min_size, max_size)\n",
    "            h2 = random.randint(min_size, available_space)\n",
    "            x2 = random.randint(max(0, box1[0] - w2), min(512 - w2, box1[0] + w2)) # limit the chosing of x so that box2 is somehow under box1\n",
    "            y2 = box1[3] + on_offset\n",
    "\n",
    "\n",
    "        elif relation in ['far from']:\n",
    "            # position box2 in the furthest corner from box1\n",
    "            w2 = random.randint(min_size, max_size)\n",
    "            h2 = random.randint(min_size, max_size)\n",
    "\n",
    "            if corner == 'top left':\n",
    "                x2 = random.randint(256, 512 - w2)\n",
    "                y2 = random.randint(256, 512 - h2)\n",
    "            elif corner == 'top right':\n",
    "                x2 = random.randint(0, 256 - w2)\n",
    "                y2 = random.randint(256, 512 - h2)\n",
    "            elif corner == 'bottom left':\n",
    "                x2 = random.randint(256, 512 - w2)\n",
    "                y2 = random.randint(0, 256 - h2)\n",
    "            elif corner == 'bottom right':\n",
    "                x2 = random.randint(0, 256 - w2)\n",
    "                y2 = random.randint(0, 256 - h2)\n",
    "\n",
    "        elif relation in ['near']:\n",
    "            # Place box2 near box1, all directions that have enough space are okay, no overlapping\n",
    "            # step1: calculate all directions from box1 that have enough space (i.e. that have at least min_size + offset from the border)\n",
    "            # step2: chose a random direction from the ones that can be used\n",
    "            # step3: place box2 in that space\n",
    "            offset = random.randint(5, 30)\n",
    "            directions = []\n",
    "\n",
    "            # Determine which directions have enough space for box2\n",
    "            if box1[0] >= min_size + offset:\n",
    "                directions.append('left')\n",
    "            if box1[2] + min_size + offset <= 512:\n",
    "                directions.append('right')\n",
    "            if box1[1] >= min_size + offset:\n",
    "                directions.append('up')\n",
    "            if box1[3] + min_size + offset <= 512:\n",
    "                directions.append('down')\n",
    "\n",
    "            if not directions:\n",
    "                # If somehow no direction is available, try again\n",
    "                continue\n",
    "            else:\n",
    "                direction = random.choice(directions)\n",
    "\n",
    "                if direction == 'left': # on the left of box1\n",
    "                    available_space = box1[0] - offset\n",
    "                    w2 = random.randint(min_size, available_space)\n",
    "                    h2 = random.randint(min_size, max_size)\n",
    "                    x2 = box1[0] - w2 - offset\n",
    "                    y2 = random.randint(max(0, box1[1] - h2), min(512 - h2, box1[1] + h2)) # limit the chosing of y so that they are actually next to each other\n",
    "                elif direction == 'right': # on the rigth of box1\n",
    "                    available_space = 512 - box1[2] - offset\n",
    "                    w2 = random.randint(min_size, available_space)\n",
    "                    h2 = random.randint(min_size, max_size)\n",
    "                    x2 = box1[2] + offset\n",
    "                    y2 = random.randint(max(0, box1[1] - h2), min(512 - h2, box1[1] + h2))\n",
    "                elif direction == 'up': # above box1\n",
    "                    available_space = box1[1] - offset\n",
    "                    w2 = random.randint(min_size, max_size)\n",
    "                    h2 = random.randint(min_size, available_space)\n",
    "                    x2 = random.randint(max(0, box1[0] - w2), min(512 - w2, box1[0] + w2))\n",
    "                    y2 = box1[1] - h2 - offset\n",
    "                elif direction == 'down': # below box1\n",
    "                    available_space = 512 - box1[3] - offset\n",
    "                    w2 = random.randint(min_size, max_size)\n",
    "                    h2 = random.randint(min_size, available_space)\n",
    "                    x2 = random.randint(max(0, box1[0] - w2), min(512 - w2, box1[0] + w2)) # limit the chosing of x so that box2 is somehow under box1\n",
    "                    y2 = box1[3] + offset\n",
    "\n",
    "        elif relation in ['next to', 'beside']:\n",
    "            # Place box2 near box1 but only left or right of each other, no overlapping\n",
    "            offset = random.randint(5, 30)\n",
    "            directions = []\n",
    "\n",
    "            # Determine which directions have enough space for box2\n",
    "            if box1[0] >= min_size + offset:\n",
    "                directions.append('left')\n",
    "            if box1[2] + min_size + offset <= 512:\n",
    "                directions.append('right')\n",
    "\n",
    "            if not directions:\n",
    "                # If somehow no direction is available, try again\n",
    "                continue\n",
    "            else:\n",
    "                direction = random.choice(directions)\n",
    "\n",
    "                if direction == 'left': # on the left of box1\n",
    "                    available_space = box1[0] - offset\n",
    "                    w2 = random.randint(min_size, available_space)\n",
    "                    h2 = random.randint(min_size, max_size)\n",
    "                    x2 = box1[0] - w2 - offset\n",
    "                    y2 = random.randint(max(0, box1[1] - h2), min(512 - h2, box1[1] + h2)) # limit the chosing of y so that they are actually next to each other\n",
    "                elif direction == 'right': # on the rigth of box1\n",
    "                    available_space = 512 - box1[2] - offset\n",
    "                    w2 = random.randint(min_size, available_space)\n",
    "                    h2 = random.randint(min_size, max_size)\n",
    "                    x2 = box1[2] + offset\n",
    "                    y2 = random.randint(max(0, box1[1] - h2), min(512 - h2, box1[1] + h2))\n",
    "\n",
    "        else:\n",
    "            # Fully random for other relations\n",
    "            w2 = random.randint(min_size, max_size)\n",
    "            h2 = random.randint(min_size, max_size)\n",
    "            x2 = random.randint(0, 512 - w2)\n",
    "            y2 = random.randint(0, 512 - h2)\n",
    "\n",
    "        box2 = (x2, y2, x2 + w2, y2 + h2)\n",
    "\n",
    "\n",
    "        if not boxes_overlap(box1, box2):\n",
    "            return [box1, box2]\n",
    "\n",
    "    else:\n",
    "        # If after 1000 attempts we couldn't place box2, retry placing box1 and box2\n",
    "        print(f\"Warning: could not find valid box2 for relation '{relation}' after 400 attempts.\")\n",
    "        return None\n",
    "\n",
    "    return [box1, box2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689248729,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "ocwWq20iMSI0"
   },
   "outputs": [],
   "source": [
    "def build_relationship_prompt(objs, relations):\n",
    "    if len(objs) == 2:\n",
    "        return f\"{objs[0]} {relations[0]} {objs[1]}\"\n",
    "    elif len(objs) == 4:\n",
    "        return f\"{objs[0]} {relations[0]} {objs[1]} and {objs[2]} {relations[1]} {objs[3]}\"\n",
    "    else:\n",
    "        raise ValueError(\"Only 2 or 4 objects supported\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689248730,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "IMdB5tAwMVH8"
   },
   "outputs": [],
   "source": [
    "def generate_csv_object_relationship(output_path, id_counter, spatial_relations):\n",
    "    mode = 'w' if id_counter == 0 else 'a'\n",
    "\n",
    "    with open(output_path, mode=mode, newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt',\n",
    "                             'obj1', 'bbox1',\n",
    "                             'obj2', 'bbox2',\n",
    "                             'obj3', 'bbox3',\n",
    "                             'obj4', 'bbox4'])\n",
    "\n",
    "        for num_objs in [2, 4]:\n",
    "            print(f\"Generating prompts with {num_objs} objects (object_relationship)...\")\n",
    "            total_needed = prompts_per_diff_level\n",
    "\n",
    "            min_size, max_size = NO_OVERLAPPING_RANGES[num_objs]\n",
    "\n",
    "            if num_objs == 2:\n",
    "                combinations = [random.sample(obj_with_articles, 2) for _ in range(total_needed)]\n",
    "            else:\n",
    "                combinations = [random.sample(obj_with_articles, 4) for _ in range(total_needed)]\n",
    "\n",
    "            random.shuffle(combinations)\n",
    "\n",
    "            repeats_needed = math.ceil(total_needed / len(combinations))\n",
    "            used_combinations = (combinations * repeats_needed)[:total_needed]\n",
    "\n",
    "            failure_count = 0\n",
    "            successful = 0\n",
    "            attempts = 0\n",
    "            max_total_attempts = prompts_per_diff_level * 20  # avoid infinite loop\n",
    "\n",
    "            while successful < prompts_per_diff_level and attempts < max_total_attempts:\n",
    "                attempts += 1\n",
    "                objs = random.sample(obj_with_articles, num_objs)\n",
    "                relations = random.choices(spatial_relations, k=num_objs // 2)\n",
    "                boxes = []\n",
    "\n",
    "                if num_objs == 2:\n",
    "                    result = generate_boxes_for_relation(objs[0], objs[1], relations[0], min_size, max_size)\n",
    "                    if result is None:\n",
    "                        continue\n",
    "                    boxes += result\n",
    "\n",
    "                elif num_objs == 4:\n",
    "                    result1 = generate_boxes_for_relation(objs[0], objs[1], relations[0], min_size, max_size)\n",
    "                    if result1 is None:\n",
    "                        continue\n",
    "\n",
    "                    for _ in range(10):  # max_attempts\n",
    "                        result2 = generate_boxes_for_relation(objs[2], objs[3], relations[1], min_size, max_size)\n",
    "                        if result2 is None:\n",
    "                            continue\n",
    "                        b3, b4 = result2\n",
    "                        if not any(boxes_overlap(b, b3) or boxes_overlap(b, b4) for b in result1):\n",
    "                            boxes = result1 + result2\n",
    "                            break\n",
    "                    else:\n",
    "                        continue  # all 10 attempts failed, retry outer loop\n",
    "\n",
    "                # build prompt and write\n",
    "                id_str = str(id_counter).zfill(4)\n",
    "                category = 'object_relationship'\n",
    "                prompt = build_relationship_prompt(objs, relations)\n",
    "\n",
    "                row = [id_str, category, prompt]\n",
    "                for obj, box in itertools.zip_longest(objs, boxes, fillvalue=''):\n",
    "                    row.append(obj)\n",
    "                    if box:\n",
    "                        box_str = f'{box[0]},{box[1]},{box[2]},{box[3]}'\n",
    "                    else:\n",
    "                        box_str = ''\n",
    "                    row.append(box_str)\n",
    "\n",
    "                while len(row) < 11:\n",
    "                    row.append('')\n",
    "\n",
    "                writer.writerow(row)\n",
    "                id_counter += 1\n",
    "                successful += 1\n",
    "\n",
    "    return id_counter\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eJ5BrT0iEhXG"
   },
   "source": [
    "### Complex composition functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pudcfJ4OE8E_"
   },
   "source": [
    "First, let's load the csv file containing the complex prompts.\n",
    "To follow the rest of the tasks, the csv should contain a total of 2048 prompts, 512 for each level of difficulty (i.e. for each number of objects 1, 2, 3 or 4).\n",
    "\n",
    "The csv should have the following structure:\n",
    "\n",
    "\n",
    "```\n",
    "prompt, objects\n",
    "```\n",
    "here an example:\n",
    "\n",
    "```\n",
    "prompt,objects\n",
    "A happy horse is behind a white toolbox.,\"a happy horse, a white toolbox\"\n",
    "A small guitar is in front of a red pan.,\"a small guitar, a red pan\"\n",
    "A green bottle is between a white skateboard.,\"a green bottle, a white skateboard\"\n",
    "```\n",
    "\n",
    "I generated them using chatGPT with the following prompts:\n",
    "\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "For 1 obj prompts:\n",
    "\n",
    "\n",
    "```\n",
    "## 1 object prompt:\n",
    "\n",
    "Generate 512 natural compositional phrases with various structures and creativity.\n",
    "Each prompt must describe one unique object from the list below. Use all the objects at least one time.\n",
    "Each object should be enriched with at least one descriptive attribute that refers to its color, shape, material, appearance, or dimension. Attributes can be qualitative (e.g., \"fluffy\", \"shiny\") or refer to physical characteristics (e.g., \"large\", \"tall\").\n",
    "Each prompt should consist of a short, vivid sentence that situates the object in a dynamic or descriptive context, similar to the following examples:\n",
    "A fish swam swiftly through the clear water.\n",
    "A bear lumbered through the dense forest.\n",
    "A tulip bloomed brightly in the garden.\n",
    "Avoid passive or overly generic constructions—aim for imaginative, specific scenarios.\n",
    "Adjust the articles of the objects if needed.\n",
    "\n",
    "Object List (use each exactly once):\n",
    "obj_with_articles = [\n",
    "'a rose', 'an oak', 'a beetle', 'a skyscraper', 'a tree', 'a baby', 'a bed', 'a lamp', 'a dog', 'a laptop',\n",
    "'a bicycle', 'a person', 'a car', 'a bus', 'a cat', 'a book', 'a chair', 'a boy', 'a couch', 'a table',\n",
    "'a plant', 'a toilet', 'a cellphone', 'a microwave', 'a sheep', 'a boat', 'a banana', 'a stop sign',\n",
    "'a donut', 'a cow', 'a clock', 'a bottle', 'an umbrella', 'a bird', 'a guitar', 'a toothbrush', 'a parking meter',\n",
    "'a bench', 'a platypus', 'a keyboard', 'a baseball bat', 'a vase', 'a surfboard', 'a tiger', 'a train', 'a flower',\n",
    "'a sandwich', 'a spoon', 'a pizza', 'a carrot', 'a teddy bear', 'an hot-dog', 'a skateboard', 'a kite', 'a broom',\n",
    "'an apple', 'a handbag', 'a horse', 'a snowboard', 'a giraffe', 'a tie', 'a shower', 'a traffic light', 'a bear',\n",
    "'a toaster', 'a knife', 'a baseball glove', 'a crocodile', 'a suitcase', 'a fork', 'a cake', 'a cup', 'a bowl',\n",
    "'a hair drier', 'an elephant', 'a mouse', 'a mushroom', 'a motorcycle', 'a turtle', 'a tennis racket', 'a truck',\n",
    "'a zebra', 'a fire hydrant', 'an oven', 'a sink', 'a frisbee', 'a hat', 'a ruler', 'a shoe', 'a ball', 'a candle',\n",
    "'a ladder', 'a charger', 'a mug', 'a tape', 'a shirt', 'a pillow', 'a pan', 'a plate', 'a shampoo', 'a hammer',\n",
    "'a blender', 'a basket', 'a screwdriver', 'a wallet', 'a bin', 'a leaf', 'a bucket', 'a monitor', 'a watch',\n",
    "'a flashlight', 'a sock', 'a door', 'a scarf', 'a speaker', 'a desk', 'a backpack', 'a printer', 'a remote',\n",
    "'a glass', 'a curtain', 'a toolbox', 'a drill', 'a notebook', 'a television', 'a soap', 'a ring', 'a refrigerator'\n",
    "]\n",
    "\n",
    "Allowed Attributes:\n",
    "attributes = [\n",
    "'aggressive', 'black', 'blue', 'bright', 'clean', 'crowded', 'dark', 'fast', 'fluffy', 'fuzzy', 'green', 'happy',\n",
    "'large', 'pink', 'red', 'rotten', 'rough', 'shiny', 'short', 'silver', 'small', 'smooth', 'snowy', 'soft',\n",
    "'tall', 'warm', 'white', 'wooden', 'yellow'\n",
    "]\n",
    "\n",
    "Allowed Colors:\n",
    "colors = ['black', 'blue', 'brown', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow', 'orange']\n",
    "\n",
    "Output Format:\n",
    "Return the result as a CSV with two columns:\n",
    "prompt, object1\n",
    "Each row should contain:\n",
    "The generated sentence\n",
    "The noun chunk used\n",
    "\n",
    "Example output format:\n",
    "prompt,object1\n",
    "A blue fish swam swiftly through the clear water, a blue fish\n",
    "A large brown bear lumbered through the dense forest, a large brown bear\n",
    "A pink tulip bloomed brightly in the garden, a pink tulip\n",
    "\n",
    "The prompts should be inside quotes if needed to keep the correct numbers of columns in the csv.\n",
    "In the sentence, keep noun chunks unbroken—adjectives modifying a noun should not be split by commas. Treat the entire noun chunk as a single unit (e.g., \"a small red ball\", not \"a small, red ball\").\n",
    "The articles should remain consistent between the prompt and the noun chunk.\n",
    "\n",
    "\n",
    "\n",
    "```\n",
    "---\n",
    "For 2,3,4 objects prompts:\n",
    "\n",
    "\n",
    "```\n",
    "# 2,3,4 objects prompt:\n",
    "\n",
    "Generate 512 natural compositional phrases with various structures and creativity.\n",
    "Each prompt must describe a scene involving exactly 4 unique objects.\n",
    "Objects can be reused across multiple prompts, but each object must appear only once within any given prompt.\n",
    "Each object must be enriched with at least one descriptive attribute, which may describe:\n",
    "Color, shape, material, appearance, or dimension\n",
    "Or a spatial relation between objects in the same prompt\n",
    "Each prompt should be a short, vivid sentence that situates the objects in a dynamic or descriptive context, similar to the following examples:\n",
    "A bright pink flower swayed gently under the tall oak tree.\n",
    "A black laptop rested beside a green coffee mug on the messy desk.\n",
    "\n",
    "Avoid passive or overly generic constructions—aim for imaginative, specific scenarios.\n",
    "Adjust the articles of the objects if needed.\n",
    "\n",
    "Object List\n",
    "(You can freely reuse any of these objects across different prompts, but not within the same prompt.)\n",
    "obj_with_articles = [\n",
    "'a rose', 'an oak', 'a beetle', 'a skyscraper', 'a tree', 'a baby', 'a bed', 'a lamp', 'a dog', 'a laptop',\n",
    "'a bicycle', 'a person', 'a car', 'a bus', 'a cat', 'a book', 'a chair', 'a boy', 'a couch', 'a table',\n",
    "'a plant', 'a toilet', 'a cellphone', 'a microwave', 'a sheep', 'a boat', 'a banana', 'a stop sign',\n",
    "'a donut', 'a cow', 'a clock', 'a bottle', 'an umbrella', 'a bird', 'a guitar', 'a toothbrush', 'a parking meter',\n",
    "'a bench', 'a platypus', 'a keyboard', 'a baseball bat', 'a vase', 'a surfboard', 'a tiger', 'a train', 'a flower',\n",
    "'a sandwich', 'a spoon', 'a pizza', 'a carrot', 'a teddy bear', 'an hot-dog', 'a skateboard', 'a kite', 'a broom',\n",
    "'an apple', 'a handbag', 'a horse', 'a snowboard', 'a giraffe', 'a tie', 'a shower', 'a traffic light', 'a bear',\n",
    "'a toaster', 'a knife', 'a baseball glove', 'a crocodile', 'a suitcase', 'a fork', 'a cake', 'a cup', 'a bowl',\n",
    "'a hair drier', 'an elephant', 'a mouse', 'a mushroom', 'a motorcycle', 'a turtle', 'a tennis racket', 'a truck',\n",
    "'a zebra', 'a fire hydrant', 'an oven', 'a sink', 'a frisbee', 'a hat', 'a ruler', 'a shoe', 'a ball', 'a candle',\n",
    "'a ladder', 'a charger', 'a mug', 'a tape', 'a shirt', 'a pillow', 'a pan', 'a plate', 'a shampoo', 'a hammer',\n",
    "'a blender', 'a basket', 'a screwdriver', 'a wallet', 'a bin', 'a leaf', 'a bucket', 'a monitor', 'a watch',\n",
    "'a flashlight', 'a sock', 'a door', 'a scarf', 'a speaker', 'a desk', 'a backpack', 'a printer', 'a remote',\n",
    "'a glass', 'a curtain', 'a toolbox', 'a drill', 'a notebook', 'a television', 'a soap', 'a ring', 'a refrigerator'\n",
    "]\n",
    "\n",
    "Allowed Attributes (for appearance or material):\n",
    "attributes = [\n",
    "'aggressive', 'black', 'blue', 'bright', 'clean', 'crowded', 'dark', 'fast', 'fluffy', 'fuzzy', 'green', 'happy',\n",
    "'large', 'pink', 'red', 'rotten', 'rough', 'shiny', 'short', 'silver', 'small', 'smooth', 'snowy', 'soft',\n",
    "'tall', 'warm', 'white', 'wooden', 'yellow'\n",
    "]\n",
    "\n",
    "Allowed Colors (subset of attributes):\n",
    "colors = ['black', 'blue', 'brown', 'gray', 'green', 'pink', 'purple', 'red', 'white', 'yellow', 'orange']\n",
    "\n",
    "Spatial Relations (to be used as part of attributes or composition logic):\n",
    "spatial_relations = [\n",
    "'on top of', 'beside', 'under', 'above', 'next to', 'beneath', 'behind', 'in front of', 'between', 'leaning on',\n",
    "'inside', 'resting on', 'attached to', 'surrounded by', 'placed near'\n",
    "]\n",
    "\n",
    "Output Format:\n",
    "Return the result as a CSV with two columns:\n",
    "prompt, object1, object2, ..\n",
    "Each row should contain:\n",
    "The generated sentence\n",
    "The noun chunks used\n",
    "\n",
    "Example output format (for N=2):\n",
    "prompt,object1,object2\n",
    "A fluffy cat jumped onto the soft couch near the window\" ,a fluffy cat, a soft couch\n",
    "A red bicycle leaned against a wooden bench in the park ,a red bicycle, a wooden bench\n",
    "\n",
    "The prompts should be inside quotes if needed to keep the correct numbers of columns in the csv.\n",
    "In the sentence, keep noun chunks unbroken—adjectives modifying a noun should not be split by commas. Treat the entire noun chunk as a single unit (e.g., \"a small red ball\", not \"a small, red ball\").\n",
    "The articles should remain consistent between the prompt and the noun chunks.\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 547,
     "status": "ok",
     "timestamp": 1750689249275,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "NJfsgV87T4-I"
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "from spacy.matcher import Matcher\n",
    "\n",
    "nlp = spacy.load(\"en_core_web_sm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689249276,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "FxM6Pyg1xneK"
   },
   "outputs": [],
   "source": [
    "def extract_spatial_relations(prompt: str, local_prompts: list[str]):\n",
    "    doc = nlp(prompt)\n",
    "    relations = []\n",
    "\n",
    "    # Convert to lowercase for fuzzy match\n",
    "    local_prompts_lower = [lp.lower() for lp in local_prompts]\n",
    "\n",
    "    for token in doc:\n",
    "        # Skip if token is not a preposition or not in spatial list\n",
    "        if token.dep_ == \"prep\" and token.text.lower() in spatial_relations:\n",
    "            prep = token.text.lower()\n",
    "\n",
    "            # Try to get the object of the preposition (e.g., \"mouse\" in \"under the mouse\")\n",
    "            pobj = next((child for child in token.children if child.dep_ == \"pobj\"), None)\n",
    "            # Head is usually the verb, go one level up to get noun subject\n",
    "            subject = None\n",
    "            if token.head:\n",
    "                for child in token.head.children:\n",
    "                    if child.dep_ in (\"nsubj\", \"nsubjpass\"):\n",
    "                        subject = child\n",
    "\n",
    "            if subject and pobj:\n",
    "                # Try to match to full local_prompt strings\n",
    "                subj_match = next((lp for lp in local_prompts_lower if subject.text.lower() in lp), None)\n",
    "                obj_match = next((lp for lp in local_prompts_lower if pobj.text.lower() in lp), None)\n",
    "\n",
    "                if subj_match and obj_match:\n",
    "                    relations.append([subj_match, obj_match, prep])\n",
    "\n",
    "    return relations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3292,
     "status": "ok",
     "timestamp": 1750689252566,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "oZgGq1vtVnih"
   },
   "outputs": [],
   "source": [
    "def extract_object_list(row):\n",
    "    return [\n",
    "        str(row.get(f'object{i}', '')).strip()\n",
    "        for i in range(1, 5)\n",
    "        if pd.notna(row.get(f'object{i}', '')) and str(row.get(f'object{i}', '')).strip() != ''\n",
    "    ]\n",
    "\n",
    "complex_prompts_df[\"triplet\"] = complex_prompts_df.apply(\n",
    "    lambda row: extract_spatial_relations(row[\"prompt\"], extract_object_list(row)),\n",
    "    axis=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "executionInfo": {
     "elapsed": 3253,
     "status": "ok",
     "timestamp": 1750689255820,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "fuLDle2RyoIV",
    "outputId": "c44f9138-8d91-4109-a047-5d14908b2a87"
   },
   "outputs": [],
   "source": [
    "complex_prompts_df[\"triplet\"] = complex_prompts_df.apply(lambda row: extract_spatial_relations(row[\"prompt\"], extract_object_list(row)), axis=1)\n",
    "complex_prompts_df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689255827,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "ZqJxAUlCOt5G"
   },
   "outputs": [],
   "source": [
    "import csv\n",
    "import random\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "\n",
    "MIN_SIZE = int(0.03 * IMAGE_SIZE)\n",
    "MAX_SIZE = int(0.97 * IMAGE_SIZE)\n",
    "\n",
    "MIN_SIZE = 80\n",
    "MAX_SIZE = 510\n",
    "\n",
    "def random_bbox(min_size=MIN_SIZE, max_size=MAX_SIZE):\n",
    "    \"\"\"Generate a random bounding box (x1, y1, x2, y2) within image bounds.\"\"\"\n",
    "    width = random.randint(min_size, max_size)\n",
    "    height = random.randint(min_size, max_size)\n",
    "    x1 = random.randint(0, IMAGE_SIZE - width)\n",
    "    y1 = random.randint(0, IMAGE_SIZE - height)\n",
    "    x2 = x1 + width\n",
    "    y2 = y1 + height\n",
    "    assert x2 > x1 and y2 > y1, \"Invalid bbox: negative width/height\"\n",
    "    return [x1, y1, x2, y2]\n",
    "\n",
    "def detect_spatial_relations(prompt):\n",
    "    \"\"\"Return all spatial relations found in the prompt.\"\"\"\n",
    "    lowered = prompt.lower()\n",
    "    return [rel for rel in spatial_relations if rel in lowered]\n",
    "\n",
    "def generate_csv_complex_composition(df, output_path, id_counter):\n",
    "    mode = 'w' if id_counter == 0 else 'a'\n",
    "\n",
    "    with open(output_path, mode=mode, newline='') as csv_file:\n",
    "        writer = csv.writer(csv_file)\n",
    "\n",
    "        if id_counter == 0:\n",
    "            writer.writerow(['id', 'category', 'prompt', 'obj1', 'bbox1', 'obj2', 'bbox2', 'obj3', 'bbox3', 'obj4', 'bbox4'])\n",
    "\n",
    "        for _, row in tqdm(df.iterrows(), total=len(df)):\n",
    "            prompt = row['prompt']\n",
    "            relations = row['triplet']  # [['obj1', 'obj2', 'relation'], ...]\n",
    "\n",
    "            # Get the object labels from the dataframe\n",
    "            objects = [row.get(f'object{i}') for i in range(1, 5)]\n",
    "            objects = [obj for obj in objects if pd.notna(obj) and str(obj).strip() != '']\n",
    "            num_objs = len(objects)\n",
    "\n",
    "            if num_objs == 0:\n",
    "                continue\n",
    "\n",
    "            # Assign boxes\n",
    "            boxes = {}\n",
    "\n",
    "            used_objects = set()\n",
    "            if relations:\n",
    "                for triplet in relations:\n",
    "                    subj, obj, relation = triplet\n",
    "\n",
    "                    if subj not in boxes or obj not in boxes:\n",
    "                        try:\n",
    "                            box1, box2 = generate_boxes_for_relation(subj, obj, relation, MIN_SIZE, MAX_SIZE)\n",
    "                            boxes[subj] = box1\n",
    "                            boxes[obj] = box2\n",
    "                            used_objects.update([subj, obj])\n",
    "                        except Exception as e:\n",
    "                            # Fallback to random if there's any issue\n",
    "                            print(f\"Error generating boxes for {subj} and {obj}: {e}. Falling back to random.\")\n",
    "                            boxes[subj] = random_bbox()\n",
    "                            boxes[obj] = random_bbox()\n",
    "                            used_objects.update([subj, obj])\n",
    "\n",
    "            # Assign random boxes to the remaining objects\n",
    "            for obj in objects:\n",
    "                if obj not in boxes:\n",
    "                    boxes[obj] = random_bbox()\n",
    "\n",
    "            # Build row\n",
    "            id_str = str(id_counter).zfill(4)\n",
    "            category = 'complex_composition'\n",
    "            row_out = [id_str, category, prompt]\n",
    "\n",
    "            for obj in objects:\n",
    "                row_out.append(obj)\n",
    "                box = boxes[obj]\n",
    "                box_str = f\"{box[0]},{box[1]},{box[2]},{box[3]}\"\n",
    "                row_out.append(box_str)\n",
    "\n",
    "            while len(row_out) < 11:\n",
    "                row_out.append('')\n",
    "\n",
    "            writer.writerow(row_out)\n",
    "            id_counter += 1\n",
    "\n",
    "    return id_counter\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ilDBck-S8Gfy"
   },
   "source": [
    "### Generation of the full CSV and download"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1750689255828,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "WfZJtCwn4ETJ"
   },
   "outputs": [],
   "source": [
    "id_counter = 0\n",
    "output_path = 'extendedDataset.csv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0kxtyR2YiE-d"
   },
   "source": [
    "#### Object binding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 14090,
     "status": "ok",
     "timestamp": 1750689269920,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "dvECmLOHVNIv",
    "outputId": "0e7d84b6-bcaa-4ab8-9cad-d019ffcabc89"
   },
   "outputs": [],
   "source": [
    "# object binding\n",
    "id_counter = generate_csv_object_binding(output_path, id_counter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2PGbnOwJiJjd"
   },
   "source": [
    "#### Color binding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13649,
     "status": "ok",
     "timestamp": 1750689283577,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "su1EIRvSvcaD",
    "outputId": "feb5bb0a-0c5c-4a96-fe51-8ca0ef825ce2"
   },
   "outputs": [],
   "source": [
    "# color binding\n",
    "id_counter = generate_csv_color_binding(output_path, id_counter, colors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nLebv6v3iNqy"
   },
   "source": [
    "#### Attribute binding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13066,
     "status": "ok",
     "timestamp": 1750689296640,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "llQwBO7k6C61",
    "outputId": "c9409bb9-1cd7-4bca-80a1-46efa6e880c9"
   },
   "outputs": [],
   "source": [
    "# attribute binding\n",
    "id_counter = generate_csv_attribute_binding(output_path, id_counter, attributes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Omdz58YeiRw7"
   },
   "source": [
    "#### Overlapping bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12242,
     "status": "ok",
     "timestamp": 1750689308883,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "7MlrE_5pEgFI",
    "outputId": "1f75cf60-1b24-421f-b45d-bc87c399d71e"
   },
   "outputs": [],
   "source": [
    "# overlapping bounding boxes\n",
    "id_counter = generate_csv_overlapping_bboxes(output_path, id_counter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4TrIe4joiXF5"
   },
   "source": [
    "#### Small bounding boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12127,
     "status": "ok",
     "timestamp": 1750689321020,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "dIf4nMBHHFP6",
    "outputId": "988ec445-5b27-4e2c-ad1b-145302d4c3c4"
   },
   "outputs": [],
   "source": [
    "# small bounding boxes\n",
    "id_counter = generate_csv_small_bboxes(output_path, id_counter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7hN6ScAbiaee"
   },
   "source": [
    "#### Object relations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10,
     "status": "ok",
     "timestamp": 1750689321029,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "zfcy5ouPNyJy",
    "outputId": "e6e50986-81dd-4c34-f1dc-19a69041bbe0"
   },
   "outputs": [],
   "source": [
    "# object relations\n",
    "id_counter = generate_csv_object_relationship(output_path, id_counter, spatial_relations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xZsHOZpKQSt2"
   },
   "source": [
    "#### Complex composition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 90,
     "status": "ok",
     "timestamp": 1750689321119,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "EtSqgZnWQVcu",
    "outputId": "9834c5a2-6238-47e6-e0e8-ccde8e2218e6"
   },
   "outputs": [],
   "source": [
    "# complex composition\n",
    "id_counter = generate_csv_complex_composition(complex_prompts_df, output_path, id_counter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XgL56x7Hie_n"
   },
   "source": [
    "#### Export file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OfWDFm8Ujsl7"
   },
   "source": [
    "Export the created prompt and boxes in a csv file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17
    },
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1750689321127,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "S3geQrn7VaRf",
    "outputId": "24076afe-9480-430f-87c1-c3fe1c1b5663"
   },
   "outputs": [],
   "source": [
    "from google.colab import files\n",
    "files.download('extendedDataset.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KlDdUdj4jT3K"
   },
   "source": [
    "# Visualize boxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1750689321138,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "ZVV9R2UhcOxG"
   },
   "outputs": [],
   "source": [
    "def parse_bbox(bbox_str):\n",
    "    if pd.isna(bbox_str) or bbox_str == '':\n",
    "        return None\n",
    "    return tuple(map(int, bbox_str.split(',')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8xmnlyqqj8ky"
   },
   "source": [
    "Extract randomly the selected number of prompts with their bounding boxes and display them. Choose the preferred number of boxes per row."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s5Od4VqoLycK"
   },
   "source": [
    "## Visualize random prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1750689321140,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "FKC9A7EncLoD"
   },
   "outputs": [],
   "source": [
    "def visualize_csv_boxes_grid(csv_path, num_samples=35, boxes_per_row=5, random_seed=42):\n",
    "    df = pd.read_csv(csv_path)\n",
    "    sampled_rows = df.sample(n=num_samples, random_state=random_seed)\n",
    "\n",
    "    num_cols = boxes_per_row\n",
    "    num_rows = math.ceil(num_samples / num_cols)\n",
    "\n",
    "    fig, axes = plt.subplots(num_rows, num_cols, figsize=(boxes_per_row * 4, num_rows * 4))\n",
    "    axes = axes.flatten()  # easy indexing even if last row is incomplete\n",
    "\n",
    "    for ax, (idx, row) in zip(axes, sampled_rows.iterrows()):\n",
    "        # Read the boxes\n",
    "        boxes = []\n",
    "        labels = []\n",
    "        for i in range(1, 5):\n",
    "            obj_col = f'obj{i}'\n",
    "            bbox_col = f'bbox{i}'\n",
    "            obj_name = row[obj_col]\n",
    "            bbox = parse_bbox(row[bbox_col])\n",
    "            if obj_name and isinstance(obj_name, str) and bbox:\n",
    "                boxes.append(bbox)\n",
    "                labels.append(obj_name)\n",
    "\n",
    "        # Plot this row\n",
    "        ax.set_xlim(0, 512)\n",
    "        ax.set_ylim(0, 512)\n",
    "        ax.invert_yaxis()\n",
    "\n",
    "        for i, box in enumerate(boxes):\n",
    "            x1, y1, x3, y3 = box\n",
    "            width = x3 - x1\n",
    "            height = y3 - y1\n",
    "            rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')\n",
    "            ax.add_patch(rect)\n",
    "\n",
    "            # Label\n",
    "            ax.text(x1 + 3, y1 - 5, labels[i], color='blue', fontsize=8)\n",
    "\n",
    "        ax.set_title(f'ID {row[\"id\"]}\\n{row[\"prompt\"]}\\n{row[\"category\"]}', fontsize=10)\n",
    "        ax.grid(True)\n",
    "\n",
    "    # Hide unused axes (if num_samples is not multiple of boxes_per_row)\n",
    "    for ax in axes[num_samples:]:\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "executionInfo": {
     "elapsed": 4424,
     "status": "ok",
     "timestamp": 1750689325571,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "OCr7VjkCcSqc",
    "outputId": "9cad56f4-5c40-4bae-e6fa-9a425e3bfcdd"
   },
   "outputs": [],
   "source": [
    "visualize_csv_boxes_grid('extendedDataset.csv', num_samples=35, boxes_per_row=5, random_seed=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3R5alzFPL4nz"
   },
   "source": [
    "## Visualize prompts by id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1750689325575,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "q0_WzzvrKtEY"
   },
   "outputs": [],
   "source": [
    "def visualize_csv_boxes_by_ids(csv_path, selected_ids, boxes_per_row=5):\n",
    "    import pandas as pd\n",
    "    import matplotlib.pyplot as plt\n",
    "    import matplotlib.patches as patches\n",
    "    import math\n",
    "\n",
    "    def parse_bbox(bbox_str):\n",
    "        try:\n",
    "            return list(map(int, bbox_str.split(',')))\n",
    "        except:\n",
    "            return None\n",
    "\n",
    "    df = pd.read_csv(csv_path)\n",
    "\n",
    "    # Make sure we match padded string IDs (e.g. '00005')\n",
    "    padded_ids = [str(id_) for id_ in selected_ids]\n",
    "    selected_rows = df[df['id'].astype(str).isin(padded_ids)]\n",
    "\n",
    "    if selected_rows.empty:\n",
    "        print(\"⚠️ No matching IDs found.\")\n",
    "        return\n",
    "\n",
    "    num_samples = len(selected_rows)\n",
    "    num_cols = boxes_per_row\n",
    "    num_rows = math.ceil(num_samples / num_cols)\n",
    "\n",
    "    fig, axes = plt.subplots(num_rows, num_cols, figsize=(boxes_per_row * 4, num_rows * 4))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for ax, (_, row) in zip(axes, selected_rows.iterrows()):\n",
    "        boxes = []\n",
    "        labels = []\n",
    "        for i in range(1, 5):\n",
    "            obj_col = f'obj{i}'\n",
    "            bbox_col = f'bbox{i}'\n",
    "            obj_name = row[obj_col]\n",
    "            bbox = parse_bbox(row[bbox_col])\n",
    "            if obj_name and isinstance(obj_name, str) and bbox:\n",
    "                boxes.append(bbox)\n",
    "                labels.append(obj_name)\n",
    "\n",
    "        ax.set_xlim(0, 512)\n",
    "        ax.set_ylim(0, 512)\n",
    "        ax.invert_yaxis()\n",
    "\n",
    "        for i, box in enumerate(boxes):\n",
    "            x1, y1, x3, y3 = box\n",
    "            width = x3 - x1\n",
    "            height = y3 - y1\n",
    "            rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')\n",
    "            ax.add_patch(rect)\n",
    "            ax.text(x1 + 3, y1 - 5, labels[i], color='blue', fontsize=8)\n",
    "\n",
    "        ax.set_title(f'ID {row[\"id\"]}\\n{row[\"prompt\"]}\\n{row[\"category\"]}', fontsize=10)\n",
    "        ax.grid(True)\n",
    "\n",
    "    for ax in axes[num_samples:]:\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 330
    },
    "executionInfo": {
     "elapsed": 160,
     "status": "ok",
     "timestamp": 1750689325736,
     "user": {
      "displayName": "Nicla Faccioli",
      "userId": "11628056964833727666"
     },
     "user_tz": -120
    },
    "id": "EVvZptXhKwvk",
    "outputId": "9740b448-a49c-4dfc-c5b5-27296b2b6689"
   },
   "outputs": [],
   "source": [
    "# Visualize specific prompts (e.g., ID 0, 5, 18, and 29)\n",
    "visualize_csv_boxes_by_ids(\"extendedDataset.csv\", selected_ids=[929])\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyMzaUewqzRquP9sR/LcilPh",
   "collapsed_sections": [
    "EXbFg5ss6MNs",
    "1XOMRz7q6RQn",
    "JwmkqkAM6Vj6",
    "byCqk_2c91ln",
    "zsSXe6SqGv7d",
    "0kxtyR2YiE-d",
    "2PGbnOwJiJjd",
    "nLebv6v3iNqy",
    "Omdz58YeiRw7",
    "4TrIe4joiXF5"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
