import json
import os
import matplotlib.pyplot as plt
from collections import defaultdict, Counter

# === Fixed 8 directions ===
DIRECTIONS = [
    "front", "front right", "right", "back right",
    "back", "back left", "left", "front left", "top"
]

# === Output directory ===
output_dir = "/fs/scratch/PAS2099/Jiacheng/Orientation/output/object_fix_histo_v2"
os.makedirs(output_dir, exist_ok=True)

# === Load benchmark.json ===
json_path = "/fs/scratch/PAS2099/Jiacheng/Orientation/output/combined_general_complex_with_top_2devices_v2.json"
with open(json_path, "r") as f:
    data = json.load(f)

# === Count labels for each object under general_complex type ===
object_label_counts = defaultdict(lambda: Counter({dir: 0 for dir in DIRECTIONS}))

for item in data:
    if item["type"] == "general_complex":
        obj = item["category_name"].strip().lower()
        label = item["label"].strip().lower()
        if label in DIRECTIONS:
            object_label_counts[obj][label] += 1

# === Plot for each object ===
for obj, label_counter in object_label_counts.items():
    labels = DIRECTIONS
    counts = [label_counter[dir] for dir in labels]

    plt.figure(figsize=(8, 5))
    plt.bar(labels, counts, color='steelblue')
    plt.title(f"Direction Counts for {obj}", fontsize=14)
    plt.xlabel("Direction", fontsize=12)
    plt.ylabel("Count", fontsize=12)
    plt.xticks(rotation=30)
    plt.tight_layout()
    plt.grid(axis="y", linestyle="--", alpha=0.6)

    # Clean filename
    filename = obj.replace("/", "_").replace(" ", "_").replace(",", "") + ".png"
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

print("✅ All object-wise direction count histograms saved to:", output_dir)


# import json
# import os
# import matplotlib.pyplot as plt
# from collections import defaultdict, Counter

# # === Fixed 8 directions ===
# DIRECTIONS = [
#     "front", "front right", "right", "back right",
#     "back", "back left", "left", "front left"
# ]

# # === Output directory ===
# output_dir = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/option_combined_6to24_histogram"
# os.makedirs(output_dir, exist_ok=True)

# # === Load JSON file ===
# json_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_combines_4dir_6to24perdir.json"
# with open(json_path, "r") as f:
#     data = json.load(f)

# # === Build a structure: object -> source (train/val) -> label counts ===
# object_source_label_counts = defaultdict(lambda: {
#     "train": Counter({dir: 0 for dir in DIRECTIONS}),
#     "val": Counter({dir: 0 for dir in DIRECTIONS})
# })

# for item in data:
#     if item["type"] == "general_complex":
#         obj = item["category_name"].strip().lower()
#         label = item["label"].strip().lower()
#         source = item.get("source", "train").strip().lower()  # Default to train if missing
#         if label in DIRECTIONS:
#             object_source_label_counts[obj][source][label] += 1

# # === Plot separately for train and val ===
# for obj, source_dict in object_source_label_counts.items():
#     for source in ["train", "val"]:
#         label_counter = source_dict[source]
#         counts = [label_counter[dir] for dir in DIRECTIONS]
        
#         # Skip if all counts are zero
#         if sum(counts) == 0:
#             continue

#         plt.figure(figsize=(8, 5))
#         color = "lightcoral" if source == "train" else "steelblue"
#         plt.bar(DIRECTIONS, counts, color=color)
#         plt.title(f"Direction Counts for {obj} ({source})", fontsize=14)
#         plt.xlabel("Direction", fontsize=12)
#         plt.ylabel("Count", fontsize=12)
#         plt.xticks(rotation=30)
#         plt.tight_layout()
#         plt.grid(axis="y", linestyle="--", alpha=0.6)

#         # Clean filename
#         filename = obj.replace("/", "_").replace(" ", "_").replace(",", "") + f"_{source}.png"
#         plt.savefig(os.path.join(output_dir, filename))
#         plt.close()

# print("✅ All object-wise direction histograms saved to:", output_dir)
