import json
import numpy as np
import matplotlib.pyplot as plt

# Input file path
input_file = "filtered_label.json"

# Step 1: Load the filtered JSON data
with open(input_file, "r") as f:
    data = json.load(f)

# Step 2: Extract classes and depths
classes = {}
for image_id, objects in data.items():
    for obj in objects:
        obj_class = obj["class"]
        depth = obj["closest_depth"]
        if depth is not None:  # Exclude None values
            if obj_class not in classes:
                classes[obj_class] = []
            classes[obj_class].append(depth)

# Step 3: Define depth bins
bin_edges = np.arange(8, 81, 1)  # Bins from 8 to 80 with a step of 1

# Step 4: Create histograms for each class
for obj_class, depths in classes.items():
    # Filter out None values (if any remain, as a safeguard)
    depths = [d for d in depths if d is not None]

    # Compute histogram
    hist, _ = np.histogram(depths, bins=bin_edges)
    
    # Step 5: Plot and save the histogram
    plt.figure(figsize=(10, 6))
    plt.bar(bin_edges[:-1], hist, width=1, edgecolor="black", align="edge")
    plt.title(f"Histogram of Depths for Class: {obj_class}")
    plt.xlabel("Depth (m)")
    plt.ylabel("Count")
    plt.xticks(bin_edges, rotation=90)
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    
    # Save the histogram as an image
    plt.savefig(f"{obj_class}_histogram.png")
    plt.close()

print("Histograms have been saved for each class.")