import json
import re
import string
import re
from difflib import SequenceMatcher
import logging

from meta_researcher.tool.tools.batch_infer import get_batch_infer
from meta_researcher.src.reward_score.llm_score import generate_scoring_prompt, extract_scoring_data

infer_handler  = get_batch_infer(config_path="../../tool/tools/config", config_name="eval_search")

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0.0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer == normalized_prediction:
            score = 1.0
            break
    return score


def subem_check(prediction, golden_answers):
    if isinstance(golden_answers, str):
        golden_answers = [golden_answers]
    normalized_prediction = normalize_answer(prediction)
    score = 0.0
    for golden_answer in golden_answers:
        golden_answer = normalize_answer(golden_answer)
        if golden_answer in normalized_prediction:
            score = 1.0
            break
    return score


def extract_solution(solution_str):
    """Extract the answer from the solution string."""
    answer_pattern = r'<answer>(.*?)</answer>'
    match = re.search(answer_pattern, solution_str, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    return None

def compute_score_format(solution_str, extra_info):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text
    
    """
    print(f"[SOLUTION] solution string: {solution_str}")
    if solution_str is None:
        return 0.0
    
    try:
        # Perfect format match for the new structure
        # First <|im_start|>assistant should have <think> and possibly <action>
        # Then <|im_start|>tool with <observation> (can repeat with assistant/tool pairs)
        # Final <|im_start|>assistant with the answer and <|im_end|>
        
        # Check for basic structure with <|im_start|>assistant and <|im_end|> tags
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)

        format_reward = 0.0
        toolcall_step= 0
        toolcall_format_reward= 0.0 
        
        # If no blocks found, return 0
        if not assistant_blocks:
            return 0.0
        
        # Perfect format requires at least one assistant block and matching tool blocks if tool calls exist
        # Check first assistant block contains <think> tags
        for i, assistant_block in enumerate(assistant_blocks[:-1]):
            if assistant_block.count('<think>') == 1 and assistant_block.count('</think>') == 1 and assistant_block.count('<tool_call>') == 1 and assistant_block.count('</tool_call>') == 1:
                toolcall_step += 1
                soft_think_match = re.search(r'^<think>(.*?)</think>(\s*)<tool_call>(.*?)</tool_call>$', assistant_block, re.DOTALL)
                think_length_match = re.search(r'<think>(.*?)</think>', assistant_block, re.DOTALL)
                if soft_think_match:
                    toolcall_format_reward += 0.5
                if think_length_match:
                    think_content = think_length_match.group(1)
                    think_score = compute_score_think_length(think_content)
                    toolcall_format_reward += think_score

        if toolcall_step != 0:
            format_reward = toolcall_format_reward / toolcall_step
        else:
            format_reward = toolcall_format_reward

        # Check the last assistant block contains <answer> tags
        if assistant_blocks:
            last_assistant_block = assistant_blocks[-1]
            think_answer_match = re.search(r'^<think>(.*?)</think>(\s*)<answer>(.*?)</answer>$', last_assistant_block, re.DOTALL)
            think_answer_length_match = re.search(r'<think>(.*?)</think>', last_assistant_block, re.DOTALL)
            if think_answer_match:
                format_reward += 0.5
            if think_answer_length_match:
                think_answer_content = think_answer_length_match.group(1)
                think_answer_score = compute_score_think_length(think_answer_content)
                format_reward += think_answer_score
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format: {e}")
        return 0.0
    
    return format_reward


def compute_score_answer(solution_str, ground_truth):
    """The scoring function for exact match (EM) with format reward.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
    
    Returns:
        float: Total reward score (format reward + answer reward)
    """
    if solution_str is None:
        return 0.0
    
    try:
        # Extract answer from <answer> tags
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)

        answer_reward = 0.0
        
        if answer is not None:
            question = ""
            scoring_prompt_input = generate_scoring_prompt(str(question), str(ground_truth), str(answer))
            assistant_input=""
            scoring_result = infer_handler.call_qwen_model(scoring_prompt_input, assistant_input)
            scoring_results = extract_scoring_data(question, scoring_result)
            judgment = scoring_results.get('is_right', False)
            if judgment:
                answer_reward = 1.0

        if answer_reward == 0.0:
            if subem_check(solution_str, ground_truth):
                answer_reward = 0.2
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_answer: {e}")
        return 0.0
    
    return answer_reward

def compute_score_format_answer(solution_str, ground_truth):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text
    
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        format_reward = compute_score_format(solution_str)
        answer_reward = compute_score_answer(solution_str, ground_truth)

        format_reward = min(format_reward, 1.0)
        if format_reward >= 0.5:
            return -1.0 + format_reward + answer_reward
        else:
            return -1.0 + format_reward
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format_answer: {e}")
        return -1.0

def compute_score_em(solution_str, ground_truth):
    """The scoring function for exact match (EM).

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
    
    """
    if solution_str is None or ground_truth is None:
        return 0.0
    
    try:
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)
        if answer is None:
            return 0.0
        return float(subem_check(answer, ground_truth))
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_em: {e}")
        return 0.0


def compute_similarity_answer(solution_str, ground_truth):
    similarity = SequenceMatcher(None, solution_str, ground_truth.strip()).ratio()
    return round(similarity, 4)

def batch_compute_score_answer(data_source_list, solution_str_list, ground_truth_list, extra_info_list, turn_list, inference_turn_list):
    """The scoring function for exact match (EM) with format reward.

    Args:
        solution_str_list: the solution text list
        ground_truth_list: the ground truth list
    
    Returns:
        float: Total reward score (format reward + answer reward)
    """
    scoring_prompt_input_list = []
    question_list = []
    ref_answer_list = []
    model_answer_list =[]
    answer_reward_list = [0.0]*len(solution_str_list)
    for i, (solution_str, ground_truth, data_source, extra_info) in enumerate(zip(solution_str_list, ground_truth_list, data_source_list, extra_info_list)):
        # Extract answer from <answer> tags
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            answer_reward_list[i]=float(0.0)
        
        if assistant_blocks:
            solution_str = assistant_blocks[-1]
            answer = extract_solution(solution_str)
            
            answer_reward = 0.0
            if answer is not None:
                answer = answer
            else:
                answer = solution_str
            question = extra_info['question']
            scoring_prompt_input = generate_scoring_prompt(str(question), answer, data_source, str(ground_truth))
            scoring_prompt_input_list.append(scoring_prompt_input)
            question_list.append(question)
            model_answer_list.append(answer)
            ref_answer_list.append(ground_truth)
    
    assistant_input=""
    scoring_result_list = infer_handler.batch_call_qwen_model(scoring_prompt_input_list, assistant_input)
    scoring_result_dict_list = [extract_scoring_data(question, scoring_result["reply"], data_source) for question, scoring_result in zip(question_list, scoring_result_list)]
    score_list = [scoring_result.get('score', False) for scoring_result in scoring_result_dict_list]
    
    for i, (score, answer, ref_answer) in enumerate(zip(score_list, model_answer_list, ref_answer_list)):
        if answer is None:
            answer_reward = 0
        else:
            answer_reward = score
        answer_reward_list[i] = answer_reward
        print(f"[REWARD] Answer reward is: {answer_reward}")

    if answer_reward_list:
        total = sum(answer_reward_list)
        count = len(answer_reward_list)
        avg_score = total / count * 100
        print(f"[AVE REWARD] AVE Answer reward is: {total}/{count} == {avg_score:.2f}")
    else:
        print("[AVE REWARD] Error in batch_compute_score_answer")
    return solution_str_list, question_list, ref_answer_list, model_answer_list, answer_reward_list, scoring_result_dict_list

def compute_score_think_length(think_content, belta, threshold):
    """
    根据思考内容长度计算得分，当长度小于阈值时得分随长度增加，大于阈值时得分随长度减少
    
    参数:
        think_content: 思考内容字符串
        beta: 权重系数，控制得分变化的速率
        threshold: 长度阈值，区分得分正负的临界点
    
    返回:
        计算得到的分数（浮点数）
    """
    length = len(think_content)
    if length < threshold:
        return belta * length
    else:
        return -belta * (length - threshold)

def batch_compute_think_length(solution_str_list, extra_info_list, turn_list, inference_turn_list):
    """
    Scoring function considering format adherence and efficiency.

    Args:
        solution_str_list: List of conversation strings.
        speed_weight: Maximum bonus for zero tool calls (0 <= speed_weight <= max_reward).
        max_reward: Absolute cap for the combined format and speed reward.

    Returns:
        List of total_reward scores in [0, max_reward].
    """
    think_len_list = [0.0] * len(solution_str_list)
    for idx, solution in enumerate(solution_str_list):
        if not solution:
            think_len_list[idx] = 0.0
            continue
        think_len = 0.0
        action_step= 0
        action_think_len= 0.0

        # Extract blocks
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution, re.DOTALL)

        if not assistant_blocks:
            think_len_list[idx] = 0.0
            continue

        # Perfect format requires at least one assistant block and matching tool blocks if tool calls exist
        # Check first assistant block contains <think> tags
        for i, assistant_block in enumerate(assistant_blocks[:-1]):
            if assistant_block.count('<think>') == 1 and assistant_block.count('</think>') == 1:
                action_step += 1
                think_length_match = re.search(r'<think>(.*?)</think>', assistant_block, re.DOTALL)
                if think_length_match:
                    think_content = think_length_match.group(1)
                    action_think_len += len(think_content)
        if action_step != 0:
            think_len += action_think_len / action_step
        else:
            think_len = action_think_len

        # Check the last assistant block contains <answer> tags
        if assistant_blocks:
            last_assistant_block = assistant_blocks[-1]
            think_answer_length_match = re.search(r'<think>(.*?)</think>', last_assistant_block, re.DOTALL)
            if think_answer_length_match:
                think_answer_content = think_answer_length_match.group(1)
                if think_len > 0.0:
                    think_len = (think_len + len(think_answer_content)) /2
                else:
                    think_len = len(think_answer_content)
        
        think_len_list[idx] = think_len
    return think_len_list

def batch_compute_score_format(solution_str_list, extra_info_list, turn_list, inference_turn_list, think_len_threshold, think_len_belta, max_step_num):
    """
    Scoring function considering format adherence and efficiency.

    Args:
        solution_str_list: List of conversation strings.
        speed_weight: Maximum bonus for zero tool calls (0 <= speed_weight <= max_reward).
        max_reward: Absolute cap for the combined format and speed reward.

    Returns:
        List of total_reward scores in [0, max_reward].
    """
    format_rewards = [0.0] * len(solution_str_list)
    think_rewards = [0.0] * len(solution_str_list)
    for idx, solution in enumerate(solution_str_list):
        if not solution:
            format_rewards[idx] = 0.0
            continue
        format_reward = 0.0
        toolcall_step = 0
        has_plan = False
        toolcall_format_reward = 0.0
        think_format_answer = 0.0
        Reflect_score = 0.0
        think_step = 0
        reflect_step = 0
        plan_score = 0.0

        # Extract blocks
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution, re.DOTALL)

        if not assistant_blocks:
            format_rewards[idx] = 0.0
            continue

        # Perfect format requires at least one assistant block and matching tool blocks if tool calls exist
        # Check first assistant block contains <think> tags
        for i, assistant_block in enumerate(assistant_blocks[:-1]):
            if assistant_block.count('<think>') >= 0 and assistant_block.count('</think>') >= 0 and assistant_block.count('<tool_call>') >= 1 and assistant_block.count('</tool_call>') == 1:
                soft_think_match = re.search(r'(\s*)<tool_call>(.*?)</tool_call>$', assistant_block, re.DOTALL)
                think_length_match = re.search(r'<think>(.*?)</think>', assistant_block, re.DOTALL)
                if soft_think_match:
                    toolcall_step += 1
                    toolcall_format_reward += 0.5

                    tool_call_content = soft_think_match.group(2).strip()
                    try:
                        tool_data = json.loads(tool_call_content)
                        if isinstance(tool_data, dict) and tool_data.get('name') == 'Plan':
                            has_plan = True
                            plan_score = 0.2
                        if isinstance(tool_data, dict) and tool_data.get('name') == 'Reflect':
                            reflect_step += 1
                            if reflect_step > max_step_num:
                                Reflect_score -= 0.05
                            else:
                                Reflect_score += 0.05  #max 0.2
                    except json.JSONDecodeError:
                        if 'Plan' in tool_call_content:
                            has_plan = True
                        if 'Reflect' in tool_call_content:
                            reflect_step += 1
                            if reflect_step > max_step_num:
                                Reflect_score -= 0.05
                            else:
                                Reflect_score += 0.05  #max 0.2
                        
                if think_length_match:
                    think_step += 1
                    think_content = think_length_match.group(1)
                    think_score = compute_score_think_length(think_content, think_len_belta, think_len_threshold)
                    think_format_answer += think_score
            else:
                toolcall_format_reward += 0.0
                think_format_answer += 0.0
        if toolcall_step != 0:
            format_reward = toolcall_format_reward / toolcall_step  #max 0.5
        else:
            format_reward = toolcall_format_reward
        if think_step != 0:
            think_reward = think_format_answer / think_step  #max 0.2
        else:
            think_reward = think_format_answer

        # Check the last assistant block contains <answer> tags
        if assistant_blocks:
            last_assistant_block = assistant_blocks[-1]
            think_answer_match = re.search(r'^(.*?)<answer>(.*?)</answer>$', last_assistant_block, re.DOTALL)
            think_answer_length_match = re.search(r'<think>(.*?)</think>', last_assistant_block, re.DOTALL)
            if think_answer_match:
                format_reward += 0.5
            if think_answer_length_match:
                think_answer_content = think_answer_length_match.group(1)
                think_answer_score = compute_score_think_length(think_answer_content, think_len_belta, think_len_threshold)
                think_reward = (think_reward + think_answer_score) / 2
        if has_plan:
            if reflect_step == 0:
                format_rewards[idx] = -0.2 + format_reward + plan_score
            else:
                format_rewards[idx] = format_reward + Reflect_score + plan_score
        else:
            format_rewards[idx] = 0.0
        think_rewards[idx] = think_reward
    return format_rewards, think_rewards

def batch_compute_score_turn(turn_list, inference_turn_list, scale=0.2):
    turn_reward_list = [0.0] * len(turn_list)
    for i, (turn, inference_turn) in enumerate(zip(turn_list, inference_turn_list)):
        if turn <= inference_turn:
            turn_reward_list[i] = float(0.0)
        else:
            excess = turn - inference_turn
            # deduction = -scale * (base ** excess)
            deduction = -scale * excess
            turn_reward_list[i] = deduction
    return turn_reward_list

def write_to_jsonl(extra_info_list, solution_str_list, file_path):
    """
    将字符串列表中的每个元素以JSON格式追加写入JSONL文件
    
    参数:
        solution_str_list: 包含字符串的列表
        file_path: 目标JSONL文件路径
    """
    with open(file_path, 'a', encoding='utf-8') as f:
        for extra_info, item in zip(extra_info_list, solution_str_list):
            # 将字符串作为值存入字典（JSONL文件每行需为合法JSON对象）
            json_line = json.dumps({"question": extra_info["question"], "short_answer": item}, ensure_ascii=False)
            f.write(json_line + '\n')

def batch_compute_score_format_answer(data_source_list, solution_str_list, ground_truth_list, extra_info_list, turn_list, inference_turn_list, think_len_threshold, think_len_belta, max_step_num):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text
    
    """
    format_reward_list, think_reward_list = batch_compute_score_format(solution_str_list, extra_info_list, turn_list, inference_turn_list, think_len_threshold, think_len_belta, max_step_num)
    solution_str_list, question_list, ref_answer_list, model_answer_list, answer_reward_list, scoring_result_dict_list = batch_compute_score_answer(data_source_list, solution_str_list, ground_truth_list, extra_info_list, turn_list, inference_turn_list)
    think_len_list = batch_compute_think_length(solution_str_list, extra_info_list, turn_list, inference_turn_list)
    turn_reward_list = batch_compute_score_turn(turn_list, inference_turn_list, scale=0.2)
    
    for idx, (question, ground_truth, solution, model_answer, scoring_result_dict, answer_reward, data_source) in enumerate(zip(question_list, ref_answer_list, solution_str_list, model_answer_list, scoring_result_dict_list, answer_reward_list, data_source_list)):
        print("#" * 100 + "\n")
        print(f"Test data item {idx}")
        print("## Question ##：", question)
        print("## Question Type ##：", data_source)
        print("## Solution Process ##：", solution)
        print("-" * 100 + "\n")
        print("## Model Response ##：", model_answer)
        print("## Ground Truth ##：", ground_truth)
        print("## Scoring Details ##：\n", scoring_result_dict)
        print("## Score ##：", answer_reward)
        print("#" * 100 + "\n\n")

    reward_list = []
    for format_reward, think_reward, answer_reward, solution_str, turn_reward in zip(format_reward_list, think_reward_list, answer_reward_list, solution_str_list, turn_reward_list):
        format_reward = min(format_reward, 1.0)
        if format_reward >= 0.8:
            reward = -1.0 + format_reward + answer_reward + think_reward
        else:
            reward = -1.0 + format_reward
        reward_list.append(float(reward))
    
    return reward_list, format_reward_list, answer_reward_list, think_len_list, think_reward_list