import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

def class2label(c):
    # Class_label
    if c == "deercrossing":
        class_label = 0
    elif c == "leftcurve":
        class_label = 1
    elif c == "oneway":
        class_label = 2
    elif c == "pedestrian":
        class_label = 3
    elif c == "speedlimit25mph":
        class_label = 4
    elif c ==  "stop":
        class_label = 5
    elif c == "warning":
        class_label = 6
    elif c == "workersahead":
        class_label = 7
    else:
        class_label = -1
        raise ValueError
    return class_label


def split_imagefolder_to_npy_separate(root_dir, out_dir, image_size=None,
                                      train_ratio=0.64, val_ratio=0.16, test_ratio=0.20):
    """
    Read ImageFolder dataset and split into train/val/test sets.
    Save images and labels into separate .npy files.
    """

    # Validate split
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
        "Ratios must sum to 1."

    images = []
    labels = []
    class_names = sorted(os.listdir(root_dir))

    # ---------- Load All Images ----------
    for label_idx, class_name in enumerate(class_names):
        class_path = os.path.join(root_dir, class_name)
        if not os.path.isdir(class_path):
            continue

        for fname in os.listdir(class_path):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                img_path = os.path.join(class_path, fname)

                img = Image.open(img_path).convert("RGB")
                if image_size is not None:
                    img = img.resize(image_size)

                images.append(np.array(img))
                label = class2label(class_name)
                labels.append(label)

    images = np.stack(images)
    labels = np.array(labels)

    # ---------- Split into train / temp ----------
    X_train, X_temp, y_train, y_temp = train_test_split(
        images, labels,
        test_size=(1 - train_ratio),
        shuffle=True, stratify=labels
    )

    # ---------- Split temp into val / test ----------
    val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)

    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=(1 - val_ratio_adjusted),
        shuffle=True, stratify=y_temp
    )

    # ---------- Save to separate .npy files ----------
    os.makedirs(out_dir, exist_ok=True)

    np.save(os.path.join(out_dir, "train_images.npy"), X_train)
    np.save(os.path.join(out_dir, "train_labels.npy"), y_train)

    np.save(os.path.join(out_dir, "val_images.npy"), X_val)
    np.save(os.path.join(out_dir, "val_labels.npy"), y_val)

    np.save(os.path.join(out_dir, "test_images.npy"), X_test)
    np.save(os.path.join(out_dir, "test_labels.npy"), y_test)

    print("Saved train/val/test splits to:", out_dir)
    print(f"Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")

    return (X_train, y_train), (X_val, y_val), (X_test, y_test)




img_root = '../../../real_data'
out_path = './data'
split_imagefolder_to_npy_separate(img_root, out_path, (32, 32))
# print(images.shape)
# print(labels.shape)
# saved_images = './data/train_images.npy'
# saved_labels = './data/train_labels.npy'
# data = np.load(saved_labels, allow_pickle=True)
# print(data.shape)
# print(data['labels'])