import os
import json
from collections import Counter


def check_same_object(box1, box2, iou_threshold=0.9):
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    # Calculate the coordinates of the intersection rectangle
    x_intersection = max(x1, x2)
    y_intersection = max(y1, y2)
    w_intersection = min(x1 + w1, x2 + w2) - x_intersection
    h_intersection = min(y1 + h1, y2 + h2) - y_intersection

    # Check if there is no intersection (one or both boxes have zero area)
    if w_intersection <= 0 or h_intersection <= 0:
        return False

    # Calculate the area of intersection
    intersection_area = w_intersection * h_intersection

    # Calculate the areas of the two bounding boxes
    area1 = w1 * h1
    area2 = w2 * h2

    # Calculate the Union area (area of box1 + area of box2 - intersection area)
    union_area = area1 + area2 - intersection_area

    # Calculate IoU
    iou = intersection_area / union_area
    # print("IoU = ", iou, flush=True)
    # print(iou)
    if iou > iou_threshold:
        return True
    else:
        return False


def check_same_depth(depth1, depth2, depth_threshold=0.01):
    return True if abs(depth1 - depth2) < depth_threshold else False


def pop_entry_via_name(obj_name, det_list):
    for idx, obj in enumerate(det_list):
        if obj_name == obj[0]:
            ret = det_list[idx]
            det_list.pop(idx)
            return ret
    return None


def pop_entry_via_box(bbox, det_list):
    for idx, obj in enumerate(det_list):
        if bbox[0] == obj[0] and check_same_object(bbox[1], obj[1]) is True:
            ret = det_list[idx]
            det_list.pop(idx)
            return ret
    return None


def peak_bbox_via_name(target_base_name, det_results):
    return_list = []
    for obj in det_results:
        base_name = obj[0].split(" #")[0]
        if base_name == target_base_name:
            return_list.append(obj)
    return return_list


def parse_list(det_results, llm_suggestions, iou_threshold=0.9, depth_threshold=0.01):
    """
    Take detection result and llm suggestions (two lists) as input,
    prase them into four categories: add / move / remove / change attr
    """
    key_curr = set([obj[0] for obj in det_results])
    key_goal = set([obj[0] for obj in llm_suggestions])
    add_keys = key_goal - key_curr  # Add / Change Attr
    sub_keys = key_curr - key_goal  # Remove / Change Attr
    same_keys = key_curr.intersection(key_goal)  # Possible Move

    remove_objects = []
    add_objects = []
    move_objects = []
    change_attr_object = []
    preserve_objects = []
    change_depth_object = []
    only_depth_chage = []
    facing_dir_change = []

    modify_image = False

    check_move_object = False
    change_object = False
    for key in same_keys:
        old_entry = pop_entry_via_name(key, det_results)
        new_entry = pop_entry_via_name(key, llm_suggestions)

        # Move and change depth should not happen at the same time
        if len(new_entry) == 4 and check_same_depth(old_entry[2], new_entry[2], depth_threshold) is False:
            if check_same_object(old_entry[1], new_entry[1], iou_threshold) is False:
                change_depth_object.append((tuple(old_entry), tuple(new_entry)))
            else:
                only_depth_chage.append((tuple(old_entry), tuple(new_entry)))
        elif (
                check_same_object(old_entry[1], new_entry[1], iou_threshold) is False
        ):  # Move
            move_objects.append((tuple(old_entry), tuple(new_entry)))
        elif len(new_entry) == 4 and new_entry[3] is not None and old_entry[3] != new_entry[3]:
            facing_dir_change.append(tuple(new_entry))
        else:
            preserve_objects.append(tuple(old_entry))

    # Add or change attribute
    for key in add_keys:
        new_entry = pop_entry_via_name(key, llm_suggestions)
        base_object = key.split(" #")[0].split(" ")[-1]
        change_attr = False
        # Peak objects with basename
        candidates = peak_bbox_via_name(base_object, det_results)
        for obj in candidates:
            if check_same_object(obj[1], new_entry[1], iou_threshold):
                change_attr = True
                change_attr_object.append(tuple(new_entry))
                # also remove it from det_attr_dict if needed
                sub_keys.remove(obj[0])
                break
        # Still need to add new object
        if change_attr is False:
            add_objects.append(tuple(new_entry))

    # Removal Part
    for key in sub_keys:
        entry = pop_entry_via_name(key, det_results)
        remove_objects.append(tuple(entry))

    # Check attribute change
    return preserve_objects, remove_objects, add_objects, move_objects, change_attr_object, facing_dir_change, change_depth_object


if __name__ == "__main__":
    with open("FILL RESULT FILE HERE", "r") as file:
        reports = json.load(file)["results"]

    count_operation = {
        "remove": 0,
        "add": 0,
        "move": 0,
        "attribute change": 0,
        "facing direction change": 0,
        "depth change": 0,
        "total": 0
    }
    for report in reports:
        det_result = report["detection_result"]
        llm_suggestion = report["llm_suggestion"]
        if llm_suggestion is None:
            continue
        preserve_objects, remove_objects, add_objects, move_objects, change_attr_object, facing_dir_change, change_depth_object = parse_list(
            det_result, llm_suggestion)
        count_operation["remove"] += len(remove_objects)
        count_operation["add"] += len(add_objects)
        count_operation["move"] += len(move_objects)
        count_operation["attribute change"] += len(change_attr_object)
        count_operation["facing direction change"] += len(facing_dir_change)
        count_operation["depth change"] += len(change_depth_object)
        count_operation["total"] += len(remove_objects) + len(add_objects) + len(move_objects) + len(
            change_attr_object) + len(facing_dir_change) + len(change_depth_object)
        # print(len(preserve_objects))
    print(count_operation)