import re
import argparse
import pandas as pd
from datasets import load_dataset
import os
from typing import Optional

from transformers import AutoTokenizer
from tqdm import tqdm


def process_samples_by_search_type(dataset):
    """Split samples into 2 clusters based on `search_type`.

    Returns a list of two lists: [split_0, split_1], and the mapping used.
    """
    splits = [[], []]
    for item in tqdm(dataset, desc="Processing dataset (type split)"):
        st = item.get("search_type")
        if st == "dfs":
            split_idx = 0
        elif st.startswith("bfs"):
            split_idx = 1
        else:
            raise ValueError(f"Invalid search_type: {st}")
        splits[split_idx].append({
            "messages": item["messages"],
        })

    return splits

   
def main():
    name = "N4T200"
    max_length = 2048
    dataset = f"xxx98/Countdown-{name}-{max_length}"
    dataset = load_dataset(dataset, split="train")

    # Create output directory for type-based split
    save_dir = f"data/countdown/type_split-{name}-{max_length}"
    os.makedirs(save_dir, exist_ok=True)

    # Process and split samples by search_type (2 clusters)
    sft_samples_list = process_samples_by_search_type(dataset)
    for idx, sft_samples in enumerate(sft_samples_list):
        sft_dataset = pd.DataFrame(sft_samples)
        sft_dataset.to_parquet(f"{save_dir}/train_{idx}.parquet", index=False)
        print(f"Split {idx}: {len(sft_dataset)} samples")

    print("\nProcessing complete!")
    print(f"Total splits created: {len(sft_samples_list)}")


if __name__ == "__main__":
    main()
