import json
import os
import matplotlib.pyplot as plt
from collections import Counter

# === Load JSON file ===
json_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_combines_4dir_10to50perdir.json"  # <-- 替换为你的 JSON 路径
with open(json_path, "r") as f:
    data = json.load(f)

# === Count images by base_dataset (merge ImageNet variants) ===
base_dataset_counter = Counter()

for item in data:
    base = item.get("base_dataset", "unknown").strip()
    # Merge similar dataset names
    if base.lower().startswith("imagenet"):
        base = "ImageNet"
    base_dataset_counter[base] += 1

# === Prepare plot data ===
datasets = list(base_dataset_counter.keys())
counts = list(base_dataset_counter.values())

# === Output directory ===
output_dir = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output"  # <-- 替换为你的输出目录
os.makedirs(output_dir, exist_ok=True)

# === Plot histogram ===
plt.figure(figsize=(10, 6))
plt.bar(datasets, counts, color='steelblue')
plt.title("Number of Images per Base Dataset (Merged ImageNet)", fontsize=16)
plt.xlabel("Base Dataset", fontsize=14)
plt.ylabel("Number of Images", fontsize=14)
plt.xticks(rotation=30, ha="right")
plt.grid(axis="y", linestyle="--", alpha=0.6)
plt.tight_layout()

# === Save figure ===
save_path = os.path.join(output_dir, "base_dataset_histogram_merged_4dir_10to50.png")
plt.savefig(save_path, dpi=300)
plt.close()

print(f"✅ Histogram saved to: {save_path}")
