from datasets import load_dataset
import os

# Load the GSM8K dataset
print("Loading GSM8K dataset...")
ds = load_dataset("openai/gsm8k", "main")
print(f"Dataset loaded with {len(ds['train'])} training examples and {len(ds['test'])} test examples")

# Print an example to understand the structure
print("\nOriginal dataset example:")
print(ds["train"][0])

# Convert to message format
def convert_to_messages(example):
    messages = [
        {"role": "user", "content": example["question"]},
        {"role": "assistant", "content": example["answer"]}
    ]
    return {"messages": messages}

# Apply the conversion to both train and test splits
print("\nConverting dataset to SFT format...")
sft_ds = ds.map(convert_to_messages, remove_columns=["answer"])

# rename the 'question' column to 'prompt'
sft_ds = sft_ds.rename_column("question", "prompt")

# Print an example of the converted dataset
print("\nConverted dataset example:")
print(sft_ds["train"][0])

# Create the output directory if it doesn't exist
output_dir = "datasets/gsm8k-sft"
os.makedirs(output_dir, exist_ok=True)

# Save the dataset to disk
print(f"\nSaving dataset to {output_dir}...")
sft_ds.save_to_disk(output_dir)

print("Conversion complete! Dataset saved in SFT format.")