import random, os.path as osp, re, numpy as np
from argparse import ArgumentParser
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModel
from mmengine import load, dump
from collections import Counter


from utils.few_shot_prompts import prompt_fn
from utils.deepseek_api_utils import DecodingArguments, ChatBot
# from utils.deepseek_utils import DecodingArguments, get_tokens_number, ChatBot
# from utils.deepseek_api_utils import DecodingArguments, get_tokens_number, ChatBot
from utils.basic import remove_prefix, compute_metric, match_answers
from utils.cot_voting import get_top_k_voting
from output_preprocessing import extract_reasoning_steps, extract_yes_or_no
from extract_final_answer import extract_final_question
from itertools import chain
from tqdm import tqdm


def calculate_candidate_discard_score(candidate_steps):
    """
    Calculate the discard score for a single candidate thinking chain, the rules are as follows:
    - Conditional check: As long as "failed" is detected in any output, add 0.2
    - Logical check: As long as "failed" is detected in any output, add 0.5
    - Reverse verification check: Check only once, if "failed" is detected, add 0.3
    Parameter candidate_steps: All steps of the candidate thinking chain (list, each element is a dictionary containing various types of check outputs)
    Return the discard score of the candidate chain.
    """
    score = 0.0
    flat_steps = []
    for item in candidate_steps:
        if isinstance(item, dict):
            flat_steps.append(item)
        elif isinstance(item, list):
            flat_steps.extend([x for x in item if isinstance(x, dict)])

    # Condition Check
    condition_has_error = False
    for step in flat_steps:
        outputs = step.get("conditions_check_check_outputs", [])
        for out in outputs:
            # Using extract_yes_or_no to process a single string (wrapped in a list)
            _, filtered = extract_yes_or_no([out])
            if not all(filtered):  # If the output is judged as failed
                condition_has_error = True
                break
        if condition_has_error:
            break
    if condition_has_error:
        score += 0.4

    # Logic Check
    logic_has_error = False
    for step in flat_steps:
        outputs = step.get("logic_check_check_outputs", [])
        for out in outputs:
            _, filtered = extract_yes_or_no([out])
            if not all(filtered):
                logic_has_error = True
                break
        if logic_has_error:
            break
    if logic_has_error:
        score += 0.4

    # Reverse verification check (check only once)
    backward_has_error = False
    for step in flat_steps:
        outputs = step.get("backward_verification_check_outputs", [])
        for out in outputs:
            _, filtered = extract_yes_or_no([out])
            if not all(filtered):
                backward_has_error = True
                break
        if backward_has_error:
            break
    if backward_has_error:
        score += 0.2

    return score


# use the traditional "step by step" method to get the chain-of-thought reasoning process
def direct_reasoning(model_output, question):
    prompt = f'Here is a question and its solution:\n"Question:\n{question}\n\nAnswer:\n{model_output}"\n\nYou are a math teacher. Do you think the reasoning process is correct?\nLet\'s think step by step. End with "The reasoning process is".'
    return call_model(prompt)


def extract_reasoning_instructions(reasoning_chain):
    """
    Extracts the reasoning steps from the reasoning chain.
    The reasoning steps are in the form of 'Step X: ...' and are separated by newline characters.

    Args:
    reasoning_chain (str): The full reasoning chain text, containing steps.

    Returns:
    str: A string of reasoning steps, each step separated by '\n'.
    """
    # Regular expression to match each step in the reasoning chain
    reasoning_steps = re.findall(r"Step \d+[^#]*?(\n|$)", reasoning_chain)

    # Clean up and join the reasoning steps, separating them with '\n'
    instructions = "\n".join([step.strip() for step in reasoning_steps])

    return instructions


def verify_steps(steps, reasoning_steps, question):
    """
    1) Perform forward verification (conditions_check & logic_check) for each reasoning step.
    2) If all steps pass, perform a single backward verification to ensure the chain
       is aligned with the original question.

    :param steps: A list of the split reasoning text.
    :param reasoning_steps: A list of extracted step texts (including markers like "Step #X").
    :param question: The original question (true user question), used for backward verification.
    :return:
       ret: A list containing the verification results for each step and also the backward verification result.
       final_pass: A boolean indicating whether all forward checks and the single backward check pass.
    """
    from copy import deepcopy
    import numpy as np

    steps = deepcopy(steps)
    ret = []
    final_pass = True  # mark all checks as passed or not

    # Traverse each reasoning step
    for i, reasoning_step in enumerate(reasoning_steps):
        try:
            step_idx = eval(reasoning_step.split(" ")[0].strip("#.: "))

            if args.ref_end:
                reasoning_step_clean = remove_prefix(
                    reasoning_step.strip(), f"{step_idx}."
                ).strip(". ")
                parts = reasoning_step_clean.split("Reference")
                statement = parts[0].strip()
                grounding = parts[1].strip(". :") if len(parts) > 1 else ""

                # Remove the "Step #X" and other markers
                for g in re.findall("(Step \\d+:|#\\d+)", statement):
                    statement = statement.replace(g, "").strip()
            else:
                reasoning_step_clean = remove_prefix(
                    reasoning_step, f"#{step_idx}."
                ).strip()
                statement = ")".join(reasoning_step_clean.split(")")[1:]).strip()
                for g in re.findall("(Step \\d+:)", statement):
                    statement = statement.replace(g, "").strip()
                grounding_match = re.findall("\\((by[^)]+)\\)", reasoning_step_clean)
                grounding = grounding_match[0] if grounding_match else ""
                grounding = remove_prefix(grounding, "by").strip()

            steps[step_idx - 1] = statement

            if "-" in grounding:
                gstart, gend = grounding.split("-")
                grounding_list = list(range(int(gstart.strip(" #")), int(gend.strip(" #")) + 1))
            else:
                grounding_list = [int(num) for num in re.findall("\\d+", grounding)]

            # Extract all cited material information
            ground_materials = "\n".join([
                remove_prefix(steps[_ - 1], f"{_}. ").strip()
                for _ in grounding_list if 1 <= _ <= len(steps)
            ])

            statement_for_check = statement

            check_results = {}
            for check_mode in ["conditions_check", "logic_check"]:
                core_text = f"""This is the given information：
"{ground_materials}"

Based on the following information，the reasoning steps are：
"{statement_for_check}"
"""
                prompt = prompt_fn(core_text, check_mode)
                verify_model_outputs = call_model(prompt)
                check_results[f"{check_mode}_check_inputs"] = prompt
                check_results[f"{check_mode}_check_outputs"] = verify_model_outputs

                # Extracting "yes/no" judgments
                _, filtered_results = extract_yes_or_no(verify_model_outputs)
                if np.mean(filtered_results) < 0.5:
                    final_pass = False  # if any one check fails, the entire verification is marked as failed.

            ret.append(check_results)

        except Exception as e:
            print(f"An error occurred while parsing the step! Error：{e}")
            final_pass = False
            ret.append(None)

    # Perform reverse verification
    entire_chain = "\n".join(reasoning_steps)
    reasoning_instructions = extract_reasoning_instructions(entire_chain)

    # Generate reverse verification prompt information
    backward_prompt = f"""Here is a list of reasoning instructions:
"{reasoning_instructions}"

The original question is:
"{extract_final_question(question)}"
"""
    backward_prompt = prompt_fn(backward_prompt, "backward_verification")
    backward_check_output = call_model(backward_prompt)

    # Extracting the "yes/no" decision for reverse verification
    _, filtered_results = extract_yes_or_no(backward_check_output)
    backward_dict = {
        "backward_verification_check_inputs": backward_prompt,
        "backward_verification_check_outputs": backward_check_output
    }
    ret.append(backward_dict)

    # If reverse verification fails, mark final_pass as False
    if np.mean(filtered_results) < 0.5:
        final_pass = False

    return ret, final_pass


def run_verify_cot(model_output, question):
    steps, reasoning_steps, _ = extract_reasoning_steps(model_output)
    return verify_steps(steps, reasoning_steps, question)


def verification():
    random.seed(42)
    examples_be_skip = defaultdict(list)
    ## example would be skipped for wrong ground truth answer or extremely long reasoning steps(enumeration)
    examples_be_skip.update({"single_eq": [149, 209, 482], "aqua_rat": [213], "date": [297, 322], "MATH": [2874]})
    # count the number of correct verifications
    correct_in = 0
    if args.num_examples == -1 and args.task_name != "verify_step":
        ## verify questions with multi-majority answers by self-consistency
        error_examples = []
        correct_examples = []
        count_delta_size = defaultdict(int)
        majority_correct = 0
        for i, result_i in enumerate(input_result):
            deprecated_keys = [
                "majority_result",
            ]
            for key in deprecated_keys:
                if key in result_i:
                    result_i.pop(key)
            result_i["majority_answer_percentage"] = np.mean(result_i["majority_answer_percentage"])
            pred_answers = result_i["candidate_answers"]
            # Get the top 2 answers by vote, which evaluates the correctness of the answers through majority voting.
            top2_voting = get_top_k_voting(pred_answers, k=2)
            result_i["answers_need_verify"] = []

            sign = False
            for key in top2_voting:
                sign = sign or match_answers(key, result_i["final_answer"])
                result_i["answers_need_verify"] += top2_voting[key]

            majority_correct += result_i["majority_answer_percentage"]
            # If the index of the current sample is in the list to be skipped, or top2_voting is empty, then the sample is skipped.
            if result_i["index_number"] in examples_be_skip[data_name] or len(top2_voting) == 0:
                continue
            if len(top2_voting) >= 2 or len(pred_answers) == 1:
                keys = list(top2_voting.keys())

                min_key = keys[0]
                max_key = keys[-1]
                for key in top2_voting:
                    if len(top2_voting[key]) > len(top2_voting[max_key]):
                        max_key = key
                    if len(top2_voting[key]) < len(top2_voting[min_key]):
                        min_key = key
                min_size, max_size = len(top2_voting[min_key]), len(top2_voting[max_key])

                # Use list derivation to construct a string of each answer and the corresponding sample quantity in top2_voting, "key->quantity"
                result_i["need_verify_samples_info"] = ", ".join([f"{key}->{len(item)}" for key, item in top2_voting.items()])
                if (min_size >= 2 and max_size - min_size <= 2) or len(pred_answers) == 1:
                    correct_in += sign
                    # If the example has the correct majority answer
                    if result_i["majority_answer_percentage"] > 0:
                        correct_examples.append(result_i)
                    elif sign or len(pred_answers) == 1:
                        count_delta_size[max_size - min_size] += 1
                        error_examples.append(result_i)
        global_majority_correct = majority_correct
        print(f"#Number of test set: {len(input_result)}, #Number of correct: {majority_correct}, #Accuracy (%): {majority_correct / len(input_result) * 100:.2f}%")
        print(
            f"#Number of verified examples {len(correct_examples) + len(error_examples)}, #Number of correct examples {len(correct_examples)}, #Number of Wrong examples {len(error_examples)}, #Wrong ratio(%) {len(error_examples) / len(input_result) * 100:.2f}%"
        )
        count_delta_size = dict(count_delta_size)
        for key in sorted(count_delta_size.keys()):
            print(f"{key} -> {count_delta_size[key] / len(input_result) * 100:.2f}%")
        random.seed(42)
        random.shuffle(correct_examples)
        random.shuffle(error_examples)
        if args.wrong_only:
            examples = error_examples
        else:
            examples = correct_examples + error_examples
        random.shuffle(examples)
        results = []
    else:

        examples = input_result[: args.num_examples]
        random.shuffle(examples)
    existing_results = load(result_file) if osp.exists(result_file) else []

    existing_results = {result["index_number"]: result for result in existing_results}

    num_call = count = count0 = count1 = num = num0 = num1 = 0
    wrong2right = right2wrong = 0
    results = []
    if args.task_name == "Forward_Backward_Verification":
        for i, example in enumerate(tqdm(examples, desc="Verifying Examples", unit="example")):
            example_idx = example["index_number"]
            model_outputs = example["n_responses"]
            final_answer = example["final_answer"]
            per_sample_result = example["candidate_answers"]
            sample_idx_need_verify = example["answers_need_verify"]
            question = example["question"]
            per_sample_correct = example["per_answer_correct"]
            majority_correct = example["majority_answer_percentage"]
            results_need_verify = [per_sample_result[sample_idx_need_verify[j]] for j in
                                   range(len(sample_idx_need_verify))]
            if example_idx not in existing_results:
                num_call += 1
                example["verify_model_results"] = []
                # Count the number of times the candidate answers appear and select the top 2 answers that appear the most
                answer_counter = Counter(example["candidate_answers"])
                top2_answers = [ans for ans, count in answer_counter.most_common(2)]

                example["discard_scores"] = {key: 0 for key in top2_answers}



            else:
                tmp = existing_results.pop(example_idx)
                tmp["voting_results"] = example["voting_results"]
                tmp["majority_answer_percentage"] = example["majority_answer_percentage"]

               # 2025.5.9 revise
                if "verify_model_results" not in tmp:
                    tmp["verify_model_results"] = []

                example = tmp

                # 2025.5.9 revise
                example.setdefault("verify_model_results", [])

                # Make sure discard_scores exists (only consider the top two most common candidate answers)
                answer_counter = Counter(example.get("candidate_answers", []))
                top2_answers = [ans for ans, count in answer_counter.most_common(2)]
                if "discard_scores" not in example:
                    example["discard_scores"] = {key: 0 for key in top2_answers}
                example.setdefault("discard_scores_sum", defaultdict(float))
                example.setdefault("discard_scores_count", defaultdict(int))

            results.append(example)
            verify_model_results = example["verify_model_results"]

            for j, sample_idx in enumerate(sample_idx_need_verify):
                if len(verify_model_results) <= j:
                    model_output = model_outputs[sample_idx]
                    if args.verify_mode == "direct":
                        verify_model_results.append(direct_reasoning(model_output, question))
                    else:
                        verify_model_results.append(run_verify_cot(model_output, question))
                    dump(list(existing_results.values()) + results, str(result_file), indent=4)

                else:
                    verify_result = verify_model_results[j]


                candidate_steps = verify_model_results[j]
                candidate_discard_score = 0.0
                if args.verify_mode != "direct":
                    candidate_discard_score = calculate_candidate_discard_score(candidate_steps)
                    answer_text = per_sample_result[sample_idx]
                    if answer_text in example["discard_scores"]:
                        if not isinstance(example.get("discard_scores_sum"), defaultdict):
                            example["discard_scores_sum"] = defaultdict(float, example.get("discard_scores_sum", {}))

                        if not isinstance(example.get("discard_scores_count"), defaultdict):
                            example["discard_scores_count"] = defaultdict(int, example.get("discard_scores_count", {}))

                        example["discard_scores_sum"][answer_text] += candidate_discard_score
                        example["discard_scores_count"][answer_text] += 1

            example["verify_correct"] = []
            example["verify_result"] = []
            # count the number of verification results
            cnt = 1
            for sample_idx, sample_verify_result in zip(sample_idx_need_verify, verify_model_results):
                all_steps_results = []
                all_step_correct_raw = []
                if args.verify_mode == "direct":
                    correct_raw, parsed_outputs_filtered = extract_yes_or_no(sample_verify_result)
                    all_steps_results.append(np.mean(parsed_outputs_filtered) > 0.5)
                    all_step_correct_raw.append(correct_raw)
                else:
                    for step_result in sample_verify_result:
                        if step_result is None:
                            all_steps_results.append(True)
                            all_step_correct_raw.append(None)
                            continue

                        if args.verify_mode == "forward_backward":
                            correct_raw = []
                            step_corrects = []
                            if isinstance(step_result, bool):
                                step_result = []
                                i += 1
                            for step in step_result:
                                if step is None:
                                    continue
                                for k, v in step.items():
                                    if "input" in k:
                                        continue
                                    correct_raw_i, parsed_outputs_filtered = extract_yes_or_no(v)
                                    step_correct = np.mean(parsed_outputs_filtered) > 0.5
                                    correct_raw.append(correct_raw_i)
                                    step_corrects.append(step_correct)

                                    forward_all_pass = np.all(step_corrects)
                                    if forward_all_pass and "backward_verification_check_outputs" in step_result:
                                        backward_raw, backward_parsed = extract_yes_or_no(
                                            step_result["backward_verification_check_outputs"]
                                        )
                                        backward_correct = np.mean(backward_parsed) > 0.5
                                        forward_all_pass = forward_all_pass and backward_correct
                                        correct_raw.append(backward_raw)
                                    step_correct = forward_all_pass

                        else:
                            step_result = step_result["verify_model_outputs"]
                            correct_raw, parsed_outputs_filtered = extract_yes_or_no(step_result)
                            step_correct = np.mean(parsed_outputs_filtered) > 0.5

                        all_steps_results.append(step_correct)
                        all_step_correct_raw.append(correct_raw)
                example["verify_correct"].append(np.all(all_steps_results))
                example["verify_result"].append(all_step_correct_raw)

                verify_flags = example["verify_correct"]
                pred_answers = []
                for j in range(min(len(results_need_verify), len(verify_flags))):
                    if verify_flags[j]:
                        pred_answers.append(results_need_verify[j])

                pred_answer_keys = []

                for j in range(min(len(results_need_verify), len(verify_flags))):
                    if verify_flags[j]:
                        pred_answer_keys.append(per_sample_result[sample_idx_need_verify[j]])

                if not pred_answers:
                    print("All the answers are verified to be wrong !!")
                    continue

                weighted_discard_scores = {}

                for ans in example["discard_scores_sum"]:
                    total_score = example["discard_scores_sum"][ans]
                    count = example["discard_scores_count"][ans]
                    weighted_discard_scores[ans] = total_score / count if count > 0 else float("inf")
                example["discard_scores"] = weighted_discard_scores
                print(f"==> Discard scores for example {example['index_number']}:")
                for ans, score in weighted_discard_scores.items():
                    print(f"   - Answer: {repr(ans)} -> Weighted discard score: {score:.4f}")

                majority_result = min(weighted_discard_scores, key=weighted_discard_scores.get)

                majority_corrects = match_answers(majority_result, final_answer)
                majority_correct = np.mean(majority_corrects)

                num += 1
                count += int(np.all(all_steps_results) == per_sample_correct[sample_idx])
                if per_sample_correct[sample_idx]:
                    count1 += int(np.all(all_steps_results) == per_sample_correct[sample_idx])
                    num1 += 1
                else:
                    count0 += int(np.all(all_steps_results) == per_sample_correct[sample_idx])
                    num0 += 1


            verify_flags = example["verify_correct"]
            pred_answers = []
            for j in range(min(len(results_need_verify), len(verify_flags))):
                if verify_flags[j]:
                    pred_answers.append(results_need_verify[j])

            if not len(per_sample_result):
                continue


            global_majority_correct = global_majority_correct - example["majority_answer_percentage"] + majority_correct
            if majority_correct > example["majority_answer_percentage"]:
                wrong2right += 1
            elif majority_correct < example["majority_answer_percentage"]:
                right2wrong += 1
        example["final_majority_result"] = majority_result[0] if isinstance(majority_result, list) and len(majority_result) > 0 else majority_result
        example["final_majority_correct"] = float(majority_correct)
        example["wrong_to_right"] = int(wrong2right)
        example["right_to_wrong"] = int(right2wrong)

        results.append(example)

        dump(list(existing_results.values()) + results, str(result_file), indent=4)
    else:
        raise NotImplementedError


if __name__ == "__main__":
    parser = ArgumentParser(description="Evaluate Verification Performance")
    parser.add_argument("--model-name", default="deepseek-v3", type=str)
    parser.add_argument("--data-name", default="gsm8k", type=str)
    parser.add_argument("--input-result", default="./deepseek_results/gsm8k/gsm8k_reasoning.json", type=str)
    parser.add_argument("--output-result", default="./deepseek_results/gsm8k/gsm8k_verification.json", type=str)
    parser.add_argument("--num-examples", default=-1, type=int)
    parser.add_argument("--wrong-only", default=False, type=bool)
    parser.add_argument("--max-seq-len", default=2048, type=int)
    parser.add_argument("--task-name", default="Forward_Backward_Verification", type=str)
    parser.add_argument("--verify-mode", default="forward_backward", type=str)
    parser.add_argument("--ref-end", default=False, type=bool)
    parser.add_argument("--tag", default=None, type=str)
    parser.add_argument("--n", default=1, type=int)
    parser.add_argument("--greedy", action="store_true", default=False)
    args = parser.parse_args()
    # direct: zero-shot prompt to verify the whole process,
    # forward_backward: one-shot prompt to verify in order (forward condition, forward logic, backward verification)
    assert args.verify_mode in ["direct", "forward_backward"], f"{args.verify_mode} mode is not supported!"
    data_name, model_name = args.data_name, args.model_name
    model_print_name = (f"chat-{model_name}" if model_name in ChatBot.MODEL_TYPE else model_name) if args.tag is None else args.tag
    input_result = load(args.input_result)

    result_file = Path(args.output_result)
    results = load(result_file) if result_file.exists() else []
    num_call = 0

    sample_n = 1 if args.greedy else args.n
    temperature = 0 if args.greedy else 0.7
    # decoding_args = OpenAIDecodingArguments(max_tokens=args.max_seq_len, n=sample_n, temperature=temperature)
    decoding_args = DecodingArguments(max_tokens=args.max_seq_len, n=sample_n, temperature=temperature)

    print(f"The results have been already saved in {str(result_file)}; Dataset: {data_name}; Used Method: {args.task_name}.")


    def call_model(prompt):
          return ChatBot.call_chat_deepseek(
            prompt,
            eos_pattern=None,
            max_new_tokens=args.max_seq_len,
            early_stopping=True,
            do_sample=not args.greedy,
            return_list=True,
            temperature=temperature,
            num_beams=sample_n,
            num_return_sequences=sample_n,
            decoding_args=decoding_args,
        )

    ChatBot.dataset_name = data_name
    ChatBot.init()
    verification()