import re
import argparse
import pandas as pd
from datasets import load_dataset
import os
from typing import Optional

from transformers import AutoTokenizer
from tqdm import tqdm


def process_samples_by_search_type(dataset):
    new_dataset = []

    for item in tqdm(dataset, desc="Processing dataset"):
        st = item.get("search_type")
        if st == "dfs":
            condition = "DFS"
        elif st.startswith("bfs"):
            condition = "BFS"
        else:
            raise ValueError(f"Invalid search_type: {st}")
        messages = item["messages"]
        assert len(messages) == 2
        user_content = messages[0]["content"]
        assistant_content = messages[1]["content"]

        new_messages = [
            {
                "role": "user",
                "content": user_content,
            },
            {
                "role": "assistant",
                "content": f"Using {condition}\n{assistant_content}",
            },
        ]

        new_item = {
            "messages": new_messages,
        }
        new_dataset.append(new_item)

    return new_dataset

   
def main():
    name = "N4T200"
    max_length = 2048
    dataset = f"xxx98/Countdown-{name}-{max_length}"
    dataset = load_dataset(dataset, split="train")

    # Create output directory if it doesn't exist
    os.makedirs(f"data/countdown/type_condition-{name}-{max_length}", exist_ok=True)
    
    save_dir = f"data/countdown/type_condition-{name}-{max_length}"

    # Process and split samples
    sft_samples = process_samples_by_search_type(
        dataset
    )
    sft_dataset = pd.DataFrame(sft_samples)
    sft_dataset.to_parquet(f"{save_dir}/train.parquet", index=False)
    
    print("\nProcessing complete!")
    print(f"SFT samples used: {len(sft_dataset)}")


if __name__ == "__main__":
    main()