#%%
import json
import matplotlib.pyplot as plt
from collections import Counter

# === Step 1: Load benchmark.json ===
# with open("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/EgoOrientBench/all_data/EgocentricDataset/train_benchmark/benchmark.json", "r") as f:
#     data = json.load(f)
# json_path = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_4dir_10to50perdir.json"
json_path = "/fs/scratch/PAS2099/Jiacheng/Orientation/output/combined_general_complex_with_top_2devices_v2.json"
with open(json_path, "r") as f:
    data = json.load(f)

# === Step 2: Count category_name under general_complex ===
category_counts = Counter()

for item in data:
    if item["type"] == "general_complex":
        category = item["category_name"].strip().lower()
        category_counts[category] += 1

# === Step 3: Extract labels and counts (sorted) ===
sorted_items = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)
labels, counts = zip(*sorted_items)

# === Step 4: Plot histogram ===
plt.figure(figsize=(22, 6))
plt.bar(labels, counts, color='steelblue')
plt.title("Number of general_complex Questions per Object Category", fontsize=16)
plt.xlabel("Object Category", fontsize=12)
plt.ylabel("Number of Questions", fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=9)
plt.tight_layout()
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.show()
plt.savefig("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output2/final_combination.png", dpi=300, bbox_inches='tight')

# %%
#%%
import json
import matplotlib.pyplot as plt
from collections import OrderedDict

# === Step 1: Load benchmark.json ===
with open(json_path, "r") as f:
    data = json.load(f)

# === Step 2: Count category_name under general_complex (preserve insertion order) ===
category_counts = OrderedDict()

for item in data:
    if item["type"] == "general_complex":
        category = item["category_name"].strip().lower()
        if category not in category_counts:
            category_counts[category] = 0
        category_counts[category] += 1

# === Step 3: Extract labels and counts without sorting ===
labels = list(category_counts.keys())
counts = list(category_counts.values())

# === Step 4: Plot histogram ===
plt.figure(figsize=(22, 6))
plt.bar(labels, counts, color='steelblue')
plt.title("Number of general_complex Questions per Object Category (Unsorted)", fontsize=16)
plt.xlabel("Object Category", fontsize=12)
plt.ylabel("Number of Questions", fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=9)
plt.tight_layout()
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.savefig("/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/output/general_complex_object_counts_unsorted_20to80.png", dpi=300, bbox_inches='tight')
plt.show()
