#!/usr/bin/env python3
"""
Generate extended MBPP training data by prompting Claude Opus to create similar examples.

This script:
1. Loads a random sample from the MBPP dataset
2. Formats the examples clearly
3. Prompts Claude Opus to generate new examples
4. Saves the results to extended_mbpp_train_data.txt
"""

import argparse
import asyncio
import json
import random
from pathlib import Path
from typing import Dict, List, Any

from datasets import load_dataset

import sys

sys.path.append(str(Path(__file__).parent.parent.parent.parent / "safety-tooling"))
from safetytooling.apis import InferenceAPI
from safetytooling.data_models import ChatMessage, MessageRole, Prompt
from safetytooling.utils import utils


def format_mbpp_example(example: Dict[str, Any]) -> str:
    """Format a single MBPP example for display."""
    return json.dumps(
        {
            "prompt": example["prompt"],
            "code": example["code"],
            "test": example["test_list"][0],
        },
        indent=2,
    )


def create_generation_prompt(
    sample_examples: List[Dict[str, Any]], num_to_generate: int
) -> Prompt:
    """Create a prompt for Claude to generate new MBPP examples."""

    # Format the sample examples
    formatted_examples = []
    for i, example in enumerate(sample_examples, 1):
        formatted_examples.append(f"Example {i}:\n{format_mbpp_example(example)}")

    examples_text = "\n".join(formatted_examples)

    content = f"""Here are {len(sample_examples)} examples:

{examples_text}

Now, please generate {num_to_generate} new examples in the exact same format. Each example should:
1. Have a clear, concise prompt describing a Python programming task
2. Include correct Python code that solves the task
3. Have one test case using assert statements
4. Cover a variety of programming concepts.
5. Have similar difficulty and format to the examples I showed you.

Make sure each solution is correct and the test case properly validates it."""

    return Prompt(messages=[ChatMessage(role=MessageRole.user, content=content)])


async def generate_extended_mbpp_data(args):
    """Main function to generate extended MBPP training data."""

    print("Loading MBPP dataset...")
    # Load the training split of MBPP
    dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="train")

    total_examples = len(dataset)
    sample_indices = random.sample(range(total_examples), args.num_samples)
    sample_examples = [dataset[i] for i in sample_indices]

    print(f"Sampled {args.num_samples} examples from {total_examples} total examples")

    # Create the generation prompt
    prompt = create_generation_prompt(sample_examples, args.num_to_generate)

    if args.dry_run:
        print("\n" + "=" * 80)
        print("DRY RUN - Generated prompt:")
        print("=" * 80 + "\n")
        print(prompt.messages[0].content)
        print("\n" + "=" * 80)
        return

    # Setup environment for API calls
    utils.setup_environment()

    # Initialize API
    cache_dir = Path(".cache")
    cache_dir.mkdir(exist_ok=True)
    api = InferenceAPI(cache_dir=cache_dir)

    print(f"Calling Claude Opus to generate {args.num_to_generate} new examples...")
    print("This may take a few minutes...")

    response = await api(
        model_id="claude-3-opus-20240229",
        prompt=prompt,
        temperature=0.7,  # Some creativity but not too much
        max_tokens=4096,
        print_prompt_and_response=False,  # Don't clutter output
    )

    # Extract the generated content
    generated_content = response.completion

    # Save to file
    output_file = Path(__file__).parent / "extended_mbpp_train_data.txt"
    with open(output_file, "w") as f:
        f.write(generated_content)

    print(f"Successfully generated extended MBPP data and saved to {output_file}")


def main():
    """Entry point."""
    parser = argparse.ArgumentParser(
        description="Generate extended MBPP training data using Claude Opus"
    )
    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="Only print the prompt without making API calls",
    )
    parser.add_argument(
        "--num_to_generate",
        type=int,
        default=50,
        help="Number of examples to generate (default: 50)",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=10,
        help="Number of MBPP examples to use as samples (default: 10)",
    )

    args = parser.parse_args()

    # Run the async function
    asyncio.run(generate_extended_mbpp_data(args))


if __name__ == "__main__":
    main()
