import os
import json

init_round = 1
compare_round = 1
result_file = "FILL OUT YOUR RESULTS FILE"
# GPT_GraPE

# TODO: PLEASE FILL OUR ALL REQUIRED DATA

reports_round = []
with open(f"{result_file}", "r") as file:
    reports = json.load(file)["results"]

benchmark_file = "FILL OUT YOUR BENCHMARKS FILE"

with open(f"{benchmark_file}", 'r') as file:
    datafiles = json.load(file)["data"]

from collections import Counter


def check_orientation(obj_infos, obj_direction):
    if obj_direction is not None and obj_direction != "nan":
        if obj_direction not in obj_infos[2]:
            return False, "Orientation wrong"
    return True, "correct orintation"


def check_spatial_relation(spatial_relation, frame_of_referece, obj1_infos, obj2_infos, obj_direction,
                           orientation_correction=True):
    if spatial_relation is None:
        # print("No spatial realtion")
        return True, "no spatial relation to check", None
    # print(obj_direction)

    bbox1, depth1, orient1 = obj1_infos
    obj2_infos = list(obj2_infos)
    # if obj_direction is not None and obj_direction != "nan":
    #     obj2_infos[-1] = obj_direction
    bbox2, depth2, orient2 = obj2_infos

    rel_relation = spatial_relation
    if "intrinsic" in frame_of_referece and orient2 is not None:
        if "front" in orient2 or "forward" in orient2:
            # Front and back doesn't affect here
            if spatial_relation == "left":
                rel_relation = "right"
            if spatial_relation == "right":
                rel_relation = "left"

        elif "back" in orient2:
            # Left and right doesn't affect here
            if spatial_relation == "front":
                rel_relation = "back"
            if spatial_relation == "back":
                rel_relation = "front"

        elif "right" in orient2:
            if spatial_relation == "front":
                rel_relation = "right"
            if spatial_relation == "right":
                rel_relation = "front"
            if spatial_relation == "back":
                rel_relation = "left"
            if spatial_relation == "left":
                rel_relation = "back"

        elif "left" in orient2:
            if spatial_relation == "front":
                rel_relation = "left"
            if spatial_relation == "right":
                rel_relation = "back"
            if spatial_relation == "back":
                rel_relation = "right"
            if spatial_relation == "left":
                rel_relation = "front"

    if orientation_correction:
        orient_correct, feedback = check_orientation(obj2_infos, obj_direction)

        if not orient_correct:
            return orient_correct, feedback, rel_relation

    # print("Checking spatial")
    if rel_relation == "front":
        check_condition = depth1 - depth2 > 0.0
    elif rel_relation == "back":
        check_condition = depth1 - depth2 < 0.0
    # Left case
    elif rel_relation == "left":
        check_condition = ((bbox2[0] + bbox2[2] / 2) - (bbox1[0] + bbox1[2] / 2)) > 0.0
    else:
        check_condition = ((bbox1[0] + bbox1[2] / 2) - (bbox2[0] + bbox2[2] / 2)) > 0.0

    # Right case
    return check_condition, ("Correct" if check_condition else "Incorrect " + rel_relation), rel_relation


def verify_spatial_info(datafile, spatial_info):
    if spatial_info is None:
        return False, ["Syntax error"], None
    obj1 = datafile["obj1"]
    obj2 = datafile["obj2"]
    obj1_infos = [(bbox, depth, orient) for obj, bbox, depth, orient in spatial_info if obj1 in obj]
    obj2_infos = [(bbox, depth, orient) for obj, bbox, depth, orient in spatial_info if obj2 in obj]
    # print(spatial_info, obj1, obj2)
    if len(obj1_infos) > 1 or len(obj2_infos) > 1:
        return False, ["Multiple objects detect"], None
    if len(obj1_infos) == 0 or len(obj2_infos) == 0:
        return False, ["Missing objects"], None

    prompt = datafile["prompt"]
    obj2_dir = datafile["obj2_dir"]
    # print(obj2_dir)
    # print(datafile["rel1"], datafile["ref_obj1"], obj1_infos[0], obj2_infos[0], None)
    convert_orient = {"front": "front",
                      "forward-left": "front",
                      "forward-right": "front",
                      "backward-left": "back",
                      "backward-right": "right",
                      "left": "left",
                      "right": "right",
                      "back": "back",
                      None: None}

    if datafile["ref_obj1"] == "relative" and datafile["ref_obj2"] == "intrinsic":
        datafile["rel1"] = None
    if datafile["ref_obj1"] == "intrinsic" and datafile["ref_obj2"] == "relative":
        datafile["rel2"] = None
    rel1_check, feedback1, rel1 = check_spatial_relation(datafile["rel1"], datafile["ref_obj1"], obj1_infos[0],
                                                         obj2_infos[0],
                                                         obj_direction=convert_orient.get(obj2_dir, obj2_dir))
    rel2_check, feedback2, rel2 = check_spatial_relation(datafile["rel2"], datafile["ref_obj2"], obj2_infos[0],
                                                         obj1_infos[0], obj_direction=None)

    return rel1_check and rel2_check, (feedback1, feedback2), (rel1, rel2)


import copy

all_feedback_det = []
all_feedback_llm = []
acc_img = 0
acc_llm = 0
total = 0

update_files = []
intrinsic_acc = 0
intrinsic_total = 0
relative_acc = 0
relative_total = 0

for report in reports[:]:
    file_id = int(report["id"].split("_")[1]) - 1

    data_info = datafiles[file_id]

    new_data_info = copy.deepcopy(data_info)
    new_data_info["llm_layout_suggestions"] = report["llm_suggestion"]
    update_files.append(new_data_info)

    verify_detection, feedback_det, relations1 = verify_spatial_info(data_info, report["detection_result"])
    verify_LLM, feedback_llm, relations2 = verify_spatial_info(data_info, report["llm_suggestion"])

    if (data_info["ref_obj1"] and "intrinsic" in data_info["ref_obj1"]) or (
            data_info["ref_obj2"] and "intrinsic" in data_info["ref_obj2"]):
        intrinsic_acc += int(verify_detection)
        intrinsic_total += 1
    else:
        relative_acc += int(verify_detection)
        relative_total += 1

    obj2 = data_info["obj2"]
    obj2_infos = [(bbox, depth, orient, obj) for obj, bbox, depth, orient in report["detection_result"] if obj2 in obj]
    obj2_dir = data_info["obj2_dir"]

    all_feedback_det.extend(feedback_det)
    all_feedback_llm.extend(feedback_llm)

    acc_img += int(verify_detection)
    acc_llm += int(verify_LLM)
    total += 1
    # print(verify_detection, verify_LLM)

print("Detection:", acc_img * 100 / total)
print("LLM:", acc_llm * 100 / total)
# print(total)
print("Relative:", relative_acc / relative_total)
print("Intrinsic:", intrinsic_acc / intrinsic_total)
feedbacks = [str(feedback) for feedback in all_feedback_det]
for k, v in sorted(Counter(feedbacks).items()):
    print(k, v)