import json
from PIL import Image



src_file = "/coco_karpathy_train_dense_caption_w_ofa_caption.json"
dst_file = "/coco_karpathy_train_dense_caption_w_ofa_caption_w_natural_language.json"



def get_object_positions(image_resolution, bounding_boxes_with_captions):
    object_positions = []
    for box_with_caption in bounding_boxes_with_captions:
        caption, box_str = box_with_caption.split(":")
        caption = caption.strip()
        box = [int(coord) for coord in box_str.strip()[1:-1].split(",")]
        x1, y1, x2, y2 = box
        object_x = (x1 + x2) / 2 / image_resolution[0]
        object_y = (y1 + y2) / 2 / image_resolution[1]
        if object_x < 0.2:
            x_pos = "on the left"
        elif object_x < 0.4:
            x_pos = "in the middle left"
        elif object_x < 0.6:
            x_pos = "in the center"
        elif object_x < 0.8:
            x_pos = "in the middle right"
        else:
            x_pos = "on the right"
        if object_y < 0.2:
            y_pos = "at the top"
        elif object_y < 0.4:
            y_pos = "near the top"
        elif object_y < 0.6:
            y_pos = "in the middle"
        elif object_y < 0.8:
            y_pos = "near the bottom"
        else:
            y_pos = "at the bottom"
        object_positions.append(caption + " " + x_pos + " " + y_pos)
    return object_positions





def read_image_width_height(image_path):
    image = Image.open(image_path)
    width, height = image.size
    return width, height

dense_caption_w_ofa_path = src_file
data_root = "/COCO2014/"

with open(dense_caption_w_ofa_path , 'r') as f1:
    data1 = json.load(f1)

output_data = []



for i in range(len(data1)):
    width, height = read_image_width_height(data_root + data1[i]['image'])
    image_resolution = (width, height)
    object_positions = get_object_positions(image_resolution, data1[i]['dense_caption'].split(";")[:-1])
    # print(object_positions)

    natural_dense_caption = ", ".join(object_positions) + " ."
    # Create a dictionary for each row
    row_data = {
        'image':  data1[i]['image'],
        'image_id': data1[i]['image_id'], # remove this line for test/val set
        'caption':  data1[i]['caption'],
        'ofa_caption': data1[i]['ofa_caption'],
        'dense_caption': data1[i]['dense_caption'],
        'natural_dense_caption': natural_dense_caption,
    }

    # Add the dictionary to the output_data list
    output_data.append(row_data)

print("Total number of rows: ", len(output_data))
# Save the output_data list to a JSON file
with open(dst_file, 'w') as outfile:
    json.dump(output_data, outfile, indent=2)
