import json

# filepath: /fs/ess/PAS2099/sooyoung/perception_system_v2_local/kitti_analysis/FINAL/absolute_depth/first_prep_json.py

# Load the original JSON file
input_file = "label.json"
output_file = "filtered_label.json"

def is_overlapping(bbox1, bbox2):
    """
    Check if two bounding boxes overlap.
    """
    x1_min, y1_min, x1_max, y1_max = bbox1
    x2_min, y2_min, x2_max, y2_max = bbox2

    # Check for no overlap
    if x1_max <= x2_min or x1_min >= x2_max or y1_max <= y2_min or y1_min >= y2_max:
        return False
    return True

def process_labels(data):
    """
    Process the labels to filter out invalid objects.
    """
    total_objects = 0
    valid_objects = 0

    for image_id, objects in data.items():
        valid_objects_list = []
        for i, obj in enumerate(objects):
            total_objects += 1

            # Exclude objects with depth less than 8 meters or of class "truck"
            if (obj["closest_depth"] is not None and obj["closest_depth"] < 8) or obj["class"] == "truck":
                continue

            bbox = obj["bbox_2d"]
            is_valid = True

            # Check for overlap with other objects
            for j, other_obj in enumerate(objects):
                if i == j:
                    continue
                other_bbox = other_obj["bbox_2d"]
                if is_overlapping(bbox, other_bbox):
                    is_valid = False
                    break

            if is_valid:
                valid_objects_list.append(obj)

        # Update the objects list with only valid objects
        data[image_id] = valid_objects_list
        valid_objects += len(valid_objects_list)

    return data, total_objects, valid_objects

def main():
    # Load the original JSON data
    with open(input_file, "r") as f:
        data = json.load(f)

    # Process the data to filter out invalid objects
    updated_data, total_objects, valid_objects = process_labels(data)

    # Save the updated data to a new file
    with open(output_file, "w") as f:
        json.dump(updated_data, f, indent=2)

    # Print the summary
    print(f"Total objects originally: {total_objects}")
    print(f"Valid objects remaining: {valid_objects}")
    print(f"Filtered data saved to {output_file}")

if __name__ == "__main__":
    main()