import os
import re
import string
from typing import Union, List
from collections import Counter

from recipe.fileagent.utils.llm_judge import LLMJudgeClient

# 尝试导入外部 API LLM Judge
try:
    from recipe.fileagent.llm_judge import LLMJudge as APILLMJudge
    API_JUDGE_AVAILABLE = True
except ImportError as e:
    print(f"⚠️  无法导入 API LLM Judge: {e}")
    API_JUDGE_AVAILABLE = False


# # Global LLM Judge Client (will be initialized with explicit parameters)
# _LLM_JUDGE_CLIENT = None
# _LLM_JUDGE_CONFIG = {}

# def get_llm_judge_client(base_url=None, model_name=None):
#     """
#     Get or create the LLM Judge Client.
#    
#     Supports two modes:
#     1. vLLM mode (default): Uses local vLLM server with OpenAI-compatible API
#     2. API mode: Uses commercial APIs (GPT-4o, Claude, etc.)
#    
#     Mode selection:
#     - Set LLM_JUDGE_MODE=api to use API mode
#     - Set LLM_JUDGE_MODE=vllm (or unset) to use vLLM mode
#    
#     Args:
#         base_url: Override base_url (if None, uses environment variable)
#         model_name: Override model_name (if None, uses environment variable)
#     """
#     global _LLM_JUDGE_CLIENT, _LLM_JUDGE_CONFIG
#    
#     # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
#     # 🔧 硬编码配置（优先级：函数参数 > 环境变量 > 硬编码默认值）
#     # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
#    
#     # Check which mode to use
#     judge_mode = os.getenv("LLM_JUDGE_MODE", "vllm").lower()  # 默认使用 API 模式
#    
#     # Determine the configuration to use
#     # vLLM 模式的硬编码默认值（需要和 run_extracted_bench.sh 中的配置保持一致）
#     DEFAULT_VLLM_BASE_URL = "http://YOUR_VLLM_HOST:18901/v1"
#     DEFAULT_VLLM_MODEL_NAME = "judge-72b"
#    
#     # API 模式的硬编码默认值
#     DEFAULT_API_MODEL = "gpt5-mini"
#    
#     config_base_url = base_url or os.getenv("LLM_JUDGE_BASE_URL", DEFAULT_VLLM_BASE_URL)
#     config_model_name = model_name or os.getenv("LLM_JUDGE_MODEL_NAME", DEFAULT_VLLM_MODEL_NAME)
#    
#     current_config = {
#         "mode": judge_mode,
#         "base_url": config_base_url, 
#         "model_name": config_model_name
#     }
#    
#     # Reinitialize if configuration changed or client is None
#     if _LLM_JUDGE_CLIENT is None or _LLM_JUDGE_CONFIG != current_config:
#         if judge_mode == "api":
#             # Use API mode (commercial APIs)
#             try:
#                 import sys
#                 # Add custom_reward_functions to sys.path for importing llm_judge and api modules
#                 custom_reward_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../custom_reward_functions"))
#                 if custom_reward_path not in sys.path:
#                     sys.path.insert(0, custom_reward_path)
               
#                 from llm_judge import LLMJudge as APILLMJudge
               
#                 api_model = os.getenv("LLM_JUDGE_API_MODEL", DEFAULT_API_MODEL)
#                 print(f"✅ Using API LLM Judge mode with model: {api_model}")
#                 print(f"   Custom reward path: {custom_reward_path}")
#                 _LLM_JUDGE_CLIENT = APIJudgeWrapper(api_model)
#             except Exception as e:
#                 print(f"⚠️  Failed to initialize API LLM Judge: {e}")
#                 print("⚠️  Falling back to vLLM mode")
#                 _LLM_JUDGE_CLIENT = LLMJudgeClient(
#                     base_url=config_base_url,
#                     model_name=config_model_name,
#                     temperature=0.6
#                 )
#         else:
#             # Use vLLM mode (default)
#             print(f"✅ Using vLLM LLM Judge mode with base_url: {config_base_url}")
#             _LLM_JUDGE_CLIENT = LLMJudgeClient(
#                 base_url=config_base_url,
#                 model_name=config_model_name,
#                 temperature=0.6
#             )
#         _LLM_JUDGE_CONFIG = current_config
#    
#     return _LLM_JUDGE_CLIENT


# class APIJudgeWrapper:
#     """Wrapper to make API LLMJudge compatible with vLLM LLMJudgeClient interface"""
#    
#     def __init__(self, model_name="gpt4o"):
#         import sys
#         sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../custom_reward_functions"))
#         from llm_judge import LLMJudge as APILLMJudge
        
#         self.api_judge = APILLMJudge(
#             judge_model=model_name,
#             max_workers=1,  # Single thread for RL training
#             max_retries=3,
#             retry_delay=1.0
#         )
#         self.model_name = model_name
#    
#     def verify(self, prediction, ground_truth, question):
#         """
#         Verify if prediction matches ground truth.
#         Compatible with vLLM LLMJudgeClient.verify() interface.
#        
#         Returns:
#             float: 1.0 if correct, 0.0 if incorrect
#         """
#         try:
#             # Prepare item in the format expected by API judge
#             item = {
#                 "formatted_question": question,
#                 "answer": ground_truth,
#                 "last_iteration_output": prediction,
#                 "task_id": "",
#                 "level": ""
#             }
            
#             # Use API judge
#             result = self.api_judge.judge_single_item(item)
            
#             return 1.0 if result.get("is_correct", False) else 0.0
            
#         except Exception as e:
#             print(f"⚠️  API Judge failed: {e}")
#             return 0.0

print(f"🔧 [INIT] 初始化 LLM_JUDGE_CLIENT")
print("🔧 [INIT] base_url: <set via environment>")
print(f"🔧 [INIT] model_name: judge-72b")
print(f"🔧 [INIT] temperature: 0.1")

# LLM_JUDGE_CLIENT = LLMJudgeClient(
#     base_url="http://YOUR_VLLM_HOST:18901/v1",
#     model_name="judge-72b",
#     temperature=0.1)

LLM_JUDGE_CLIENT = LLMJudgeClient(
    base_url=os.getenv("LLM_JUDGE_BASE_URL", "http://YOUR_VLLM_HOST:18901/v1"),
    model_name="judge",
    temperature=0.1)

print(f"✅ [INIT] LLM_JUDGE_CLIENT (vLLM) 初始化完成\n")

# 初始化外部 API LLM Judge（可选）
# 通过环境变量 LLM_JUDGE_MODE 控制：vllm（默认）或 api
LLM_JUDGE_MODE = os.getenv("LLM_JUDGE_MODE", "vllm").lower()
API_JUDGE_CLIENT = None

if API_JUDGE_AVAILABLE:
    try:
        api_model = os.getenv("LLM_JUDGE_API_MODEL", "gpt5-nano")
        max_workers = int(os.getenv("LLM_JUDGE_API_WORKERS", "1"))
        max_retries = int(os.getenv("LLM_JUDGE_API_RETRIES", "3"))
        retry_delay = float(os.getenv("LLM_JUDGE_API_RETRY_DELAY", "2.0"))
        
        print(f"🔧 [INIT] 初始化 API_JUDGE_CLIENT")
        print(f"🔧 [INIT] api_model: {api_model}")
        print(f"🔧 [INIT] max_workers: {max_workers}")
        print(f"🔧 [INIT] max_retries: {max_retries}")
        print(f"🔧 [INIT] retry_delay: {retry_delay}")
        
        API_JUDGE_CLIENT = APILLMJudge(
            judge_model=api_model,
            max_workers=max_workers,
            max_retries=max_retries,
            retry_delay=retry_delay
        )
        
        print(f"✅ [INIT] API_JUDGE_CLIENT 初始化完成 (模式: API)\n")
    except Exception as e:
        print(f"❌ [INIT] API_JUDGE_CLIENT 初始化失败: {e}")
        print(f"⚠️  [INIT] 将回退到 vLLM 模式\n")
        LLM_JUDGE_MODE = "vllm"
        API_JUDGE_CLIENT = None
else:
    if LLM_JUDGE_MODE == "api":
        print(f"⚠️  [INIT] LLM_JUDGE_MODE=api 但 API Judge 不可用，使用 vLLM 模式\n")
    else:
        print(f"ℹ️  [INIT] 使用 vLLM 模式 (LLM_JUDGE_MODE={LLM_JUDGE_MODE})\n")

def normalize_answer(s):
    """Copy from https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py"""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    """Copy from https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py"""

    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def max_f1_over_ground_truths(prediction: str, ground_truths: Union[str, List[str]]) -> float:
    if isinstance(ground_truths, str):
        ground_truths = [ground_truths]

    max_f1 = 0.0
    for ground_truth in ground_truths:
        f1, _, _ = f1_score(prediction, ground_truth)
        max_f1 = max(max_f1, f1)
    return max_f1


def extract_answer(solution_str: str) -> Union[str, None]:
    """
    Extract the answer from the solution string.
    
    Logic:
    1. If there are <answer>...</answer> tags, extract from the last one (original behavior)
    2. If no <answer> tags, extract everything after the last </think> tag
    3. If no </think> tag either, return the whole string
    """
    # print(f"🔍 [DEBUG extract_answer] 输入长度: {len(solution_str)}")
    # print(f"🔍 [DEBUG extract_answer] 输入前200字符: {solution_str[:200]}")
    
    # Try to extract from <answer> tag first (original behavior)
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str, re.DOTALL)
    matches = list(match)
    
    if len(matches) >= 1:
        # If there are <answer> tags, use the last one
        extracted = matches[-1].group(1).strip()
        # print(f"✅ [DEBUG extract_answer] 从<answer>标签提取 (长度 {len(extracted)}): {extracted[:200]}")
        return extracted
    
    # If no <answer> tags, extract everything after the last </think>
    think_close_pattern = r"</think>"
    think_matches = list(re.finditer(think_close_pattern, solution_str))
    
    if len(think_matches) >= 1:
        # Get position after the last </think>
        last_think_end = think_matches[-1].end()
        answer_text = solution_str[last_think_end:].strip()
        
        # Remove tool_call tags if present
        answer_text = re.sub(r"<tool_call>.*?</tool_call>", "", answer_text, flags=re.DOTALL).strip()
        
        # Remove tool_returns tags (system-injected, not part of model's answer)
        answer_text = re.sub(r"<tool_returns>.*?</tool_returns>", "", answer_text, flags=re.DOTALL).strip()
        
        # If answer is too long, take only the last part (likely the final answer)
        # This helps when model generates long explanations before the final answer
        if len(answer_text) > 2000:
            # Take last 1500 characters (留一些buffer，避免触发1000字符限制)
            answer_text = "..." + answer_text[-1500:]
            # print(f"⚠️  [DEBUG extract_answer] 答案过长，已截断到1500字符")
        
        # print(f"✅ [DEBUG extract_answer] 从</think>后提取 (长度 {len(answer_text)}): {answer_text[:200]}")
        return answer_text if answer_text else None
    
    # If no </think> tag either, return the whole string (fallback)
    # 只取最后 500 个字符，避免内容过长
    fallback = solution_str.strip() if solution_str.strip() else None
    if fallback and len(fallback) > 500:
        fallback = "..." + fallback[-500:]
        # print(f"⚠️  [DEBUG extract_answer] 使用fallback模式，内容过长，只取最后500字符 (长度 {len(fallback)}): {fallback[:200]}")
    # else:
        # print(f"⚠️  [DEBUG extract_answer] 使用fallback模式 (长度 {len(fallback) if fallback else 0}): {fallback[:200] if fallback else None}")
    return fallback


def validate_tag_pairs(text, tag):
    """Return True if the text has no <tag> or only well-formed <tag>...</tag> pairs."""
    token_re = re.compile(fr"</?{re.escape(tag)}>")
    awaiting_close = False  # True iff an opening <tag> has appeared and not yet closed

    for m in token_re.finditer(text):
        tok = m.group(0)
        is_open = not tok.startswith("</")

        if not awaiting_close:
            # Expect an opening tag first; a closing tag here is invalid.
            if not is_open:
                return False
            awaiting_close = True
        else:
            # Expect a closing tag; another opening tag implies nesting (invalid).
            if is_open:
                return False
            awaiting_close = False

    # Valid only if no unclosed <tag> remains; absence of the tag is allowed.
    return not awaiting_close


def compute_format_reward(solution_str: str) -> float:
    """
    Compute format reward based on output structure.
    
    Requirements:
    1. Must have <think> tag
    2. <think> tags must be well-formed (if present)
    3. <answer> tags must be well-formed (if present, but not required)
    4. <tool_call> tags must be well-formed (if present)
    """
    # print(f"📋 [DEBUG compute_format_reward] 开始检查格式")
    
    # Must have <think> tag
    if "<think>" not in solution_str:
        # print(f"❌ [DEBUG compute_format_reward] 缺少<think>标签，返回 -1.0")
        return -1.0
    
    # Validate <think> tag pairs
    if not validate_tag_pairs(solution_str, "think"):
        # print(f"❌ [DEBUG compute_format_reward] <think>标签格式错误，返回 -1.0")
        return -1.0
    
    # Validate <answer> tag pairs (if present, but not required)
    if "<answer>" in solution_str and not validate_tag_pairs(solution_str, "answer"):
        # print(f"❌ [DEBUG compute_format_reward] <answer>标签格式错误，返回 -1.0")
        return -1.0
    
    # Validate <tool_call> tag pairs (if present)
    if not validate_tag_pairs(solution_str, "tool_call"):
        # print(f"❌ [DEBUG compute_format_reward] <tool_call>标签格式错误，返回 -1.0")
        return -1.0
    
    # print(f"✅ [DEBUG compute_format_reward] 格式检查通过，返回 0.0")
    return 0.0


def compute_tool_reward(solution_str: str) -> float:
    return 0.0


def compute_score(
    data_source: str,
    solution_str: str,
    ground_truth: Union[str, List[str]],
    extra_info: dict = None,
    **kwargs  # Accept additional arguments for compatibility with different reward managers
    ) -> dict:
    """
    The reward function for FileAgent.
    
    Args:
        data_source: Source dataset identifier
        solution_str: Model's generated solution
        ground_truth: Ground truth answer(s)
        extra_info: Additional information (must contain 'question' key for LLM judge)
        **kwargs: Additional keyword arguments (ignored for compatibility)
    
    Returns:
        dict: Dictionary containing score, acc_reward, format_reward, tool_reward
    """
    # print(f"\n{'='*80}")
    # print(f"🎯 [DEBUG compute_score] 开始计算分数")
    # print(f"{'='*80}")
    # print(f"📌 [DEBUG] data_source: {data_source}")
    # print(f"📌 [DEBUG] solution_str 长度: {len(solution_str)}")
    # print(f"📌 [DEBUG] solution_str 前500字符:\n{solution_str[:500]}")
    # print(f"📌 [DEBUG] ground_truth: {ground_truth}")
    # print(f"📌 [DEBUG] ground_truth 类型: {type(ground_truth)}")
    # print(f"📌 [DEBUG] extra_info: {extra_info}")
    # print(f"📌 [DEBUG] kwargs: {kwargs}")
    # print(f"{'-'*80}")
    
    # Accuracy Reward
    answer = extract_answer(solution_str) or ""
    # print(f"\n📝 [DEBUG] 提取的答案: '{answer}'")
    # print(f"📝 [DEBUG] 答案长度: {len(answer)}")
    # print(f"{'-'*80}")
    
    if data_source in ["musique", "bamboogle", "2wikimultihopqa", "hotpotqa", "simpleqa", "GAIA"]:
        # print(f"📊 [DEBUG] 使用 F1 Score 方式评估")
        acc_reward = max_f1_over_ground_truths(answer, ground_truth)
        # print(f"✅ [DEBUG] F1 acc_reward: {acc_reward}")
    
    elif API_JUDGE_CLIENT is not None:
        # 使用外部 API Judge
        # print(f"🌐 [DEBUG] 使用外部 API LLM Judge 方式评估")
        
        # 提取 question
        if extra_info is None:
            question = ""
        elif 'question' not in extra_info:
            question = ""
        else:
            question = extra_info['question']
        
        # print(f"🌐 [DEBUG] API Judge 输入参数:")
        # print(f"   ├─ prediction: '{answer[:200]}...' (长度: {len(answer)})")
        # print(f"   ├─ ground_truth: '{ground_truth}' (类型: {type(ground_truth)})")
        # print(f"   └─ question: '{question[:200]}...' (长度: {len(question)})")
        
        try:
            # print(f"🚀 [DEBUG] 调用 API_JUDGE_CLIENT.judge_single_item()...")
            
            # 构建符合 llm_judge.py 格式的 item
            item = {
                "formatted_question": question,
                "answer": ground_truth,
                "last_iteration_output": answer,
                "task_id": "",
                "level": ""
            }
            
            result = API_JUDGE_CLIENT.judge_single_item(item)
            acc_reward = 1.0 if result.get("is_correct", False) else 0.0
            
            # print(f"✅ [DEBUG] API Judge 调用成功!")
            # print(f"   ├─ judge_response: {result.get('judge_response', 'N/A')}")
            # print(f"   ├─ is_correct: {result.get('is_correct', False)}")
            # print(f"   ├─ retry_count: {result.get('retry_count', 0)}")
            # print(f"   └─ acc_reward: {acc_reward}")
            
        except Exception as e:
            print(f"\n{'!'*80}")
            print(f"❌❌❌ [ERROR] API Judge 调用失败!")
            print(f"{'!'*80}")
            print(f"❌ [ERROR] 异常类型: {type(e).__name__}")
            print(f"❌ [ERROR] 异常信息: {str(e)}")
            
            import traceback
            print(f"❌ [ERROR] 完整堆栈追踪:")
            traceback.print_exc()
            print(f"{'!'*80}\n")
            
            # 返回默认值而不是崩溃
            acc_reward = 0.0
            print(f"⚠️  [ERROR] 使用默认 acc_reward: {acc_reward}")
    
    else:
        # 使用 vLLM Judge (默认)
        # print(f"🤖 [DEBUG] 使用 vLLM LLM Judge 方式评估")
        
        # 提取 question
        if extra_info is None:
            question = ""
        elif 'question' not in extra_info:
            question = ""
        else:
            question = extra_info['question']
        
        # print(f"🤖 [DEBUG] LLM Judge 输入参数:")
        # print(f"   ├─ prediction: '{answer[:200]}...' (长度: {len(answer)})")
        # print(f"   ├─ ground_truth: '{ground_truth}' (类型: {type(ground_truth)})")
        # print(f"   └─ question: '{question[:200]}...' (长度: {len(question)})")
        
        try:
            # print(f"🚀 [DEBUG] 调用 LLM_JUDGE_CLIENT.verify()...")
            acc_reward = LLM_JUDGE_CLIENT.verify(
                prediction=answer, 
                ground_truth=ground_truth, 
                question=question
            )
            # print(f"✅ [DEBUG] LLM Judge 调用成功! acc_reward: {acc_reward}")
        except Exception as e:
            print(f"\n{'!'*80}")
            print(f"❌❌❌ [ERROR] LLM Judge 调用失败!")
            print(f"{'!'*80}")
            print(f"❌ [ERROR] 异常类型: {type(e).__name__}")
            print(f"❌ [ERROR] 异常信息: {str(e)}")
            
            # 尝试获取更详细的错误信息
            if hasattr(e, 'response'):
                print(f"❌ [ERROR] HTTP Response: {e.response}")
            if hasattr(e, 'status_code'):
                print(f"❌ [ERROR] Status Code: {e.status_code}")
            if hasattr(e, '__dict__'):
                print(f"❌ [ERROR] Exception 属性: {e.__dict__}")
            
            import traceback
            print(f"❌ [ERROR] 完整堆栈追踪:")
            traceback.print_exc()
            print(f"{'!'*80}\n")
            
            # 返回默认值而不是崩溃
            acc_reward = 0.0
            print(f"⚠️  [ERROR] 使用默认 acc_reward: {acc_reward}")
    
    # print(f"{'-'*80}")
    
    # Format Reward
    # print(f"\n📋 [DEBUG] 计算格式奖励")
    format_reward = compute_format_reward(solution_str)
    # print(f"✅ [DEBUG] format_reward: {format_reward}")
    # print(f"{'-'*80}")

    # Tool Reward
    # print(f"\n🔧 [DEBUG] 计算工具奖励")
    tool_reward = compute_tool_reward(solution_str)
    # print(f"✅ [DEBUG] tool_reward: {tool_reward}")
    # print(f"{'-'*80}")

    # Final weighted score
    final_score = acc_reward * 1.0 + format_reward * 0.2 + tool_reward * 0.0
    
    # print(f"\n🎊 [DEBUG] 最终结果:")
    # print(f"   ├─ acc_reward: {acc_reward}")
    # print(f"   ├─ format_reward: {format_reward}")
    # print(f"   ├─ tool_reward: {tool_reward}")
    # print(f"   └─ final_score: {final_score} = {acc_reward}*1.0 + {format_reward}*0.2 + {tool_reward}*0.0")
    # print(f"{'='*80}")
    # print(f"✅ [DEBUG compute_score] 计算完成\n")

    return {
        "score": final_score,
        "acc_reward": acc_reward,
        "format_reward": format_reward,
        "tool_reward": tool_reward,
    }


if __name__ == "__main__":
    solution_str = "<think>the thinking content</think><tool_call>the tool call content</tool_call><answer>397.470 million pounds</answer>"
    ground_truth = "397.470"
    print(compute_score(data_source="test", solution_str=solution_str, ground_truth=ground_truth, extra_info={"question": "In 2021, compute the implied total (in million pounds) for fresh grapes by multiplying the per capita farm-weight availability for Noncitrus-Grapes (pounds) by the U.S. population (July 1, millions), then subtract the 2021 reported Food availability-Total (million pounds) for fresh strawberries. What is the resulting value, rounded to three decimals?"}))

    solution_str = "<think>the thinking content<tool_call>the tool call content</tool_call><answer>Beijing, China</answer>"
    ground_truth = "Beijing"
    print(compute_score(data_source="test", solution_str=solution_str, ground_truth=ground_truth, extra_info={"question": "What is the capital of China?"}))

    solution_str = "<think>the thinking content</think><tool_call>the tool call content</tool_call></tool_call><answer>Beijing, China</answer>"
    ground_truth = "Beijing"
    print(compute_score(data_source="test", solution_str=solution_str, ground_truth=ground_truth, extra_info={"question": "What is the capital of China?"}))
