#%%
import json
import matplotlib.pyplot as plt
from collections import Counter

# === Load JSON file ===
json_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_4dir_10to30perdir.json"  # <-- 改成你的json路径
with open(json_path, "r") as f:
    data = json.load(f)

# === Count number of images per domain ===
domain_counter = Counter()

for item in data:
    domain = item.get("domain", "unknown").strip().lower()
    domain_counter[domain] += 1

# === Prepare data for plotting ===
domains = list(domain_counter.keys())
counts = list(domain_counter.values())

# === Plot histogram ===
plt.figure(figsize=(8, 5))
plt.bar(domains, counts, color='steelblue')
plt.title("Number of Images per Domain", fontsize=14)
plt.xlabel("Domain", fontsize=12)
plt.ylabel("Number of Images", fontsize=12)
plt.xticks(rotation=30)
plt.grid(axis="y", linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()
plt.savefig("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/domain_count_histo_3dir_10to30.png", dpi=300)

print("✅ Done! Histogram plotted.")

# %%
