import re
from multiprocessing import Pool

from verl.utils.reward_score.format_validation import validate_format, get_tool_call_times, \
                                                      search_query_repeat_nums, extract_json_from_markdown
from verl.utils.reward_score.f1_score import get_em_score

def extract_answer(text: str):
    text = text.strip()

    pattern = r"<answer>(.*?)</answer>"
    match = re.search(pattern, text, re.DOTALL)
    if not match:
        return None
    
    return match.group(1)

def remove_boxed(s):
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[:len(left)] == left
        return s[len(left):]

    left = "\\boxed{"

    assert s[:len(left)] == left
    assert s[-1] == "}"

    return s[len(left):-1]

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def answer_length_control(pred_answers, gt_answers, tolerance, max_penalty):
    pred_answer_lengths = [len(answer) for answer in pred_answers]
    if isinstance(gt_answers[0], str):
        gt_answer_lengths = [len(answer) for answer in gt_answers]
    elif isinstance(gt_answers[0], dict):
        gt_answer_lengths = [len(answer['answer']) for answer in gt_answers] + \
            [len(alias) for answer in gt_answers for alias in answer['aliases']]
    else:
        raise ValueError(f"Invalid gt_answers type: {type(gt_answers[0])}")
    max_pred_answer_length = max(pred_answer_lengths)
    max_gt_answer_length = max(gt_answer_lengths)
    overlong = max_pred_answer_length/max_gt_answer_length
    if overlong > tolerance:
        return -((max_pred_answer_length-max_gt_answer_length)/max_pred_answer_length * max_penalty)
    else:
        return 0

def get_multiple_answer_f1_score(reward_cfg, pred_answer, ground_truth, correct_count, return_dict, hit_count=None):
    gt_answer_count = len(ground_truth)
    pred_answer_count = len(pred_answer)
    precision = min(correct_count / pred_answer_count, 1)
    if hit_count == None:
        recall = min(correct_count / gt_answer_count, 1)
    else:
        recall = min(hit_count / gt_answer_count, 1)
    f1 = 2 * precision * recall / (precision + recall)
    final_score = 1 - reward_cfg.multiple_answer.answer_count.max_penalty * (1 - f1)
    return_dict['score_breakdown']['correct_answer_count'] = correct_count
    return_dict['score_breakdown']['gt_answer_count'] = gt_answer_count
    return_dict['score_breakdown']['pred_answer_count'] = pred_answer_count
    return_dict['score_breakdown']['precision'] = precision
    return_dict['score_breakdown']['recall'] = recall
    return_dict['score_breakdown']['f1_score'] = f1
    return final_score, return_dict
    
def compute_score(args) -> float:
    eos_token, question, solution_str, ground_truth, question_decomposition, reward_cfg, valid_response_length, max_resp_len, is_validate, model_generated_tokens, model_generated_tokens_str, sub_hypothesis = args
    
    assert sum([
        reward_cfg.single_answer.enable, 
        reward_cfg.single_json_answer.enable, 
        reward_cfg.multiple_answer.enable,
        reward_cfg.multiple_answer_in_plain_text.enable
    ]) == 1, "only one of single_answer, single_json_answer, multiple_answer, multiple_answer_in_plain_text can be enabled"
    
    reward_type = reward_cfg.reward_type if not is_validate else reward_cfg.validate_reward_type
    single_answer_mode = reward_cfg.single_answer.enable or reward_cfg.single_json_answer.enable
    multiple_answer_mode = reward_cfg.multiple_answer.enable or reward_cfg.multiple_answer_in_plain_text.enable
    
    return_dict = {
        'score': 0,
        'reason': '',
        'score_breakdown': {
            'format': 0,
        },
        'correctness': False,
        'call_search_times': 0,
        'multiple_answer_count': 0
    }
    
    if reward_cfg.format.type == "validate_format":
        validate_format_fn = validate_format
    else:
        raise ValueError(f"Invalid format type: {reward_cfg.format.type}")
    
    # handling both the base model and the instruction-tuned model
    if "<|im_start|>assistant\n" in solution_str:
        if solution_str.count("<|im_start|>assistant\n") == 1:
            response = solution_str.split("<|im_start|>assistant\n")[1]
        else:
            response = "<|im_start|>assistant\n".join(solution_str.split("<|im_start|>assistant\n")[1:])
    else:
        response = solution_str.split("Assistant:")[1]

    valid_template, reason = validate_format_fn(response, reward_cfg.tool_call_tag, reward_cfg.tool_response_tag)
    if not valid_template:
        return_dict['reason'] = f'bad format: {reason}'
        return return_dict

    if response.endswith(eos_token):
        response = response[:-len(eos_token)]
    else:
        return_dict['reason'] = f'over length'
        return return_dict

    answer_part = extract_answer(response)
    citations = []
    if answer_part is not None and reward_cfg.single_answer.enable:
        try:
            answer = remove_boxed(last_boxed_only_string(answer_part))
            return_dict['multiple_answer_count'] = 1
        except Exception as e:
            return_dict['reason'] = f'find box error: {e}'
            return return_dict
    elif answer_part is not None and reward_cfg.single_json_answer.enable:
        try:
            answer = extract_json_from_markdown(answer_part)['answer']
            if not isinstance(answer, str):
                raise ValueError(f'answer is not a string for single json answer')
            return_dict['multiple_answer_count'] = 1
        except Exception as e:
            return_dict['reason'] = f'find json error: {e}'
            return return_dict
    elif answer_part is not None and reward_cfg.multiple_answer.enable:
        try:
            extracted_answers = extract_json_from_markdown(answer_part)['answers']
            if len(extracted_answers) == 0:
                raise ValueError(f'no answer found')
            if reward_cfg.format.type == "validate_format_with_citation":
                answer = [i['answer'] for i in extracted_answers]
                citations = [c for i in extracted_answers for c in i['citations']]
                return_dict['score_breakdown']['citation_count'] = len(citations)
            else:
                answer = extracted_answers
            return_dict['multiple_answer_count'] = len(answer)
        except Exception as e:
            return_dict['reason'] = f'find json error: {e}'
            return return_dict
    elif answer_part is not None and reward_cfg.multiple_answer_in_plain_text.enable:
        try:
            answer = remove_boxed(last_boxed_only_string(answer_part))
            answer = [a.strip() for a in answer.split(';') if a.strip()]
            if len(answer) == 0:
                raise ValueError(f'no answer found')
            return_dict['multiple_answer_count'] = len(answer)
        except Exception as e:
            return_dict['reason'] = f'find plain text error: {e}'
            return return_dict
    else:
        return_dict['reason'] = f'cannot extract answer'
        return return_dict
    final_score = 0
    final_reason = ''
    correct_answer = None
    pred_answer = None
    return_dict['score_breakdown']['format'] = reward_cfg.format_reward
    if reward_type == "em_score":
        if multiple_answer_mode:
            if type(answer) == list:
                pred_answer = answer
            else:
                pred_answer = [answer]
        else:
            pred_answer = answer_part
        try:
            em_score, hit_count = get_em_score(answer, ground_truth, return_hit_count=True)
        except Exception as e:
            em_score = 0
            print(e)
        return_dict['score_breakdown']['em_score'] = em_score
        return_dict['score_breakdown']['hit_gt_count'] = hit_count
        if em_score == 0:
            final_score = reward_cfg.format_reward
            final_reason = f'wrong answer, em score: {em_score}, but good format: {answer_part}'
            correct_answer = False
        elif single_answer_mode:
            assert em_score == 1, f'em score: {em_score}, should be 1, answer: {answer}'
            final_score = em_score
            correct_answer = True
            final_reason = f'correct answer, get em score: {em_score}'
        elif multiple_answer_mode:
            assert 1 <= em_score <= len(answer), f'em score: {em_score}, should be between 1 and {len(answer)}, answers: {answer}'
            final_score, return_dict = get_multiple_answer_f1_score(reward_cfg, answer, ground_truth, em_score, return_dict, hit_count)
            correct_answer = True
            final_reason = f'correct answer, em score: {em_score}, gt answer count: {len(ground_truth)}, pred answer count: {len(answer)}, precision: {return_dict["score_breakdown"]["precision"]}, recall: {return_dict["score_breakdown"]["recall"]}, f1 score: {return_dict["score_breakdown"]["f1_score"]}'
        else:
            raise ValueError(f"Invalid reward type: {reward_type} for computing EM score")
    else:
        raise ValueError(f"Invalid reward type: {reward_type}")

    if is_binary_reward_type(reward_type):
        # assert final_score in [reward_cfg.format_reward, 1], f"final score: {final_score}, should be {reward_cfg.format_reward} or 1"
        assert correct_answer is not None, "correct answer should not be None for binary reward type"
    
    if multiple_answer_mode:
        assert pred_answer is not None, f"pred_answer should not be None, answer part: {answer_part}, answer: {answer}"
        assert type(pred_answer) == list, "pred_answer should be a list"
    
    if reward_cfg.query_repeat_penalty.enable:
        query_repeat_penalty = search_query_repeat_nums(response, reward_cfg.tool_call_tag) * reward_cfg.query_repeat_penalty.penalty_factor
        return_dict['score_breakdown']['query_repeat_penalty'] = query_repeat_penalty
        final_score = final_score - query_repeat_penalty
        final_reason = f'{final_reason}, original score: {final_score}, query repeat penalty: {query_repeat_penalty}, revised score: {final_score}'

    if final_score < 0:
        final_score = 0
        final_reason = f'{final_reason}, the final score is negative, set to 0'

    return_dict['score'] = final_score
    return_dict['reason'] = final_reason
    return_dict['correctness'] = correct_answer
    return_dict['call_search_times'] = get_tool_call_times(response, reward_cfg.tool_call_tag)
    return return_dict

def is_binary_reward_type(reward_type):
    return reward_type in ["em_score"]


def compute_score_parallel(eos_token, qids, questions, solution_strs, ground_truths, question_decompositions, reward_cfg, valid_response_lengths, max_resp_len, is_validate, model_generated_tokens, model_generated_tokens_str, optimal_ref, sub_hypotheses):
    with Pool(processes=16) as pool:
        scores = list(pool.imap(compute_score, zip(
            [eos_token] * len(solution_strs),
            questions,
            solution_strs,
            ground_truths,
            question_decompositions,
            [reward_cfg] * len(solution_strs),
            valid_response_lengths,
            [max_resp_len] * len(solution_strs),
            [is_validate] * len(solution_strs),
            model_generated_tokens,
            model_generated_tokens_str,
            sub_hypotheses
        )))
    return scores
