import pandas as pd
import random
import os
import matplotlib.pyplot as plt
from collections import defaultdict
import re

# === Config ===
MIN_LIMIT = 50
MAX_LIMIT = 200
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
OUTPUT_CSV = "ran_min50_max200_balanced_target_samples.csv"
HIST_DIR = "histogram_min50_max200"
os.makedirs(HIST_DIR, exist_ok=True)

# === Load unified label mapping
master_df = pd.read_csv("/fs/scratch/PAS2099/Lemeng/Spatial_combined/fifth_master_label_mapping_table_corrected.csv")
raw_to_unified = {}
for _, row in master_df.iterrows():
    unified = row["unified_label"]
    for source in ["object365_label", "nyu_label", "lvis_label"]:
        raw = row[source]
        if pd.notna(raw) and raw.strip():
            sublabels = [label.strip().lower() for label in str(raw).split(",") if label.strip()]
            for label in sublabels:
                raw_to_unified[label] = unified

# === Load and combine all spatial datasets
dfs = []
for filename in [
    "object365_valid_nonoverlapping_pairs.csv",
    "nyu_strictly_nonoverlapping_pairs.csv",
    "lvis_pairwise_spatial_relations.csv",
]:
    df = pd.read_csv(filename)
    df["source"] = filename
    dfs.append(df)

full_df = pd.concat(dfs, ignore_index=True)

# === Map to unified labels
def map_label(x):
    return raw_to_unified.get(str(x).strip().lower(), None)

full_df["unified_target"] = full_df["target_object"].apply(map_label)
full_df["unified_arche"] = full_df["arche_object"].apply(map_label)
full_df = full_df[full_df["unified_target"].notna() & full_df["unified_arche"].notna()]

# === Resolve image field
def resolve_image_field(row):
    for field in ["image_path", "image_id", "image_url"]:
        if field in row and pd.notna(row[field]):
            return row[field]
    return None

full_df["image"] = full_df.apply(resolve_image_field, axis=1)

# === Identify valid objects (min 50 in all 4 directions)
directions = ["left up", "right up", "left bottom", "right bottom"]
counts = full_df.groupby(["unified_target", "spatial_answer"]).size().unstack(fill_value=0)
valid_objects = counts[(counts[directions] >= MIN_LIMIT).all(axis=1)].index.tolist()

# === Sample min 50, max 200 rows per direction
grouped = full_df.groupby(["unified_target", "spatial_answer"])
final_samples = []

for obj in valid_objects:
    for direction in directions:
        rows = grouped.get_group((obj, direction))
        n = len(rows)
        if n > MAX_LIMIT:
            sampled = rows.sample(n=MAX_LIMIT, random_state=RANDOM_SEED)
        else:
            sampled = rows.copy()
        final_samples.append(sampled)

# === Save to CSV
final_df = pd.concat(final_samples, ignore_index=True)
final_df = final_df[[
    "source", "image", "unified_arche", "unified_target",
    "spatial_answer", "arche_bbox", "target_bbox"
]]
final_df.columns = [
    "source", "image", "arche_object", "target_object",
    "spatial_answer", "arche_bbox", "target_bbox"
]
final_df.to_csv(OUTPUT_CSV, index=False)
print(f"✅ Saved {len(final_df)} rows to {OUTPUT_CSV}")

# === Generate histograms per object
direction_counts = defaultdict(lambda: defaultdict(int))
for _, row in final_df.iterrows():
    label = row["target_object"]
    direction = row["spatial_answer"]
    direction_counts[label][direction] += 1

for label, counts in direction_counts.items():
    values = [counts.get(d, 0) for d in directions]
    
    plt.figure(figsize=(6, 4))
    plt.bar(directions, values, color="green")
    plt.title(f"{label} — Target Direction Frequency")
    plt.ylabel("Target Count")
    plt.xticks(rotation=30)
    plt.tight_layout()

    safe_name = re.sub(r"[^a-zA-Z0-9_\-]", "_", label)
    plt.savefig(os.path.join(HIST_DIR, f"{safe_name}.png"))
    plt.close()

print(f"📊 Histograms saved to: {HIST_DIR}/")

# === Summary stats ===
num_unique_labels = final_df["target_object"].nunique()
num_unique_images = final_df["image"].nunique()
print(f"🔢 Unique target labels: {num_unique_labels}")
print(f"🖼️ Unique images: {num_unique_images}")
