import json
import matplotlib.pyplot as plt

# Load the label.json file
with open("label.json", "r") as file:
    data = json.load(file)

# Initialize counters
valid_image_counts = []
total_images_generated = 0
image_count_by_objects = {}

# Analyze each image
image_height = 100  # Replace with actual image height if available
min_gap = 10

for image_id, bboxes in data.items():
    if len(bboxes) == 1:
        # Single object: valid
        valid_image_counts.append(1)
        total_images_generated += 1
        image_count_by_objects[1] = image_count_by_objects.get(1, 0) + 1
    else:
        # Multiple objects: check bounding box alignment
        bboxes_sorted = sorted(bboxes, key=lambda bbox: bbox['bbox_2d'][0])  # Sort by x_min

        valid = True
        for i in range(len(bboxes_sorted) - 1):
            x_max_left = bboxes_sorted[i]['bbox_2d'][2]  # x_max of the left bbox
            x_min_right = bboxes_sorted[i + 1]['bbox_2d'][0]  # x_min of the right bbox
            gap = x_min_right - x_max_left
            if gap < min_gap:
                valid = False
                break

        if valid:
            num_objects = len(bboxes_sorted)
            valid_image_counts.append(num_objects)
            total_images_generated += num_objects
            image_count_by_objects[num_objects] = image_count_by_objects.get(num_objects, 0) + 1

# Print the total valid image count
print(f"Total valid images: {len(valid_image_counts)}")

# Print the total images generated for each object count
print("Images generated by object count:")
for obj_count, img_count in sorted(image_count_by_objects.items()):
    print(f"{obj_count} objects: {img_count * obj_count} images")

# Print the total number of images generated
print(f"Total images generated: {total_images_generated}")

# Plot the histogram of valid image counts
plt.figure(figsize=(10, 6))
plt.hist(valid_image_counts, bins=range(max(valid_image_counts) + 2), edgecolor='black', alpha=0.7)
plt.title("Histogram of Valid Image Counts per Number of Objects", fontsize=16)
plt.xlabel("Number of Objects", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.xticks(range(max(valid_image_counts) + 1))
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Save the plot as a file
plt.savefig("valid_image_counts_histogram.png", dpi=300, bbox_inches='tight')

# Optionally, display the plot
plt.show()