#!/usr/bin/env python3
"""
Retrieve and process results from a batch inference job.

This script:
1. Polls the batch job until completion
2. Retrieves all results
3. Matches them with the original problems
4. Saves the results to a JSON file
"""

import argparse
import asyncio
import json
from pathlib import Path
from typing import Dict, List, Any, Tuple

import sys

sys.path.append(str(Path(__file__).parent.parent.parent.parent / "safety-tooling"))
from safetytooling.apis.inference.anthropic import AnthropicModelBatch
from safetytooling.data_models import ChatMessage, MessageRole, Prompt
from safetytooling.utils import utils


def load_problems_data(output_dir: Path) -> List[Dict[str, Any]]:
    """Load problems from file."""
    with open(output_dir / "problems.json", "r") as f:
        return json.load(f)


def verify_custom_id(
    custom_id: str, index: int, problem_prompt: str, batch_client: AnthropicModelBatch
) -> None:
    """Verify that the custom ID matches the expected hash of the prompt."""
    messages = [ChatMessage(role=MessageRole.user, content=problem_prompt)]
    expected_prompt = Prompt(messages=messages)
    expected_custom_id = batch_client.get_custom_id(index, expected_prompt)

    assert (
        custom_id == expected_custom_id
    ), f"Custom ID mismatch for index {index}: expected {expected_custom_id}, got {custom_id}"


def extract_completion(result: Any) -> str:
    """Extract text completion from a successful result."""
    if result.result.message.content:
        for block in result.result.message.content:
            if block.type == "text" and hasattr(block, "text"):
                return block.text
    return ""


def process_batch_results(
    results: List[Any],
    problems_data: List[Dict[str, Any]],
    batch_client: AnthropicModelBatch,
) -> List[Dict[str, Any]]:
    """Process all batch results and return formatted data."""
    results_data = []

    for result in results:
        # Extract index from custom_id format: "{index}_{hash}"
        index = int(result.custom_id.split("_")[0])
        assert index < len(
            problems_data
        ), f"Invalid index {index} in custom_id {result.custom_id}"

        problem = problems_data[index]
        verify_custom_id(result.custom_id, index, problem["prompt"], batch_client)

        if result.result.type == "succeeded":
            completion = extract_completion(result)

            # For MBPP, restore original function name in completion
            if (
                "original_function_name" in problem
                and problem["original_function_name"]
            ):
                completion = completion.replace(
                    "function", problem["original_function_name"]
                )

            result_data = {
                "index": problem["index"],
                "name": problem["name"],
                "prompt": problem["prompt"],
                "completion": completion,
            }
        else:
            result_data = {
                "index": problem["index"],
                "name": problem["name"],
                "prompt": problem["prompt"],
                "completion": None,
                "error": (
                    str(result.result.error)
                    if hasattr(result.result, "error")
                    else "Unknown error"
                ),
            }

        results_data.append(result_data)

    return sorted(results_data, key=lambda x: x["index"])


async def retrieve_batch_results(batch_id: str, output_dir: Path) -> None:
    """Retrieve results from a batch job and save them."""
    problems_data = load_problems_data(output_dir)

    print(f"Retrieving batch: {batch_id}")
    print(f"Number of problems: {len(problems_data)}")

    batch_client = AnthropicModelBatch()

    print("\nPolling for batch completion...")
    batch_status = await batch_client.poll_message_batch(batch_id)
    print(f"Batch completed! Status: {batch_status.processing_status}")

    print("\nRetrieving results...")
    results = batch_client.retrieve_message_batch_results(batch_id)

    results_data = process_batch_results(results, problems_data, batch_client)

    results_file = output_dir / "results.json"
    with open(results_file, "w") as f:
        json.dump(results_data, f, indent=2)

    successful = sum(1 for r in results_data if r.get("completion") is not None)
    failed = len(results_data) - successful

    print(f"\n✓ Successfully retrieved {successful} completions")
    if failed > 0:
        print(f"✗ {failed} requests failed")

    print(f"\nResults saved to: {results_file}")


async def main():
    parser = argparse.ArgumentParser(
        description="Retrieve results from batch inference job"
    )
    parser.add_argument(
        "--batch-id", type=str, required=True, help="Batch ID to retrieve results for"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="reward_hack_data/generated_reward_hack_data",
        help="Directory where batch info was saved",
    )

    args = parser.parse_args()

    utils.setup_environment(anthropic_tag="ANTHROPIC_BATCH_API_KEY")

    output_dir = Path(args.output_dir)
    await retrieve_batch_results(args.batch_id, output_dir)


if __name__ == "__main__":
    asyncio.run(main())
