import pandas as pd
import os
import matplotlib.pyplot as plt
from collections import defaultdict

# === Create output folder ===
os.makedirs("histogram_all", exist_ok=True)

# === Load master label mapping ===
master_df = pd.read_csv("master_label_mapping_table.csv")

# Build raw label → unified_label map
raw_to_unified = {}
for _, row in master_df.iterrows():
    unified = row["unified_label"]
    for source in ["object365_label", "nyu_label", "lvis_label"]:
        raw = row[source]
        if pd.notna(raw) and raw.strip():
            raw_to_unified[raw.strip().lower()] = unified

# === Initialize counter: unified_label → spatial direction counts ===
direction_counts = defaultdict(lambda: defaultdict(int))

# === Helper to process each CSV ===
def process_file(file_path):
    df = pd.read_csv(file_path)
    for _, row in df.iterrows():
        raw_label = str(row["target_object"]).strip().lower()
        direction = str(row["spatial_answer"]).strip().lower()
        if direction not in {"left up", "right up", "left bottom", "right bottom"}:
            continue
        unified = raw_to_unified.get(raw_label)
        if unified:
            direction_counts[unified][direction] += 1

# === Process all three datasets ===
process_file("object365_valid_nonoverlapping_pairs.csv")
process_file("nyu_strictly_nonoverlapping_pairs.csv")
process_file("lvis_pairwise_spatial_relations.csv")

# === Build totals and sort ===
label_totals = {
    label: sum(d.values()) for label, d in direction_counts.items()
}
sorted_labels = sorted(label_totals.items(), key=lambda x: -x[1])
print(f"in total {len(label_totals)} of labels")
top_5 = sorted_labels[:5]
bottom_5 = [x for x in sorted(label_totals.items(), key=lambda x: x[1]) if x[1] > 0][:5]

# === Print top and bottom 5 ===
print("\n🔝 Top 5 most frequent target objects:")
for label, count in top_5:
    print(f"  {label}: {count}")

print("\n🔻 Bottom 5 least frequent target objects:")
for label, count in bottom_5:
    print(f"  {label}: {count}")

# === Generate and save histograms ===
for label, dir_counts in direction_counts.items():
    directions = ["left up", "right up", "left bottom", "right bottom"]
    counts = [dir_counts.get(d, 0) for d in directions]

    plt.figure(figsize=(6, 4))
    plt.bar(directions, counts, color="skyblue")
    plt.title(f"{label} — Target Object Distribution")
    plt.ylabel("Count")
    plt.xticks(rotation=30)
    plt.tight_layout()
    plt.savefig(f"histogram_all/{label}.png")
    plt.close()

print("\n✅ Histograms saved to folder: histogram_all/")
