import os
import shutil
import argparse
from typing import Union

from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict


def load_mmmlu_subset(subset: str) -> Union[Dataset, DatasetDict]:
    """
    Load the MMMLU dataset from Hugging Face. The repo used in the codebase is
    "openai/MMMLU" with a subset like "all".
    """
    return load_dataset("openai/MMMLU", subset)


def merge_all_splits(ds: Union[Dataset, DatasetDict]) -> Dataset:
    """
    If multiple splits exist, concatenate them into a single Dataset.
    """
    if isinstance(ds, DatasetDict):
        datasets = [split_ds for split_ds in ds.values()]
        return concatenate_datasets(datasets)
    return ds


def main():
    parser = argparse.ArgumentParser(description="Download MMMLU and split into train/test, saving to disk.")
    parser.add_argument("--subset", type=str, default="SW_KE", help="Subset name for openai/MMMLU (e.g., 'all').")
    parser.add_argument("--test_size", type=float, default=0.3, help="Test split size (fraction between 0 and 1).")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling/splitting.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="local/teacher_datasets/MMMLU",
        help="Directory to save the dataset splits.",
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        default=None,
        help="Optionally limit total samples before splitting (for quick tests).",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="If set, remove existing output_dir before saving.",
    )

    args = parser.parse_args()

    print(f"Loading dataset 'openai/MMMLU' subset='{args.subset}' ...")
    ds = load_mmmlu_subset(args.subset)
    full = merge_all_splits(ds)
    print(f"Total samples before filtering: {len(full)}")

    if args.max_samples is not None:
        n = min(args.max_samples, len(full))
        full = full.select(range(n))
        print(f"Selected first {n} samples for quick run.")

    print("Shuffling dataset ...")
    full = full.shuffle(seed=args.seed)

    print(f"Splitting with test_size={args.test_size}, seed={args.seed} ...")
    split = full.train_test_split(test_size=args.test_size, seed=args.seed)
    train_ds = split["train"]
    test_ds = split["test"]
    print(f"Split sizes -> train: {len(train_ds)}, test: {len(test_ds)}")

    # Prepare output dir and optionally overwrite
    if os.path.exists(args.output_dir) and args.overwrite:
        print(f"Overwriting existing directory: {args.output_dir}")
        shutil.rmtree(args.output_dir)

    os.makedirs(args.output_dir, exist_ok=True)

    # Save as DatasetDict at root so load_from_disk(output_dir) works
    dataset_dict = DatasetDict({
        "train": train_ds,
        "test": test_ds,
    })
    print(f"Saving DatasetDict to: {args.output_dir}")
    dataset_dict.save_to_disk(args.output_dir)

    print("Done.")


if __name__ == "__main__":
    main()


