import json
import random
import argparse
from typing import List, Dict

def load_raw_dataset(dataset_path: str) -> List[Dict]:
    """Loads the raw dataset from a JSON file."""
    with open(dataset_path, "r") as f:
        data = json.load(f)
    return data

def split_dataset(
    data: List[Dict],
    train_ratio: float,
    val_ratio: float,
    test_ratio: float,
    seed: int = 42
) -> (List[Dict], List[Dict], List[Dict]):
    """
    Splits the dataset into train, validation, and test sets.

    :param data: List of dataset entries.
    :param train_ratio: Ratio of data for training set.
    :param val_ratio: Ratio of data for validation set.
    :param test_ratio: Ratio of data for test set.
    :param seed: Random seed for reproducibility.
    :return: (train_set, val_set, test_set)
    """
    assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1.0"

    random.seed(seed)
    random.shuffle(data)

    total = len(data)
    train_end = int(total * train_ratio)
    val_end = train_end + int(total * val_ratio)

    train_set = data[:train_end]
    val_set = data[train_end:val_end]
    test_set = data[val_end:]

    return train_set, val_set, test_set

def save_splits(train_set: List[Dict], val_set: List[Dict], test_set: List[Dict], 
    output_dir: str, 
    train_ratio: float,
    val_ratio: float,
    test_ratio: float
    ):
    """Saves the dataset splits as separate JSON files."""
    train_path = f"{output_dir}/random_train_dataset_{train_ratio}_{val_ratio}_{test_ratio}.json"
    val_path = f"{output_dir}/random_val_dataset_{train_ratio}_{val_ratio}_{test_ratio}.json"
    test_path = f"{output_dir}/random_test_dataset_{train_ratio}_{val_ratio}_{test_ratio}.json"

    with open(train_path, "w") as f:
        json.dump(train_set, f, indent=2)
    with open(val_path, "w") as f:
        json.dump(val_set, f, indent=2)
    with open(test_path, "w") as f:
        json.dump(test_set, f, indent=2)

    print(f"Datasets saved:\n  Train: {train_path} ({len(train_set)} samples)\n  Val: {val_path} ({len(val_set)} samples)\n  Test: {test_path} ({len(test_set)} samples)")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Split raw dataset into train, val, and test splits.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the raw dataset JSON file.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the split datasets.")
    parser.add_argument("--train_ratio", type=float, default=0.8, help="Proportion of data for training set.")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="Proportion of data for validation set.")
    parser.add_argument("--test_ratio", type=float, default=0.1, help="Proportion of data for test set.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")

    args = parser.parse_args()

    # Load and split dataset
    raw_data = load_raw_dataset(args.dataset_path)
    train_set, val_set, test_set = split_dataset(
        raw_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
    )

    # Save the splits
    save_splits(train_set, val_set, test_set, args.output_dir, args.train_ratio, args.val_ratio, args.test_ratio)
