import os
import json
import random
import shutil
from collections import defaultdict

# === Configuration ===
random.seed(42)

# Input combined JSON
INPUT_JSON = "/fs/scratch/PAS2099/Jiacheng/Orientation/output/combined_general_complex_with_top.json"

# Image source roots
EGO_ROOT = "/fs/scratch/PAS2099/Jiacheng/EgoOrientBench/EgoOrientBench/all_data/EgocentricDataset/imagenet_after"
CURE_DIRS = [
    "/fs/scratch/PAS2099/Jiacheng/Cure_or/01_no_challenge/white/iPhone",
    "/fs/scratch/PAS2099/Jiacheng/Cure_or/01_no_challenge/texture1/iPhone",
    "/fs/scratch/PAS2099/Jiacheng/Cure_or/01_no_challenge/texture2/iPhone"
]

# Output directories
BASE_OUT      = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient"
TRAIN_IMG_DIR = os.path.join(BASE_OUT, "train", "images")
VAL_IMG_DIR   = os.path.join(BASE_OUT, "val",   "images")
TRAIN_JSON    = os.path.join(BASE_OUT, "train", "train.json")
VAL_JSON      = os.path.join(BASE_OUT, "val",   "val.json")

# Create output dirs if not exist
for path in (TRAIN_IMG_DIR, VAL_IMG_DIR, os.path.dirname(TRAIN_JSON), os.path.dirname(VAL_JSON)):
    os.makedirs(path, exist_ok=True)

# Load combined data
with open(INPUT_JSON, "r") as f:
    data = json.load(f)

# 1) Group by (category_name, label)
groups = defaultdict(list)
for entry in data:
    if entry.get("type") != "general_complex":
        continue
    key = (entry["category_name"], entry["label"])
    groups[key].append(entry)

train_entries = []
val_entries   = []

# 2) For each (object, orientation) split its images 80/20
for (obj, lbl), entries in groups.items():
    # collect unique filenames
    fnames = list({os.path.basename(e["image"]) for e in entries})
    random.shuffle(fnames)
    # compute split point
    n_train = round(0.8 * len(fnames))
    train_f = set(fnames[:n_train])
    val_f   = set(fnames[n_train:])

    # 3) assign entries to train or val
    for e in entries:
        fn = os.path.basename(e["image"])
        if fn in train_f:
            train_entries.append(e)
        else:
            val_entries.append(e)

    # 4) copy images for this group
    for fn in train_f:
        # find source path for train image
        src = None
        if any(os.path.exists(os.path.join(d, fn)) for d in CURE_DIRS):
            # CureOr image
            for d in CURE_DIRS:
                p = os.path.join(d, fn)
                if os.path.exists(p):
                    src = p
                    break
        else:
            # EgoOrient image
            src = os.path.join(EGO_ROOT, fn)
        # copy if exists
        if src and os.path.exists(src):
            shutil.copy(src, os.path.join(TRAIN_IMG_DIR, fn))

    for fn in val_f:
        # find source path for val image
        src = None
        if any(os.path.exists(os.path.join(d, fn)) for d in CURE_DIRS):
            for d in CURE_DIRS:
                p = os.path.join(d, fn)
                if os.path.exists(p):
                    src = p
                    break
        else:
            src = os.path.join(EGO_ROOT, fn)
        if src and os.path.exists(src):
            shutil.copy(src, os.path.join(VAL_IMG_DIR, fn))

# 5) save train/val json
with open(TRAIN_JSON, "w") as f:
    json.dump(train_entries, f, indent=2, ensure_ascii=False)
with open(VAL_JSON, "w") as f:
    json.dump(val_entries, f, indent=2, ensure_ascii=False)

# 6) print summary per group (optional)
print("✅ Split complete!")
print(f"Train entries: {len(train_entries)}, Val entries: {len(val_entries)}")
