#%%
import pandas as pd
from sklearn.model_selection import train_test_split
import os
import sys

# === Configuration ===
INPUT_CSV = "ran40_balanced_target_samples.csv"
TRAIN_CSV = "train.csv"
VAL_CSV = "val.csv"
RANDOM_SEED = 42

# === Load the dataset ===
df = pd.read_csv(INPUT_CSV)

# === Group by (target_object, spatial_answer) ===
grouped = df.groupby(["target_object", "spatial_answer"])

# === Perform 80/20 split for each group ===
train_samples = []
val_samples = []

for (target, direction), group in grouped:
    if len(group) < 2:
        # Skip groups that can't be split
        continue
    train, val = train_test_split(group, test_size=0.2, random_state=RANDOM_SEED)
    train_samples.append(train)
    val_samples.append(val)

# === Concatenate all train and val samples ===
train_df = pd.concat(train_samples, ignore_index=True)
val_df = pd.concat(val_samples, ignore_index=True)

# === Save to CSV files ===
train_df.to_csv(TRAIN_CSV, index=False)
val_df.to_csv(VAL_CSV, index=False)

print(f"✅ Saved {len(train_df)} rows to {TRAIN_CSV}")
print(f"✅ Saved {len(val_df)} rows to {VAL_CSV}")

# %%
# === Extract image columns as sets ===
train_images = set(train_df["image"])
val_images = set(val_df["image"])

# === Find overlapping images ===
overlapping_images = train_images & val_images

# === Report the result ===
print(f"Number of overlapping images in train and val: {len(overlapping_images)}")

#%%#%%
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import defaultdict
import random

# === Config ===
TRAIN_CSV = "train.csv"
VAL_CSV = "val.csv"
SEED = 42
random.seed(SEED)

# === Load data ===
train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)

# === Step 1: Find overlapping images ===
train_images = set(train_df["image"])
val_images = set(val_df["image"])
overlapping_images = train_images & val_images
print(f"Step 1️⃣ - Found {len(overlapping_images)} overlapping images.")

# === Step 2: Remove (image, target_object) duplicates ===
train_overlap = train_df[train_df["image"].isin(overlapping_images)]
val_overlap = val_df[val_df["image"].isin(overlapping_images)]

dup_pairs = pd.merge(
    train_overlap, val_overlap,
    on=["image", "target_object"],
    how="inner"
)[["image", "target_object"]].drop_duplicates()

train_clean = train_df[~train_df.set_index(["image", "target_object"]).index.isin(
    dup_pairs.set_index(["image", "target_object"]).index)]
val_clean = val_df[~val_df.set_index(["image", "target_object"]).index.isin(
    dup_pairs.set_index(["image", "target_object"]).index)]

# Combine remaining overlapping rows
overlap_rows = pd.concat([
    train_clean[train_clean["image"].isin(overlapping_images)],
    val_clean[val_clean["image"].isin(overlapping_images)]
], ignore_index=True)

# Remove dup pairs again (for safety)
overlap_rows = overlap_rows[~overlap_rows.set_index(["image", "target_object"]).index.isin(
    dup_pairs.set_index(["image", "target_object"]).index)]

print(f"Step 2️⃣ - Removed {len(dup_pairs)} duplicated (image, target_object).")
print(f"Remaining overlap rows: {len(overlap_rows)}")

# === Step 3: Per-group 80/20 split by (target_object, spatial_answer) ===
pre_train = []
pre_val = []

grouped = overlap_rows.groupby(["target_object", "spatial_answer"])
for _, group in grouped:
    if len(group) < 2:
        pre_train.append(group)
        continue
    train_split, val_split = train_test_split(group, test_size=0.2, random_state=SEED)
    pre_train.append(train_split)
    pre_val.append(val_split)

pre_train_df = pd.concat(pre_train, ignore_index=True)
pre_val_df = pd.concat(pre_val, ignore_index=True)

print(f"Step 3️⃣ - Pre-train: {len(pre_train_df)}, Pre-val: {len(pre_val_df)}")

# === Step 4: Reassign only conflicting images in pre_train/pre_val ===
train_imgs = set(pre_train_df["image"])
val_imgs = set(pre_val_df["image"])
conflict_imgs = train_imgs & val_imgs

# Extract rows of conflicting images
conflict_rows = pd.concat([
    pre_train_df[pre_train_df["image"].isin(conflict_imgs)],
    pre_val_df[pre_val_df["image"].isin(conflict_imgs)]
], ignore_index=True)
print(f"num of conflict {len(conflict_rows)}")

# Assign images randomly (but deterministically)
random.seed(SEED)
image_to_split = {
    img: "train" if random.random() < 0.8 else "val"
    for img in sorted(conflict_imgs)  # sort to make iteration deterministic
}
conflict_rows["split"] = conflict_rows["image"].map(image_to_split)
resolved_train = conflict_rows[conflict_rows["split"] == "train"].drop(columns=["split"])
resolved_val   = conflict_rows[conflict_rows["split"] == "val"].drop(columns=["split"])


# Keep the safe (non-conflicting) rows
safe_train = pre_train_df[~pre_train_df["image"].isin(conflict_imgs)]
safe_val   = pre_val_df[~pre_val_df["image"].isin(conflict_imgs)]

# Merge final overlap
final_overlap_train = pd.concat([safe_train, resolved_train], ignore_index=True)
final_overlap_val   = pd.concat([safe_val, resolved_val], ignore_index=True)

print(f"Step 4️⃣ - Conflict resolved: train={len(final_overlap_train)}, val={len(final_overlap_val)}")

# === Step 5: Merge with original non-overlapping data ===
base_train = train_df[~train_df["image"].isin(overlapping_images)]
base_val = val_df[~val_df["image"].isin(overlapping_images)]

final_train = pd.concat([base_train, final_overlap_train], ignore_index=True)
final_val = pd.concat([base_val, final_overlap_val], ignore_index=True)

# === Save to file ===
final_train.to_csv("train_final.csv", index=False)
final_val.to_csv("val_final.csv", index=False)

print(f"✅ Step 5️⃣ - Saved final train: {len(final_train)} → train_final.csv")
print(f"✅ Step 5️⃣ - Saved final val:   {len(final_val)} → val_final.csv")

#%%
#%%
# === Post-processing: limit each (target_object, spatial_answer) group in final_train to max 20 rows ===
LIMIT = 20

# Reload just in case (optional)
final_train = pd.read_csv("train_final.csv")

# Group and sample
grouped = final_train.groupby(["target_object", "spatial_answer"])
limited_train = grouped.apply(lambda g: g.sample(n=LIMIT, random_state=SEED) if len(g) > LIMIT else g)

# Remove the extra index created by groupby.apply
limited_train = limited_train.reset_index(drop=True)

# Save updated version
limited_train.to_csv("train_final_limited.csv", index=False)
print(f"🔁 Limited train_final: {len(limited_train)} rows saved to train_final_limited.csv")

#%%
# === Compare (target_object, spatial_answer) sets between train and val ===

# Load both datasets
train_limited_df = pd.read_csv("train_final_limited.csv")
val_df = pd.read_csv("val_final.csv")

# Create sets of (target_object, spatial_answer)
train_pairs = set(zip(train_limited_df["target_object"], train_limited_df["spatial_answer"]))
val_pairs = set(zip(val_df["target_object"], val_df["spatial_answer"]))

# Compute differences
only_in_train = train_pairs - val_pairs
only_in_val = val_pairs - train_pairs
in_both = train_pairs & val_pairs

# Print stats
print(f"✅ Unique (label, spatial) in train: {len(train_pairs)}")
print(f"✅ Unique (label, spatial) in val:   {len(val_pairs)}")
print(f"✅ Common in both: {len(in_both)}")
print(f"❌ Only in train:  {len(only_in_train)}")
print(f"❌ Only in val:    {len(only_in_val)}")

# Optionally print actual differences
print("\n🔍 (label, spatial) only in train:")
for pair in sorted(only_in_train):
    print(pair)

print("\n🔍 (label, spatial) only in val:")
for pair in sorted(only_in_val):
    print(pair)

#%%
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

# === Load datasets ===
train_limited_df = pd.read_csv("train_final_limited.csv")
val_df = pd.read_csv("val_final.csv")

# === Output folders ===
os.makedirs("histogram_train", exist_ok=True)
os.makedirs("histogram_val", exist_ok=True)

# === Function to generate histograms with tqdm ===
def plot_histograms_per_label(df, output_folder, title_prefix):
    labels = df["target_object"].unique()
    for label in tqdm(labels, desc=f"📊 Generating {title_prefix} histograms"):
        subset = df[df["target_object"] == label]
        spatial_counts = subset["spatial_answer"].value_counts().to_dict()
        directions = ["left up", "right up", "left bottom", "right bottom"]
        counts = [spatial_counts.get(d, 0) for d in directions]

        plt.figure(figsize=(6, 4))
        plt.bar(directions, counts)
        plt.title(f"{title_prefix}: {label}")
        plt.xlabel("Spatial Answer")
        plt.ylabel("Number of Rows")
        plt.xticks(rotation=30)
        plt.tight_layout()

        # Safe filename
        safe_label = label.replace("/", "_").replace("\\", "_").replace(" ", "_")
        plt.savefig(os.path.join(output_folder, f"{safe_label}.png"))
        plt.close()

# === Generate histograms with progress bars ===
plot_histograms_per_label(train_limited_df, "histogram_train", "Train")
plot_histograms_per_label(val_df, "histogram_val", "Val")

print("✅ Histograms saved in histogram_train/ and histogram_val/")

#%%
sys.exit()
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

# === Load cleaned data ===
# train_df = pd.read_csv("train_final.csv")
train_df = pd.read_csv("train_final_limited.csv")
val_df = pd.read_csv("val_final.csv")

# === Combine both datasets ===
combined_df = pd.concat([train_df, val_df], ignore_index=True)

# === 1️⃣ Count unique images per target_object ===
label_to_images = defaultdict(set)
for _, row in combined_df.iterrows():
    label = row["target_object"]
    img = row["image"]
    label_to_images[label].add(img)

image_count_per_label = {label: len(imgs) for label, imgs in label_to_images.items()}

# === 2️⃣ Count total rows per target_object ===
row_count_per_label = combined_df["target_object"].value_counts().to_dict()

# === Plot: Unique image count per label ===
plt.figure(figsize=(16, 5))
plt.bar(image_count_per_label.keys(), image_count_per_label.values())
plt.xticks(rotation=90, ha="right")
plt.xlabel("Target Object Label")
plt.ylabel("Number of Unique Images")
plt.title("📊 Histogram: Unique Images per Target Label")
plt.tight_layout()
plt.show()

# === Plot: Row count per label ===
plt.figure(figsize=(16, 5))
plt.bar(row_count_per_label.keys(), row_count_per_label.values(), color="orange")
plt.xticks(rotation=90, ha="right")
plt.xlabel("Target Object Label")
plt.ylabel("Number of Rows")
plt.title("📊 Histogram: Row Count per Target Label")
plt.tight_layout()
plt.show()


# %%
#%%
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

# === Load cleaned data ===
train_df = pd.read_csv("train_final.csv")
val_df = pd.read_csv("val_final.csv")
combined_df = pd.concat([train_df, val_df], ignore_index=True)

# === 1️⃣ Unique image count per (label, spatial_answer) ===
label_dir_to_images = defaultdict(set)
for _, row in combined_df.iterrows():
    key = f"{row['target_object']}-{row['spatial_answer']}"
    label_dir_to_images[key].add(row["image"])

image_count_per_pair = {k: len(v) for k, v in label_dir_to_images.items()}

# === 2️⃣ Row count per (label, spatial_answer) ===
label_dir_to_rows = combined_df.groupby(["target_object", "spatial_answer"]).size()
row_count_per_pair = {f"{k[0]}-{k[1]}": v for k, v in label_dir_to_rows.items()}

# === Plot: Unique image count per (label, spatial_answer) ===
plt.figure(figsize=(20, 6))
plt.bar(image_count_per_pair.keys(), image_count_per_pair.values())
plt.xticks(rotation=90, ha="right")
plt.xlabel("Target Object - Spatial Answer")
plt.ylabel("Number of Unique Images")
plt.title("📊 Unique Images per Target Label + Spatial Answer")
plt.tight_layout()
plt.show()

# === Plot: Row count per (label, spatial_answer) ===
plt.figure(figsize=(20, 6))
plt.bar(row_count_per_pair.keys(), row_count_per_pair.values(), color="orange")
plt.xticks(rotation=90, ha="right")
plt.xlabel("Target Object - Spatial Answer")
plt.ylabel("Number of Rows")
plt.title("📊 Row Count per Target Label + Spatial Answer")
plt.tight_layout()
plt.show()

# %%
