import os
import shutil
import random
from tqdm import tqdm
import json
from datetime import datetime, UTC
from config.config import IMAGENET_DIR, IMAGENET_DATA_DIR

# ===== Config =====
TRAIN_RATIO = 0.75
# Set via env var SPLIT_SEED or default; change as needed.
SEED = 42

image_net_dir = IMAGENET_DIR
source_dir = IMAGENET_DATA_DIR
output_dir = os.path.join(image_net_dir, "split_data")
os.makedirs(output_dir, exist_ok=True)
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")

# Clean existing train/val folders if they exist
for folder in [train_dir, val_dir]:
    if os.path.exists(folder):
        shutil.rmtree(folder)

# Create output dirs
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# Get list of all class folders (sorted -> stable base order)
print(source_dir)
all_classes = sorted(
    d for d in os.listdir(source_dir)
    if os.path.isdir(os.path.join(source_dir, d))
)

# Shuffle with a fixed seed (local RNG to avoid touching global state)
rng = random.Random(SEED)
rng.shuffle(all_classes)

# Split classes
split_idx = int(TRAIN_RATIO * len(all_classes))
train_classes = all_classes[:split_idx]
val_classes = all_classes[split_idx:]

def copy_class_images(class_list, target_dir):
    for class_name in tqdm(class_list, desc=f"Copying to {os.path.basename(target_dir)}"):
        class_path = os.path.join(source_dir, class_name)
        # Optional: ensure deterministic file order
        for fname in sorted(os.listdir(class_path)):
            src = os.path.join(class_path, fname)
            dst_name = f"{class_name}_{fname}"
            dst = os.path.join(target_dir, dst_name)
            shutil.copy2(src, dst)

copy_class_images(train_classes, train_dir)
copy_class_images(val_classes, val_dir)

# Save minimal reproducibility metadata
meta = {
    "seed": SEED,
    "train_classes": train_classes,
    "val_classes": val_classes
}
with open(os.path.join(output_dir, "split_meta.json"), "w") as f:
    json.dump(meta, f, indent=2)

# Also write plain-text lists
with open(os.path.join(output_dir, "train_classes.txt"), "w") as f:
    f.write("\n".join(train_classes))
with open(os.path.join(output_dir, "val_classes.txt"), "w") as f:
    f.write("\n".join(val_classes))

print(f"Done. Seed={SEED}. Metadata saved to {os.path.join(output_dir, 'split_meta.json')}")
