import numpy as np
import os
import cv2
from sklearn.model_selection import train_test_split

def generate_gray_dataset(grayscale_range=(0, 255), img_width=256, img_height=20, data_path="shap_dataset", train_split=0.7, val_split=0.1, test_split=0.2, seed=42):
    """
    Generate a synthetic dataset with all possible grayscale values within a specified range.

    Parameters:
        grayscale_range (tuple): Range of grayscale values (inclusive) as (min_value, max_value).
        img_width (int): Width of each image.
        img_height (int): Height of each image.
        data_path (str): Directory to save the dataset.
        train_split (float): Proportion of the training dataset.
        val_split (float): Proportion of the validation dataset.
        test_split (float): Proportion of the test dataset.
        seed (int): Random seed for reproducibility.

    Returns:
        None

    Example usage:
        generate_gray_dataset(grayscale_range=(0, 255), img_width=256, img_height=20, data_path="shap_dataset")
    """
    # Set random seed for reproducibility
    np.random.seed(seed)

    # Generate all possible grayscale values within the range
    min_value, max_value = grayscale_range
    grayscale_values = np.arange(min_value, max_value + 1, dtype=np.uint8)

    # Compute labels using the threshold condition
    labels = np.where(grayscale_values >= 127.5, 1, 0)

    # Partition the dataset
    X_train, X_temp, y_train, y_temp = train_test_split(
        grayscale_values, labels, test_size=(1 - train_split), random_state=seed
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=(test_split / (val_split + test_split)), random_state=seed
    )

    # Create directories to save datasets
    for subset in ["train", "validation", "test"]:
        subset_path = os.path.join(data_path, subset)
        os.makedirs(subset_path, exist_ok=True)

    def save_subset(X, y, subset_name):
        subset_path = os.path.join(data_path, subset_name)
        for idx, (value, label) in enumerate(zip(X, y)):
            # Create an image filled with the grayscale value
            image = np.full((img_height, img_width), value, dtype=np.uint8)

            # Save image with label in the filename
            img_filename = os.path.join(subset_path, f"img_{idx:05d}_label_{label}.png")
            cv2.imwrite(img_filename, image)

    # Save each subset
    save_subset(X_train, y_train, "train")
    save_subset(X_val, y_val, "validation")
    save_subset(X_test, y_test, "test")

    print(f"Dataset generated and saved to {data_path}")
