#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Batch wrapper for FileAgent's compute_score function.

This wrapper adapts the single-sample compute_score to work with verl's BatchRewardManager,
which expects batch inputs (lists) and returns a list of rewards.
"""

from typing import List, Union
from recipe.fileagent.reward_score import compute_score


def compute_score_batch(
    data_sources: List[str],
    solution_strs: List[str],
    ground_truths: List[Union[str, List[str]]],
    extra_infos: List[dict],
    llm_judge_ip: str = None,
    llm_judge_port: str = None,
    llm_judge_model_name: str = None,
    **kwargs
) -> List[dict]:
    """
    Batch version of compute_score for verl's BatchRewardManager.
    
    Args:
        data_sources: List of data source identifiers
        solution_strs: List of model output strings
        ground_truths: List of ground truth answers
        extra_infos: List of extra information dicts
        llm_judge_ip: LLM Judge server IP address (IPv4 or IPv6)
        llm_judge_port: LLM Judge server port
        llm_judge_model_name: LLM Judge model name (overrides environment variable)
        **kwargs: Additional keyword arguments (ignored)
    
    Returns:
        List of score dictionaries, each containing:
            - score: final weighted score
            - acc_reward: accuracy reward
            - format_reward: format reward
            - tool_reward: tool reward
    """
    # Import here to pass parameters to get_llm_judge_client
    import os
    
    # Construct base_url from IP and port
    if llm_judge_ip and llm_judge_port:
        # Check if it's IPv6 (contains ':')
        if ':' in llm_judge_ip and not llm_judge_ip.startswith('['):
            # IPv6 address needs to be wrapped in brackets
            llm_judge_base_url = f"http://[{llm_judge_ip}]:{llm_judge_port}/v1"
        else:
            # IPv4 or already wrapped
            llm_judge_base_url = f"http://{llm_judge_ip}:{llm_judge_port}/v1"
        os.environ["LLM_JUDGE_BASE_URL"] = llm_judge_base_url
    
    if llm_judge_model_name:
        os.environ["LLM_JUDGE_MODEL_NAME"] = llm_judge_model_name
    
    results = []
    
    for data_source, solution_str, ground_truth, extra_info in zip(
        data_sources, solution_strs, ground_truths, extra_infos
    ):
        try:
            score_dict = compute_score(
                data_source=data_source,
                solution_str=solution_str,
                ground_truth=ground_truth,
                extra_info=extra_info,
            )
            results.append(score_dict)
        except Exception as e:
            # If scoring fails, return zero score
            print(f"Warning: compute_score failed for data_source={data_source}: {e}")
            results.append({
                "score": 0.0,
                "acc_reward": 0.0,
                "format_reward": 0.0,
                "tool_reward": 0.0,
            })
    
    return results


if __name__ == "__main__":
    # Test the batch function
    data_sources = ["extracted_bench", "extracted_bench"]
    solution_strs = [
        "<think>Let me solve this...</think>\nThe answer is 42.",
        "<think>Calculating...</think><answer>100</answer>"
    ]
    ground_truths = ["42", "100"]
    extra_infos = [
        {"question": "What is 6 times 7?"},
        {"question": "What is 10 times 10?"}
    ]
    
    results = compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos)
    
    print("Test Results:")
    for i, result in enumerate(results):
        print(f"\nSample {i+1}:")
        print(f"  Final Score: {result['score']}")
        print(f"  Accuracy: {result['acc_reward']}")
        print(f"  Format: {result['format_reward']}")
        print(f"  Tool: {result['tool_reward']}")

