import re
import argparse
import pandas as pd
from datasets import load_dataset
import os
from typing import Optional

from transformers import AutoTokenizer
from tqdm import tqdm


def process_samples(dataset):
    new_dataset = []

    for item in tqdm(dataset, desc="Processing dataset"):
        new_item = {
            "messages": item["messages"],
        }
        new_dataset.append(new_item)

    return new_dataset

   
def main():
    name = "N4T200"
    max_length = 2048
    dataset = f"xxx98/Countdown-{name}-{max_length}"
    dataset = load_dataset(dataset, split="train")

    # Create output directory if it doesn't exist
    os.makedirs(f"data/countdown/full-{name}-{max_length}", exist_ok=True)
    
    save_dir = f"data/countdown/full-{name}-{max_length}"

    # Process and split samples
    sft_samples = process_samples(
        dataset
    )
    sft_dataset = pd.DataFrame(sft_samples)
    sft_dataset.to_parquet(f"{save_dir}/train.parquet", index=False)
    
    print("\nProcessing complete!")
    print(f"SFT samples used: {len(sft_dataset)}")


if __name__ == "__main__":
    main()