import math
import random
import itertools
from PIL import Image, ImageDraw, ImageFont
import json
import os
import argparse
import numpy as np
import hashlib
import pickle
import sys


CANVAS_SIZE = 900
MIN_SHAPE_COUNT = 2
MAX_SHAPE_COUNT = 9
MIN_ARROW_COUNT = 1

grid_coords = [300, 600]

# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the full path to the JSON file
json_file_path = os.path.join(script_dir, 'cosine_sim_chosen_icon_paths.json')

with open(json_file_path, 'r') as file:
    ICON_DICT = json.load(file)


ICON_LIST = sorted(list(ICON_DICT.keys()))
SIZE = 100
WRONG_CHOICE_COUNT = 3

POSITION_ORDER = [
    'top left', 'top center', 'top right',
    'center left', 'center center', 'center right',
    'bottom left', 'bottom center', 'bottom right'
]

X_POSITIONS_GRID = ["left", "center", "right"]
Y_POSITIONS_GRID = ["top", "center", "bottom"]

RELATIVE_OPTIONS = ["left", "right", "top", "bottom"]

RELATIONSHIP_OPTIONS = ["from", "to"]

ARROWHEAD_SIZES = [14, 20, 26]
ARROW_WIDTHS = [1, 2, 4, 8]
ARROW_COLORS = ["black", "red", "blue"]


def calculate_position_name(coordinates):
    one_third = CANVAS_SIZE/3

    x_position = "left" if coordinates[0] < one_third else (
        "center" if coordinates[0] < 2 * one_third else "right")
    y_position = "top" if coordinates[1] < one_third else (
        "center" if coordinates[1] < 2 * one_third else "bottom")

    return f"{y_position} {x_position}"


def relationship_to_text_directionless(relationship):
    origin = relationship["origin"]
    destination = relationship["destination"]
    return f"{origin} and {destination}"


def negative_object_sample(positive_objects, count):
    non_existing_objects = sorted(
        list(set(ICON_LIST) - set(positive_objects)))
    return random.sample(non_existing_objects, count)


def create_relationship(origin, destination):
    relationship = {}
    relationship["origin"] = origin
    relationship["destination"] = destination
    return relationship


def get_all_positive_relationships(selected_pairs):

    positive_relationships = []
    for (shape1, shape2) in selected_pairs:
        positive_relationships.append(create_relationship(
            shape1['shape'], shape2['shape']))
    return positive_relationships


def negative_pairs_sample(selected_pairs, all_objects, sample_size):
    all_combinations = list(itertools.combinations(all_objects, 2))
    selected_pairs_set = set(
        frozenset([pair[0]['shape'], pair[1]['shape']]) for pair in selected_pairs)
    negative_pairs = [pair for pair in all_combinations if frozenset(
        pair) not in selected_pairs_set]

    while len(negative_pairs) < sample_size:
        negative_objs = negative_object_sample(all_objects, 2)
        negative_pair = tuple(sorted(negative_objs))
        if negative_pair not in negative_pairs:
            negative_pairs.append(negative_pair)

    # Sample the desired number of negative pairs
    sampled_negative_pairs = random.sample(negative_pairs, sample_size)

    return sampled_negative_pairs


def get_items_in_row_or_column(selected_icons, selected_coordinates, row_or_column, asked_position):
    coord_index = 0 if row_or_column == "row" else 1
    return [icon for icon, coord in zip(
        selected_icons, selected_coordinates) if calculate_position_name(coord).split()[coord_index] == asked_position]


def get_items_in_relative_position(zipped_list, rel_position, picked_icon_coord):

    if rel_position == "left":
        return [icon for (icon, coord) in zipped_list if coord[0] < picked_icon_coord[0]]
    elif rel_position == "right":
        return [icon for (icon, coord) in zipped_list if coord[0] > picked_icon_coord[0]]
    elif rel_position == "top":
        return [icon for (icon, coord) in zipped_list if coord[1] < picked_icon_coord[1]]
    elif rel_position == "bottom":
        return [icon for (icon, coord) in zipped_list if coord[1] > picked_icon_coord[1]]
    else:
        return []


def get_relationships_with_object(random_icon, asked_direction, selected_pairs):
    connected_objects = []
    for (shape1, shape2) in selected_pairs:
        if (asked_direction == "from" or asked_direction == "all") and shape1['shape'] == random_icon:
            connected_objects.append(shape2['shape'])
        elif (asked_direction == "to" or asked_direction == "all") and shape2['shape'] == random_icon:
            connected_objects.append(shape1['shape'])
    return connected_objects


def get_entity_count_qa(correct_count, just_text=False):
    question = "How many text labels are there in the diagram?" if just_text else "How many icons are there in the diagram?"
    possible_counts = list(range(MIN_SHAPE_COUNT, MAX_SHAPE_COUNT+1))
    possible_counts.remove(correct_count)
    wrong_choices = random.sample(possible_counts, WRONG_CHOICE_COUNT)
    all_choices = wrong_choices+[correct_count]
    random.shuffle(all_choices)
    return question, all_choices, correct_count


def get_entity_existence_qa(selected_icons):
    question = "Which one of the entities exists in the diagram?"
    correct_choice = random.choice(selected_icons).replace("_", " ")
    wrong_choices = negative_object_sample(selected_icons, WRONG_CHOICE_COUNT)
    all_choices = wrong_choices+[correct_choice]
    all_choices = [choice.replace("_", " ") for choice in all_choices]
    random.shuffle(all_choices)
    return question, all_choices, correct_choice


def get_abs_position_count_qa(selected_icons, selected_coordinates):
    row_or_column = random.choice(["row", "column"])
    asked_position = random.choice(
        Y_POSITIONS_GRID) if row_or_column == "row" else random.choice(X_POSITIONS_GRID)
    entities_in_position = get_items_in_row_or_column(
        selected_icons, selected_coordinates, row_or_column, asked_position)

    all_choices = list(range(0, len(X_POSITIONS_GRID)+1))
    random.shuffle(all_choices)

    question = f"How many text labels are there in the {asked_position} {row_or_column} of the diagram?"

    correct_count = len(entities_in_position)
    return question, all_choices, correct_count


def get_abs_position_existence_qa(selected_icons, selected_coordinates):
    while True:
        row_or_column = random.choice(["row", "column"])
        asked_position = random.choice(
            Y_POSITIONS_GRID) if row_or_column == "row" else random.choice(X_POSITIONS_GRID)
        entities_in_position = get_items_in_row_or_column(
            selected_icons, selected_coordinates, row_or_column, asked_position)
        if len(entities_in_position) != 0:
            break

    question = f"Which one of the text labels exists in the {asked_position} {row_or_column} of the diagram?"
    correct_choice = random.choice(entities_in_position).replace("_", " ")

    modified_icons = selected_icons.copy()
    for entity in entities_in_position:
        modified_icons.remove(entity)

    if len(modified_icons) >= WRONG_CHOICE_COUNT:
        wrong_choices = random.sample(modified_icons, WRONG_CHOICE_COUNT)
    else:
        wrong_choices = modified_icons.copy()
        remaining_count = WRONG_CHOICE_COUNT - len(wrong_choices)
        wrong_choices.extend(negative_object_sample(
            selected_icons, remaining_count))

    all_choices = wrong_choices+[correct_choice]
    all_choices = [choice.replace("_", " ") for choice in all_choices]
    random.shuffle(all_choices)
    return question, all_choices, correct_choice


def get_rel_position_count_qa(selected_icons, selected_coordinates):
    zipped_list = list(zip(selected_icons, selected_coordinates))
    random_icon_info = random.choice(zipped_list)
    random_icon_name = random_icon_info[0].replace("_", " ")
    asked_position = random.choice(RELATIVE_OPTIONS)

    relative_icons = get_items_in_relative_position(
        zipped_list, asked_position, random_icon_info[1])
    correct_count = len(relative_icons)

    question = f"How many text labels are placed on the {asked_position} of the entity {random_icon_name}?"

    possible_counts = list(range(0, len(selected_icons)))
    possible_counts.remove(correct_count)

    impossible_counts = list(range(len(selected_icons), MAX_SHAPE_COUNT+1))

    if len(possible_counts) >= WRONG_CHOICE_COUNT:
        wrong_choices = random.sample(possible_counts, WRONG_CHOICE_COUNT)
    else:
        wrong_choices = possible_counts.copy()
        remaining_count = WRONG_CHOICE_COUNT - len(wrong_choices)
        wrong_choices.extend(random.sample(impossible_counts, remaining_count))

    all_choices = wrong_choices+[correct_count]
    random.shuffle(all_choices)

    return question, all_choices, correct_count


def get_rel_position_existence_qa(selected_icons, selected_coordinates):

    zipped_list = list(zip(selected_icons, selected_coordinates))

    while True:
        random_icon_info = random.choice(zipped_list)
        asked_position = random.choice(RELATIVE_OPTIONS)

        relative_icons = get_items_in_relative_position(
            zipped_list, asked_position, random_icon_info[1])
        if len(relative_icons) != 0:
            break

    random_icon_name = random_icon_info[0].replace("_", " ")
    question = f"Which one of the text labels is placed on the {asked_position} of the entity {random_icon_name}?"
    correct_choice = random.choice(relative_icons).replace("_", " ")

    modified_icons = selected_icons.copy()
    modified_icons.remove(random_icon_info[0])
    for entity in relative_icons:
        modified_icons.remove(entity)

    if len(modified_icons) >= WRONG_CHOICE_COUNT:
        wrong_choices = random.sample(modified_icons, WRONG_CHOICE_COUNT)
    else:
        wrong_choices = modified_icons.copy()
        remaining_count = WRONG_CHOICE_COUNT - len(wrong_choices)
        wrong_choices.extend(negative_object_sample(
            selected_icons, remaining_count))

    all_choices = wrong_choices+[correct_choice]
    all_choices = [choice.replace("_", " ") for choice in all_choices]
    random.shuffle(all_choices)
    return question, all_choices, correct_choice


def get_relationship_count_directionless_qa(selected_icons, selected_pairs):
    random_icon = random.choice(selected_icons)
    random_icon_name = random_icon.replace("_", " ")

    question = f"How many entities are connected to {random_icon_name}?"

    connected_icons = get_relationships_with_object(
        random_icon, "all", selected_pairs)
    correct_count = len(connected_icons)

    possible_counts = list(range(0, len(selected_icons)))
    possible_counts.remove(correct_count)

    impossible_counts = list(range(len(selected_icons), MAX_SHAPE_COUNT+1))

    if len(possible_counts) >= WRONG_CHOICE_COUNT:
        wrong_choices = random.sample(possible_counts, WRONG_CHOICE_COUNT)
    else:
        wrong_choices = possible_counts.copy()
        remaining_count = WRONG_CHOICE_COUNT - len(wrong_choices)
        wrong_choices.extend(random.sample(impossible_counts, remaining_count))

    all_choices = wrong_choices+[correct_count]
    random.shuffle(all_choices)

    return question, all_choices, correct_count


def get_relationship_existence_directionless_qa(selected_pairs, selected_icons):
    question = "Which one of the pairs are connected in the diagram?"
    all_positive_relationships = get_all_positive_relationships(
        selected_pairs)
    negative_pairs = negative_pairs_sample(
        selected_pairs, selected_icons, WRONG_CHOICE_COUNT)
    correct_choice = relationship_to_text_directionless(
        random.choice(all_positive_relationships)).replace("_", " ")
    wrong_choices = [f"{rel[0]} and {rel[1]}" for rel in negative_pairs]
    all_choices = wrong_choices+[correct_choice]
    all_choices = [choice.replace("_", " ") for choice in all_choices]
    random.shuffle(all_choices)
    return question, all_choices, correct_choice


def get_random_icon(shape):
    icon_image_list = ICON_DICT[shape]
    chosen_image_fname = random.choice(icon_image_list)
    return chosen_image_fname


def get_random_center_coordinates(n_shapes):
    centers = [(150, 150), (150, 450), (150, 750), (450, 150),
               (450, 450), (450, 750), (750, 150), (750, 450), (750, 750)]
    return random.sample(centers, n_shapes)


def get_random_coordinates(n_shapes, min_distance=220, max_attempts_inner=1000, border_x=100, border_y=100):

    attempts_inner = 0
    new_set = []
    while len(new_set) < n_shapes and attempts_inner < max_attempts_inner:
        x = random.randint(border_x, CANVAS_SIZE - border_x)
        y = random.randint(border_y, CANVAS_SIZE - border_y)
        if all(math.hypot(x - cx, y - cy) >= min_distance for cx, cy in new_set):
            new_set.append((x, y))
        else:
            attempts_inner += 1
    if len(new_set) == n_shapes:
        return new_set
    else:
        return []


def line_segment_circle_intersection(x1, y1, x2, y2, xc, yc, r):
    if (x1 == xc and y1 == yc) or (x2 == xc and y2 == yc):
        return False
    a = (x2 - x1) ** 2 + (y2 - y1) ** 2
    b = 2 * ((x2 - x1) * (x1 - xc) + (y2 - y1) * (y1 - yc))
    c = (x1 - xc) ** 2 + (y1 - yc) ** 2 - r ** 2

    # Discriminant
    delta = b ** 2 - 4 * a * c

    if delta < 0:
        return False  # No intersection
    elif delta == 0:
        t = -b / (2 * a)
        return 0 <= t <= 1  # Tangent to the circle
    else:
        sqrt_delta = math.sqrt(delta)
        t1 = (-b + sqrt_delta) / (2 * a)
        t2 = (-b - sqrt_delta) / (2 * a)
        return (0 <= t1 <= 1) or (0 <= t2 <= 1)  # Intersection at two points


def add_gaussian_noise(canvas, noise_percentage):

    image_array = np.asarray(canvas, dtype=np.uint8)
    intensity_range = 255
    noise_std = (noise_percentage / 100) * intensity_range
    gaussian_noise = np.random.normal(0, noise_std, image_array.shape)

    noisy_image_array = image_array + gaussian_noise

    # Clip values to be within the valid intensity range
    noisy_image_array = np.clip(noisy_image_array, 0, 255)
    noisy_image_array = noisy_image_array.astype(np.uint8)

    noisy_image = Image.fromarray(noisy_image_array)

    return noisy_image


def apply_center_mask(image, object_center, object_size, mask_size_percentage):

    image_array = np.array(image, dtype=np.uint8, copy=True)

    # Calculate the mask's width and height based on the object's size and the percentage
    mask_size = int(object_size * (math.sqrt(mask_size_percentage / 100)))

    # Calculate the mask's bounding box
    x_start = max(0, int(object_center[0] - mask_size / 2))
    x_end = min(image.width, int(object_center[0] + mask_size / 2))
    y_start = max(0, int(object_center[1] - mask_size / 2))
    y_end = min(image.height, int(object_center[1] + mask_size / 2))

    # Apply the mask
    if len(image_array.shape) == 3:  # For RGB images
        image_array[y_start:y_end, x_start:x_end, :] = 255
    else:  # For grayscale images
        image_array[y_start:y_end, x_start:x_end] = 255

    # Convert the modified numpy array back to a Pillow Image object
    masked_image = Image.fromarray(image_array)

    return masked_image


def render_only_text(d, text, coordinates, font_path='arial.ttf', font_size=30):
    text = text.replace("_", " ")
    # Load a font with a specific size
    font = ImageFont.truetype(font_path, font_size)
    textlength = d.textlength(text, font=font)

    # Adjust starting coordinates based on text width to center it
    top_left_x = coordinates[0] - textlength // 2
    top_left_y = coordinates[1] - font_size // 2

    d.text((top_left_x, top_left_y), text, font=font, fill="black")


def render_text(d, text, coordinates):
    text = text.replace("_", " ")
    font = ImageFont.load_default()
    textlength = d.textlength(text, font=font)

    top_left_x = coordinates[0] - textlength // 2
    top_left_y = coordinates[1]

    d.text((top_left_x, top_left_y), text, font=font, fill="black")


def draw_icon(canvas, d, shape_info, text, gaussian_noise_percentage=0, mask_percentage=0):

    icon_name = shape_info['shape']
    coordinates = shape_info['coordinates']
    size = shape_info['size']
    image_path = shape_info['image_path']

    if text == "just_text":
        render_only_text(d, icon_name, coordinates)

    else:
        centered_coordinates = (
            int(coordinates[0]-(size/2)), int(coordinates[1]-(size/2)))
        if text == "w_text" or text == "random_text":
            text_coordinates = (
                int(coordinates[0]), int(coordinates[1]+(2*size/3)))

            displayed_text = icon_name
            if text == "random_text":
                while displayed_text == icon_name:
                    displayed_text = random.choice(ICON_LIST)

            render_text(d, displayed_text, text_coordinates)

        # print(os.path.join(os.getcwd(), image_path))

        image_full_path = os.path.join(os.getcwd(), image_path)
        icon = Image.open(image_full_path)
        icon = icon.resize((size, size))
        icon_noisy = add_gaussian_noise(icon, gaussian_noise_percentage)

        icon_noisy = apply_center_mask(
            icon_noisy, (size/2, size/2), size, mask_percentage)
        canvas.paste(icon_noisy, centered_coordinates)


def draw_arrow(draw, start, end, size_start, size_end, arrow_width=2, arrowhead_size=10, arrow_color="black"):
    """
    Draw an arrow from start to end with arrowheads at the end point, shortening the arrow length.

    Parameters:
    draw: ImageDraw object
    start, end: tuples of x, y indicating the start and end of the line.
    arrowhead_size: the size of the arrow head
    """
    def draw_arrowhead(point, angle, arrowhead_size, arrow_color):
        arrowhead_1 = (point[0] - arrowhead_size * math.cos(angle - math.pi / 6),
                       point[1] - arrowhead_size * math.sin(angle - math.pi / 6))

        arrowhead_2 = (point[0] - arrowhead_size * math.cos(angle + math.pi / 6),
                       point[1] - arrowhead_size * math.sin(angle + math.pi / 6))

        draw.polygon([point, arrowhead_1, arrowhead_2], fill=arrow_color)

    def draw_aa_line(draw_aa, start_aa, end_aa, width, fill):
        draw_aa.line([start_aa, end_aa], fill=fill, width=width)
        draw_aa.line([(start_aa[0]-1, start_aa[1]), (end_aa[0]-1, end_aa[1])],
                     fill=fill, width=width)
        draw_aa.line([(start_aa[0]+1, start_aa[1]), (end_aa[0]+1, end_aa[1])],
                     fill=fill, width=width)
        draw_aa.line([(start_aa[0], start_aa[1]-1), (end_aa[0], end_aa[1]-1)],
                     fill=fill, width=width)
        draw_aa.line([(start_aa[0], start_aa[1]+1), (end_aa[0], end_aa[1]+1)],
                     fill=fill, width=width)

    # Calculate the shortened end point
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    length = math.sqrt(dx**2 + dy**2)

    # Adjust the length of the line
    short_start = (start[0] + size_start * (dx / length),
                   start[1] + size_start * (dy / length))
    short_end = (end[0] - size_end * (dx / length),
                 end[1] - size_end * (dy / length))

    arrowhead_end = (short_end[0] + (arrowhead_size)/2 * (dx / length),
                     short_end[1] + (arrowhead_size)/2 * (dy / length))

    # Draw the line
    # draw.line([short_start, short_end], fill="black",
    #          width=arrow_width)

    draw_aa_line(draw, short_start, short_end,
                 width=arrow_width, fill=arrow_color)

    # Calculate the angle of the line
    angle = math.atan2(short_end[1] - short_start[1],
                       short_end[0] - short_start[0])

    draw_arrowhead(arrowhead_end, angle, arrowhead_size, arrow_color)


def draw_dotted_line(draw, start, end, color="black", dot_length=5, space_length=5):
    x1, y1 = start
    x2, y2 = end
    if x1 == x2:  # Vertical line
        total_length = y2 - y1
        num_dots = total_length // (dot_length + space_length)
        for i in range(int(num_dots)):
            start_dot = y1 + i * (dot_length + space_length)
            end_dot = start_dot + dot_length
            draw.line([(x1, start_dot), (x2, end_dot)], fill=color)
    elif y1 == y2:  # Horizontal line
        total_length = x2 - x1
        num_dots = total_length // (dot_length + space_length)
        for i in range(int(num_dots)):
            start_dot = x1 + i * (dot_length + space_length)
            end_dot = start_dot + dot_length
            draw.line([(start_dot, y1), (end_dot, y2)], fill=color)


def create_random_diagram_metadata(with_grid=False, bidirectional_arrows=False, max_attempts_positioning=100):
    while True:
        shape_count = random.randint(
            MIN_SHAPE_COUNT, MAX_SHAPE_COUNT)

        # number of possible arrows
        reasonable_max = shape_count * (shape_count - 1) // 2
        arrow_count = random.randint(MIN_ARROW_COUNT, reasonable_max)

        selected_coordinates = []
        attempt_count = 0
        while attempt_count < max_attempts_positioning:
            selected_coordinates = []
            while not selected_coordinates and attempt_count < max_attempts_positioning:
                attempt_count += 1
                selected_coordinates = get_random_center_coordinates(
                    shape_count) if with_grid else get_random_coordinates(shape_count)

            if not selected_coordinates:
                continue

            drawable_arrows = 0
            combinations = list(
                itertools.combinations(selected_coordinates, 2))
            for (x1, y1), (x2, y2) in combinations:
                intersection_check = [line_segment_circle_intersection(
                    x1, y1, x2, y2, xc, yc, 100) for (xc, yc) in selected_coordinates]
                if all(not item for item in intersection_check):
                    drawable_arrows += 1
                    if drawable_arrows == arrow_count:
                        break
            if drawable_arrows == arrow_count:
                break

        if attempt_count < max_attempts_positioning:
            break

    selected_icons = random.sample(ICON_LIST, shape_count)
    chosen_image_paths = [get_random_icon(icon) for icon in selected_icons]
    selected_sizes = [SIZE] * shape_count

    shapes_info = []
    for ind, (icon, chosen_image_path, coordinates, size) in enumerate(zip(selected_icons, chosen_image_paths, selected_coordinates, selected_sizes)):
        shape_id = "shape_" + str(ind)
        shape_info = {'id': shape_id, 'shape': icon, 'image_path': chosen_image_path,
                      'coordinates': coordinates, 'size': size}
        shapes_info.append(shape_info)

    combinations = list(itertools.combinations(shapes_info, 2))
    random.shuffle(combinations)
    selected_pairs = []

    for selected_pair in combinations:
        selected_pair = tuple(reversed(selected_pair)) if random.choice(
            [True, False]) else selected_pair
        (shape1, shape2) = selected_pair
        (x1, y1) = shape1['coordinates']
        (x2, y2) = shape2['coordinates']
        intersection_check = [line_segment_circle_intersection(
            x1, y1, x2, y2, xc, yc, 100) for (xc, yc) in selected_coordinates]
        if all(not item for item in intersection_check):
            selected_pairs.append(selected_pair)
        if len(selected_pairs) == arrow_count:
            break

    arrows_info = []
    bidirectionality_info = []
    for ind, (shape1, shape2) in enumerate(selected_pairs):
        arrow_id = "arrow_" + str(ind)
        arrow_info = {'id': arrow_id, 'origin': shape1['id'],
                      'destination': shape2['id']}
        if bidirectional_arrows:
            bidirectional = random.choice([True, False])
            bidirectionality_info.append(bidirectional)
            arrow_info["bidirectional"] = bidirectional
        arrows_info.append(arrow_info)

    return shape_count, selected_icons, selected_coordinates, arrow_count, selected_pairs, shapes_info, arrows_info, bidirectionality_info


def create_diagram_metadata_from_relationship_tuples(rel_tuples, max_attempts_positioning=1000):
    given_arrow_count = len(rel_tuples)
    unique_items = set(item for tup in rel_tuples for item in tup)
    unique_items_list = list(unique_items)
    given_shape_count = len(unique_items_list)

    # print(rel_tuples)

    selected_coordinates = []
    attempt_count = 0
    while attempt_count < max_attempts_positioning:
        selected_coordinates = []
        while not selected_coordinates and attempt_count < max_attempts_positioning:
            attempt_count += 1
            selected_coordinates = get_random_coordinates(
                given_shape_count)

        if not selected_coordinates:
            continue

        for _ in range(100):

            random.shuffle(selected_coordinates)

            drawable_arrows = 0

            for ent_1, ent_2 in rel_tuples:
                (x1, y1) = selected_coordinates[unique_items_list.index(ent_1)]
                (x2, y2) = selected_coordinates[unique_items_list.index(ent_2)]
                intersection_check = [line_segment_circle_intersection(
                    x1, y1, x2, y2, xc, yc, 100) for (xc, yc) in selected_coordinates]
                if all(not item for item in intersection_check):
                    drawable_arrows += 1
                    if drawable_arrows == given_arrow_count:
                        break
            if drawable_arrows == given_arrow_count:
                break
        if drawable_arrows == given_arrow_count:
            break

    if attempt_count == max_attempts_positioning:
        print("Max attempts reached")
        print(
            f"Obj count: {given_shape_count}, Rel count: {given_arrow_count}")
        return "", "", "", "", "", "", "", ""
    selected_sizes = [SIZE] * given_shape_count

    shapes_info = []
    for ind, (icon, coordinates, size) in enumerate(zip(unique_items_list, selected_coordinates, selected_sizes)):
        shape_id = "shape_" + str(ind)
        shape_info = {'id': shape_id, 'shape': icon, 'image_path': '',
                      'coordinates': coordinates, 'size': size}
        shapes_info.append(shape_info)

    selected_pairs = []

    for ent_1, ent_2 in rel_tuples:
        shape1 = next(
            (item for item in shapes_info if item['shape'] == ent_1), None)
        shape2 = next(
            (item for item in shapes_info if item['shape'] == ent_2), None)

        (x1, y1) = shape1['coordinates']
        (x2, y2) = shape2['coordinates']
        intersection_check = [line_segment_circle_intersection(
            x1, y1, x2, y2, xc, yc, 100) for (xc, yc) in selected_coordinates]
        if all(not item for item in intersection_check):
            selected_pairs.append((shape1, shape2))

    arrows_info = []
    bidirectionality_info = []
    for ind, (shape1, shape2) in enumerate(selected_pairs):
        arrow_id = "arrow_" + str(ind)
        arrow_info = {'id': arrow_id, 'origin': shape1['id'],
                      'destination': shape2['id']}
        arrows_info.append(arrow_info)

    return given_shape_count, unique_items_list, selected_coordinates, given_arrow_count, selected_pairs, shapes_info, arrows_info, bidirectionality_info


def draw_diagram(shapes_info, selected_pairs, arrowhead_size, arrow_thickness, arrow_color, with_grid=False, text="w_text", gaussian_noise_percentage=0, mask_percentage=0, bidirectionality_info=[]):
    canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), color='white')
    d = ImageDraw.Draw(canvas)

    if with_grid:
        for coord in grid_coords:
            draw_dotted_line(d, (0, coord), (CANVAS_SIZE, coord), "black")
            draw_dotted_line(d, (coord, 0), (coord, CANVAS_SIZE), "black")

    arrow_width = arrow_thickness

    for shape_info in shapes_info:
        draw_icon(canvas, d, shape_info, text,
                  gaussian_noise_percentage, mask_percentage)
    i = 0
    for (shape1, shape2) in selected_pairs:

        draw_arrow(d, shape1['coordinates'], shape2['coordinates'],
                   SIZE, SIZE, arrow_width, arrowhead_size, arrow_color)
        if bidirectionality_info and bidirectionality_info[i]:
            draw_arrow(d, shape2['coordinates'], shape1['coordinates'],
                       SIZE, SIZE, arrow_width, arrowhead_size, arrow_color)
        i += 1
    return canvas


def create_sample_dataset(size=20, dataset_name="icon_dataset", subtask_string="image", ICON_LIST_path="", gaussian_noise_percentage=0, mask_percentage=0, random_attributes=0):
    global ICON_LIST
    images_directory = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'images')
    os.makedirs(images_directory, exist_ok=True)

    statistics = {"shape_counts": [], "selected_icons": [],
                  "arrow_counts": []}

    if ICON_LIST_path:
        with open(ICON_LIST_path, 'r') as file:
            ICON_LIST = json.load(file)

    with_grid = True if subtask_string == "abs_position" else False
    bidirectional_arrows = True if subtask_string == "relationship_bi" else False
    required_diagrams = ["w_text", "wout_text", "random_text", "just_text"]

    qa_info = {}
    all_shapes_info = {}
    for i in range(size):
        shape_count, selected_icons, selected_coordinates, arrow_count, selected_pairs, shapes_info, arrows_info, bidirectionality_info = create_random_diagram_metadata(
            with_grid, bidirectional_arrows)

        if subtask_string == "image":
            count_question, count_choices, count_answer = get_entity_count_qa(
                shape_count, just_text=False)
            existence_question, existence_choices, existence_answer = get_entity_existence_qa(
                selected_icons)
        elif subtask_string == "text":
            count_question, count_choices, count_answer = get_entity_count_qa(
                shape_count, just_text=True)
            existence_question, existence_choices, existence_answer = get_entity_existence_qa(
                selected_icons)
        elif subtask_string == "abs_position":
            count_question, count_choices, count_answer = get_abs_position_count_qa(
                selected_icons, selected_coordinates)
            existence_question, existence_choices, existence_answer = get_abs_position_existence_qa(
                selected_icons, selected_coordinates)
        elif subtask_string == "rel_position":
            count_question, count_choices, count_answer = get_rel_position_count_qa(
                selected_icons, selected_coordinates)
            existence_question, existence_choices, existence_answer = get_rel_position_existence_qa(
                selected_icons, selected_coordinates)
        elif subtask_string == "relationship_directionless":
            count_question, count_choices, count_answer = get_relationship_count_directionless_qa(
                selected_icons, selected_pairs)
            existence_question, existence_choices, existence_answer = get_relationship_existence_directionless_qa(
                selected_pairs, selected_icons)

        picked_q_type = random.choice(["count", "existence"])

        qa_info_img = {"count": {"question": count_question,
                                 "choices": count_choices,
                                 "answer": count_answer},
                       "existence": {"question": existence_question,
                                     "choices": existence_choices,
                                     "answer": existence_answer},
                       "random_pick": picked_q_type
                       }
        qa_info[str(i)] = qa_info_img

        statistics["shape_counts"].append(shape_count)
        statistics["selected_icons"].append(selected_icons)
        statistics["arrow_counts"].append(arrow_count)

        arrowhead_size, arrow_thickness, arrow_color = 10, 2, "black"
        if random_attributes != 0:
            arrowhead_size = random.choice(ARROWHEAD_SIZES)
            arrow_thickness = random.choice(ARROW_WIDTHS)
            arrow_color = random.choice(ARROW_COLORS)

        default_image_path = ""
        for diagram_type in required_diagrams:
            img = draw_diagram(shapes_info, selected_pairs, arrowhead_size, arrow_thickness, arrow_color, with_grid, text=diagram_type,
                               gaussian_noise_percentage=gaussian_noise_percentage, mask_percentage=mask_percentage, bidirectionality_info=bidirectionality_info)
            curr_images_directory = os.path.join(
                images_directory, diagram_type)
            if (subtask_string == "image" and diagram_type == "wout_text") or (subtask_string != "image" and diagram_type == "just_text"):
                os.makedirs(curr_images_directory, exist_ok=True)
                image_filename = os.path.join(
                    curr_images_directory, f'{i}.png')
                img.save(image_filename)
                if not default_image_path:
                    default_image_path = image_filename

        base_path = os.getcwd()
        relative_image_path = os.path.relpath(default_image_path, base_path)

        # Add shapes and their descriptions to the all_shapes_info dict
        all_shapes_info[str(i)] = {
            'image_path': relative_image_path,
            'shapes': shapes_info,
            'arrows': arrows_info
        }

        if random_attributes != 0:
            all_shapes_info[str(i)].update({"arrowhead_size": arrowhead_size,
                                           "arrow_thickness": arrow_thickness,
                                            "arrow_color": arrow_color})

    # Save all the shape information and descriptions to a JSON file
    info_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_info.json')
    with open(info_json_filename, 'w') as json_file:
        json.dump(all_shapes_info, json_file, indent=4)

    qa_info_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_qa_info.json')
    with open(qa_info_json_filename, 'w') as json_file:
        json.dump(qa_info, json_file, indent=4)

    stats_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_stats.json')
    with open(stats_json_filename, 'w') as json_file:
        json.dump(statistics, json_file, indent=4)

    print(f"Saved {size} images and their descriptions to {info_json_filename}")


def create_dataset_from_relationship_file(relationship_pickle_path="knowledge_g.txt", dataset_name="relationship_dataset_with_knowledge", subtask_string="relationship_directionless", gaussian_noise_percentage=0, mask_percentage=0, random_attributes=0):
    with open(relationship_pickle_path, "rb") as fp:
        given_relationships = pickle.load(fp)

    global ICON_LIST
    images_directory = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'images')
    os.makedirs(images_directory, exist_ok=True)

    statistics = {"shape_counts": [], "selected_icons": [],
                  "arrow_counts": []}

    required_diagrams = ["just_text"]
    with_grid = False

    qa_info = {}
    all_shapes_info = {}
    for i, rel_tuples in enumerate(given_relationships):
        shape_count, selected_icons, selected_coordinates, arrow_count, selected_pairs, shapes_info, arrows_info, bidirectionality_info = create_diagram_metadata_from_relationship_tuples(
            rel_tuples)
        if shape_count == "":
            print(i)
            continue

        if subtask_string == "relationship_directionless":
            count_question, count_choices, count_answer = get_relationship_count_directionless_qa(
                selected_icons, selected_pairs)
            existence_question, existence_choices, existence_answer = get_relationship_existence_directionless_qa(
                selected_pairs, selected_icons)

        picked_q_type = random.choice(["count", "existence"])

        qa_info_img = {"count": {"question": count_question,
                                 "choices": count_choices,
                                 "answer": count_answer},
                       "existence": {"question": existence_question,
                                     "choices": existence_choices,
                                     "answer": existence_answer},
                       "random_pick": picked_q_type
                       }
        qa_info[str(i)] = qa_info_img

        statistics["shape_counts"].append(shape_count)
        statistics["selected_icons"].append(selected_icons)
        statistics["arrow_counts"].append(arrow_count)

        arrowhead_size, arrow_thickness, arrow_color = 10, 2, "black"
        if random_attributes != 0:
            arrowhead_size = random.choice(ARROWHEAD_SIZES)
            arrow_thickness = random.choice(ARROW_WIDTHS)
            arrow_color = random.choice(ARROW_COLORS)

        default_image_path = ""
        for diagram_type in required_diagrams:
            img = draw_diagram(shapes_info, selected_pairs, arrowhead_size, arrow_thickness, arrow_color, with_grid, text=diagram_type,
                               gaussian_noise_percentage=gaussian_noise_percentage, mask_percentage=mask_percentage, bidirectionality_info=bidirectionality_info)
            curr_images_directory = os.path.join(
                images_directory, diagram_type)
            os.makedirs(curr_images_directory, exist_ok=True)
            image_filename = os.path.join(curr_images_directory, f'{i}.png')
            img.save(image_filename)
            if not default_image_path:
                default_image_path = image_filename

        # Add shapes and their descriptions to the all_shapes_info dict
        all_shapes_info[str(i)] = {
            'image_path': default_image_path,
            'shapes': shapes_info,
            'arrows': arrows_info
        }

        if random_attributes != 0:
            all_shapes_info[str(i)].update({"arrowhead_size": arrowhead_size,
                                           "arrow_thickness": arrow_thickness,
                                            "arrow_color": arrow_color})

    # Save all the shape information and descriptions to a JSON file
    info_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_info.json')
    with open(info_json_filename, 'w') as json_file:
        json.dump(all_shapes_info, json_file, indent=4)

    qa_info_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_qa_info.json')
    with open(qa_info_json_filename, 'w') as json_file:
        json.dump(qa_info, json_file, indent=4)

    stats_json_filename = os.path.join(
        'datasets', 'synthetic_datasets', dataset_name, 'diagram_stats.json')
    with open(stats_json_filename, 'w') as json_file:
        json.dump(statistics, json_file, indent=4)

    print(f"Saved images and their descriptions to {info_json_filename}")


def string_to_int(s):
    hash_object = hashlib.sha256(s.encode())
    hex_digest = hash_object.hexdigest()
    hash_int = int(hex_digest, 16)
    return hash_int & ((1 << 32) - 1)


def main():
    parser = argparse.ArgumentParser(
        description='Generate dataset')

    parser.add_argument('dataset_name', type=str,
                        help='The name of the dataset to be generated.')
    parser.add_argument('size', type=int,
                        help='The size of the dataset.')
    parser.add_argument('subtask_string', type=str,
                        help='The subtask name. Options are: ["image", "text", "abs_position", "rel_position", "relationship_directionless"].')
    parser.add_argument('--ICON_LIST_path', type=str, default='',
                        help='The path to the file containing the subset of icon names.')
    parser.add_argument('--gaussian_noise_percentage', type=int, default=0,
                        help='The Gaussian noise added to the icons.')
    parser.add_argument('--mask_percentage', type=int, default=0,
                        help='The mask added to the icons.')
    parser.add_argument('--random_attributes', type=int, default=0,
                        help='Give random attributes to arrows.')
    parser.add_argument('--relationship_file', type=str, default="",
                        help='File with relationship information')

    # Parse the arguments
    args = parser.parse_args()

    hashed_seed = string_to_int(args.subtask_string)
    print(hashed_seed)

    random.seed(hashed_seed)
    np.random.seed(hashed_seed)

    if args.relationship_file:
        create_dataset_from_relationship_file(args.relationship_file, args.dataset_name,
                                              "relationship_directionless",
                                              gaussian_noise_percentage=args.gaussian_noise_percentage,
                                              mask_percentage=args.mask_percentage, random_attributes=args.random_attributes)
    else:
        create_sample_dataset(args.size, args.dataset_name, args.subtask_string, args.ICON_LIST_path,
                              gaussian_noise_percentage=args.gaussian_noise_percentage, mask_percentage=args.mask_percentage,
                              random_attributes=args.random_attributes)


if __name__ == "__main__":
    main()
