import pandas as pd
import random
import os
import matplotlib.pyplot as plt
from collections import defaultdict
import re

# === Config ===
SAMPLE_LIMIT = 20
RANDOM_SEED = 42
OUTPUT_CSV = "ran20_balanced_target_samples.csv"
HIST_DIR = "histogram_ran20"
os.makedirs(HIST_DIR, exist_ok=True)
random.seed(RANDOM_SEED)

# === Load unified label mapping ===
master_df = pd.read_csv("/fs/scratch/PAS2099/Lemeng/Spatial_combined/fifth_master_label_mapping_table_corrected.csv")
# raw_to_unified = {}
# for _, row in master_df.iterrows():
#     unified = row["unified_label"]
#     for source in ["object365_label", "nyu_label", "lvis_label"]:
#         raw = row[source]
#         if pd.notna(raw) and raw.strip():
#             raw_to_unified[raw.strip().lower()] = unified
# === Load unified label mapping ===
raw_to_unified = {}
for _, row in master_df.iterrows():
    unified = row["unified_label"]
    for source in ["object365_label", "nyu_label", "lvis_label"]:
        raw = row[source]
        if pd.notna(raw) and raw.strip():
            # Split by comma, and normalize each sublabel
            sublabels = [label.strip().lower() for label in str(raw).split(",") if label.strip()]
            for label in sublabels:
                raw_to_unified[label] = unified


# === Load and combine all 3 spatial datasets ===
dfs = []
for filename in [
    "object365_valid_nonoverlapping_pairs.csv",
    "nyu_strictly_nonoverlapping_pairs.csv",
    "lvis_pairwise_spatial_relations.csv",
]:
    df = pd.read_csv(filename)
    df["source"] = filename
    dfs.append(df)

full_df = pd.concat(dfs, ignore_index=True)

# === Normalize target object → unified label ===
def map_label(x):
    return raw_to_unified.get(str(x).strip().lower(), None)

full_df["unified_target"] = full_df["target_object"].apply(map_label)
full_df["unified_arche"] = full_df["arche_object"].apply(map_label)
full_df = full_df[full_df["unified_target"].notna() & full_df["unified_arche"].notna()]

# === Resolve image field ===
def resolve_image_field(row):
    for field in ["image_path", "image_id", "image_url"]:
        if field in row and pd.notna(row[field]):
            return row[field]
    return None

full_df["image"] = full_df.apply(resolve_image_field, axis=1)

# === Filter objects with >= SAMPLE_LIMIT per direction ===
directions = ["left up", "right up", "left bottom", "right bottom"]
counts = full_df.groupby(["unified_target", "spatial_answer"]).size().unstack(fill_value=0)
valid_objects = counts[(counts[directions] >= SAMPLE_LIMIT).all(axis=1)].index.tolist()

# === Sample and collect data ===
grouped = full_df.groupby(["unified_target", "spatial_answer"])
final_samples = []
hist_data = {}

for obj in valid_objects:
    hist_data[obj] = {}
    for direction in directions:
        rows = grouped.get_group((obj, direction))
        sampled = rows.sample(n=SAMPLE_LIMIT, random_state=RANDOM_SEED)
        hist_data[obj][direction] = len(sampled)
        final_samples.append(sampled)

# === Save sampled rows to CSV ===
final_df = pd.concat(final_samples, ignore_index=True)
final_df = final_df[[
    "source", "image", "unified_arche", "unified_target",
    "spatial_answer", "arche_bbox", "target_bbox"
]]
final_df.columns = [
    "source", "image", "arche_object", "target_object",
    "spatial_answer", "arche_bbox", "target_bbox"
]
final_df.to_csv(OUTPUT_CSV, index=False)
print(f"✅ Saved {len(final_df)} rows to {OUTPUT_CSV}")

# === Plot histograms ===
for obj, dir_counts in hist_data.items():
    counts = [dir_counts.get(d, 0) for d in directions]
    plt.figure(figsize=(6, 4))
    plt.bar(directions, counts, color="orange")
    plt.title(f"{obj} — Random {SAMPLE_LIMIT} Per Direction")
    plt.ylabel("Target Count")
    plt.xticks(rotation=30)
    plt.tight_layout()

    # Sanitize filename
    safe_name = obj.replace("/", "_").replace("\\", "_").replace(" ", "_")
    plt.savefig(os.path.join(HIST_DIR, f"{safe_name}.png"))
    plt.close()

# === Print summary ===
num_unique_labels = final_df["target_object"].nunique()
num_unique_images = final_df["image"].nunique()
print(f"🔢 Unique target labels: {num_unique_labels}")
print(f"🖼️ Unique images: {num_unique_images}")