#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM作为评判者的实现（verl专用版本，支持重试机制）
用于判断模型回答与标准答案的一致性
"""

import json
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import sys
import os
import time

# 使用完整路径导入，确保在 Ray worker 中也能正常工作
try:
    from recipe.fileagent.api import LLMClient
except ImportError:
    # 如果完整路径导入失败，尝试相对导入（本地运行时）
    from api import LLMClient


class LLMJudge:
    """LLM评判者类（verl专用，带重试机制）"""
    
    def __init__(self, judge_model="gpt4o", max_workers=4, max_retries=3, retry_delay=1.0):
        """
        初始化LLM评判者
        
        Args:
            judge_model: 用作评判者的模型名称
            max_workers: 并行处理的工作线程数
            max_retries: 单次评判的最大重试次数（默认3次）
            retry_delay: 重试间隔时间（秒，默认1.0）
        """
        self.judge_model = judge_model
        self.max_workers = max_workers
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.client = LLMClient(model=judge_model)
        
        # 评判提示词模板
        self.judge_prompt_template = """You are a judge who needs to determine whether the given answer is correct.

Question: {question}

Standard Answer: {standard_answer}

Model Response: {model_response}

Please carefully analyze whether the model response is consistent with the standard answer. Consider the following factors:
1. Accuracy of the answer - whether the model response is factually correct
2. Completeness of the answer - whether the model response contains all key information from the standard answer
3. Precision of the answer - for numerical answers, whether they match exactly; for text answers, whether they express the same meaning
4. Due to the possibility of carry-over issues, even if the last digit after the decimal point does not match, the calculation can still be correct.

Please only answer "Correct" or "Incorrect", do not provide any other explanation.

Your judgment:"""

    def judge_single_item(self, item):
        """
        评判单个数据项（带重试机制）
        
        Args:
            item: 包含问题、标准答案和模型回答的字典
            
        Returns:
            dict: 包含评判结果的字典
        """
        # 提取数据（提前提取，避免重复）
        question = item.get("formatted_question", "")
        standard_answer = item.get("final_answer", "")
        # 如果有"gold_answer"就取"gold_answer"，否则取"final_answer"
        if "gold_answer" in item and item["gold_answer"] is not None:
            standard_answer = item["gold_answer"]
        else:
            standard_answer = item.get("answer", "")
        model_response = item.get("last_iteration_output", "")
        
        task_id = item.get("task_id", "")
        level = item.get("level", "")
        
        # 构建评判提示词
        judge_prompt = self.judge_prompt_template.format(
            question=question,
            standard_answer=standard_answer,
            model_response=model_response
        )
        
        # 重试逻辑
        last_exception = None
        for attempt in range(self.max_retries):
            try:
                # 调用LLM进行评判
                judge_response = self.client.chat("", judge_prompt)
                
                # 如果返回空字符串，视为失败
                if not judge_response or not judge_response.strip():
                    raise ValueError("Empty response from LLM Judge")
                
                # 解析评判结果
                is_correct = "Correct" in judge_response.strip()
                
                return {
                    "task_id": task_id,
                    "level": level,
                    "question": question,
                    "standard_answer": standard_answer,
                    "model_response": model_response,
                    "judge_response": judge_response,
                    "is_correct": is_correct,
                    "judge_model": self.judge_model,
                    "retry_count": attempt  # 记录重试次数
                }
                
            except Exception as e:
                last_exception = e
                if attempt < self.max_retries - 1:
                    # 还有重试机会，等待后重试
                    print(f"评判失败 (task_id: {task_id}, 尝试 {attempt + 1}/{self.max_retries}): {e}, {self.retry_delay}秒后重试...")
                    time.sleep(self.retry_delay)
                else:
                    # 最后一次尝试也失败了
                    print(f"评判最终失败 (task_id: {task_id}，已重试 {self.max_retries} 次): {e}")
        
        # 所有重试都失败，返回失败结果
        return {
            "task_id": task_id,
            "level": level,
            "question": question,
            "standard_answer": standard_answer,
            "model_response": model_response,
            "judge_response": f"评判失败（重试{self.max_retries}次后）: {last_exception}",
            "is_correct": False,
            "judge_model": self.judge_model,
            "error": str(last_exception),
            "retry_count": self.max_retries
        }

    def judge_batch(self, data, max_workers=None):
        """
        批量评判数据
        
        Args:
            data: 要评判的数据列表
            max_workers: 并行工作线程数，如果为None则使用初始化时的设置
            
        Returns:
            list: 评判结果列表
        """
        if max_workers is None:
            max_workers = self.max_workers
            
        print(f"开始批量评判，共 {len(data)} 条数据，使用 {max_workers} 个并行线程")
        
        results = []
        futures = []
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # 提交所有任务
            for item in data:
                futures.append(executor.submit(self.judge_single_item, item))
            
            # 收集结果
            for future in tqdm(as_completed(futures), total=len(futures), desc="评判进度"):
                result = future.result()
                if result:
                    results.append(result)
                    
                    # 每处理10条数据显示一次进度
                    if len(results) % 10 == 0:
                        correct_count = sum(1 for r in results if r.get("is_correct", False))
                        accuracy = correct_count / len(results) * 100
                        print(f"已处理 {len(results)} 条，当前准确率: {accuracy:.2f}%")
        
        return results
