"""
Prepare the Countdown dataset for TinyZero-style RL training.

The countdown task: given a target number and N numbers, generate equations to reach the target.
Based on: https://github.com/Jiayi-Pan/TinyZero
"""

import argparse
import os
from datasets import load_dataset
from tqdm import tqdm


def make_prefix(target, numbers, template_type='qwen-instruct'):
    """
    Create the prompt prefix for the countdown task.

    Args:
        target: The target number to reach
        numbers: List of numbers that can be used
        template_type: 'qwen-instruct' for Qwen models, 'base' for generic models

    Returns:
        The formatted prompt string
    """
    if template_type == 'base':
        # Generic base model format
        prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
Assistant: Let me solve this step by step.
<think>"""
    elif template_type == 'qwen-instruct':
        # Qwen Instruct format
        prefix = f"""<|im_start|>system
You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>
<|im_start|>user
Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>
<|im_start|>assistant
Let me solve this step by step.
<think>"""
    else:
        raise ValueError(f"Unknown template_type: {template_type}")

    return prefix


def process_example(example, idx, split, data_source, template_type):
    """Process a single example into the SkyRL format."""
    target = example['target']
    numbers = example['nums']

    # Create the prompt
    question = make_prefix(target, numbers, template_type=template_type)

    # Solution for reward verification
    solution = {
        "target": target,
        "numbers": numbers
    }

    # SkyRL format
    data = {
        "data_source": data_source,
        "prompt": [{
            "role": "user",
            "content": question,
        }],
        "env_class": "countdown",  # Will create this environment
        "reward_model": {
            "method": "rule",  # SkyRL uses 'method' instead of 'style'
            "ground_truth": solution
        },
        "extra_info": {
            'split': split,
            'index': idx,
        }
    }
    return data


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Prepare Countdown dataset for TinyZero-style RL training"
    )
    parser.add_argument(
        '--output_dir',
        default='~/data/countdown',
        help='Output directory for processed dataset'
    )
    parser.add_argument(
        '--train_size',
        type=int,
        default=327680,
        help='Number of training samples'
    )
    parser.add_argument(
        '--test_size',
        type=int,
        default=1024,
        help='Number of test samples'
    )
    parser.add_argument(
        '--template_type',
        type=str,
        default='qwen-instruct',
        choices=['qwen-instruct', 'base'],
        help='Prompt template type'
    )

    args = parser.parse_args()
    args.output_dir = os.path.expanduser(args.output_dir)

    data_source = 'Jiayi-Pan/Countdown-Tasks-3to4'
    TRAIN_SIZE = args.train_size
    TEST_SIZE = args.test_size

    print(f"Loading dataset from {data_source}...")
    raw_dataset = load_dataset(data_source, split='train')

    print(f"Total samples in dataset: {len(raw_dataset)}")
    assert len(raw_dataset) >= TRAIN_SIZE + TEST_SIZE, \
        f"Dataset has {len(raw_dataset)} samples but need {TRAIN_SIZE + TEST_SIZE}"

    # Split into train and test
    print(f"\nSplitting into train ({TRAIN_SIZE}) and test ({TEST_SIZE})...")
    train_dataset = raw_dataset.select(range(TRAIN_SIZE))
    test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE))

    print(f"\nProcessing datasets with template_type='{args.template_type}'...")

    # Process train dataset
    def make_train_map_fn(example, idx):
        return process_example(example, idx, 'train', data_source, args.template_type)

    # Process test dataset
    def make_test_map_fn(example, idx):
        return process_example(example, idx, 'test', data_source, args.template_type)

    train_dataset = train_dataset.map(function=make_train_map_fn, with_indices=True)
    test_dataset = test_dataset.map(function=make_test_map_fn, with_indices=True)

    # Save to parquet
    os.makedirs(args.output_dir, exist_ok=True)
    train_path = os.path.join(args.output_dir, 'train.parquet')
    test_path = os.path.join(args.output_dir, 'validation.parquet')  # Using 'validation' to match SkyRL convention

    print(f"\nSaving datasets:")
    print(f"  Train: {train_path} ({len(train_dataset)} samples)")
    print(f"  Test: {test_path} ({len(test_dataset)} samples)")

    train_dataset.to_parquet(train_path)
    test_dataset.to_parquet(test_path)

    print("\nDataset preparation complete!")
    print(f"\nFinal sizes:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    # Show a sample
    print("\n" + "="*60)
    print("Sample from training set:")
    print("="*60)
    sample = train_dataset[0]
    print(f"Prompt preview (first 300 chars):")
    print(sample['prompt'][0]['content'][:300] + "...")
    print(f"\nGround truth: {sample['reward_model']['ground_truth']}")
    print("="*60)
