import re
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from sympy import sympify, simplify
import math
from .reward_utils.math import accuracy_reward

# def think_format_reward_func(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
#     r"""
#     A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.

#     Example:
#     ```python
#     >>> from trl.rewards import think_format_reward

#     >>> completions = [
#     ...     [{"content": "\nThis is my reasoning.\n</think>\nThis is my answer."}],
#     ...     [{"content": "\nThis is my reasoning.\nThis is my answer."}],
#     ... ]
#     >>> think_format_reward(completions)
#     [1.0, 0.0]
#     ```
#     """
#     # pattern = r"^<think>(?!.*<think>)(.*?)</think>.*$" # we insert <think> in prompt
#     # pattern = r"^(?!.*<think>)(.*?)</think>.*$" # so we don't need to match <think> at the start
#     pattern = r"^(?!.*<think>)(.*?)</think>(?!.*</think>).*$"
#     completion_contents = [completion[0]["content"] for completion in completions]
#     matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]

#     return [1.0 if match else 0.0 for match in matches]

def think_format_reward_func(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
    rewards = []
    for completion in completions:
        content = completion[0]["content"]
        # Rules:
        # 1) The number of </think> must be exactly 1
        # 2) The number of <think> must be exactly 0
        if content.count("</think>") == 1 and content.count("<think>") == 0:
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards

def answer_accuracy_reward_func(completions, reward_model, **kwargs):
    rewards = []
    for completion, gt in zip(completions, reward_model):
        content = completion[0]["content"]
        # get string after <think>
        think_close_idx = content.find("</think>")
        if think_close_idx != -1:
            content = content[think_close_idx + len("</think>"):].strip()
        rewards.append(accuracy_reward(content, gt["ground_truth"]))
    return rewards
def no_reference_answer_leakage_reward_func(completions, reward_model=None, **kwargs):
    """
    Reward function: Check whether the literal substring "reference answer" (case-insensitive)
    does not appear in the output segment after the first closing `</think>` tag.

    Rules:
      - Find the position of the first `</think>` (pos); only check `content[pos:]` for the presence of "reference answer".
      - If `</think>` does not exist, check from the beginning of the text (for compatibility; alternatively one could return 0.0 directly).
      - If "reference answer" does NOT appear → 1.0; if it appears → 0.0.

    Returns: A rewards list of the same length as `completions`.
    """
    rewards = []
    target = "reference answer"
    for completion in completions:
        content = completion[0]["content"]
        lower_content = content.lower()
        think_close_idx = lower_content.find("</think>")
        # Fix: when not found, start checking from the beginning
        start_idx = 0 if think_close_idx == -1 else think_close_idx + len("</think>")
        tail = lower_content[start_idx:]
        if target in tail:
            rewards.append(0.0)
        else:
            rewards.append(1.0)
    return rewards

def valid_reasoning_reward_func(completions, reward_model=None, **kwargs):
    """
    Reward function: In the q(z|x,y) setting, check whether there is substantive
    reasoning after the closing </think> tag.
    """
    rewards = []
    for completion, gt in zip(completions, reward_model):
        content = completion[0]["content"]
        think_close_idx = content.find("</think>")
        
        if think_close_idx == -1:
            rewards.append(0.0)
        else:
            before_think = content[:think_close_idx].strip()
            after_think = content[think_close_idx+len("</think>"):].strip()

            length_ratio_to_think  = len(after_think) / (len(before_think) + 1)
            length_ratio_to_answer = (len(after_think)-1) / len(f"\\boxed{{{gt['ground_truth']}}}")
            if length_ratio_to_think > 0.01 and length_ratio_to_answer > 1:
                rewards.append(1.0)
            else:
                rewards.append(0.0)
    return rewards

def answer_accuracy_reward_func_for_naive_grpo(completions, reward_model=None, **kwargs):
    """
    Reward function: Provide a score based on the match between the model's answer
    and the reference/ground-truth answer.
    """
    new_completions = []
    for completion in completions:
        # our think_format_reward_func expects content to not start with <think>
        content = completion[0]["content"]
        content = content.removeprefix("<think>")
        completion[0]["content"] = content
        new_completions.append(completion)

    completions = new_completions
    acc_rewards = answer_accuracy_reward_func(completions, reward_model, **kwargs)
    return acc_rewards

def think_format_reward_func_for_naive_grpo(completions, **kwargs):
    """
    Reward function: Reward based on whether the output follows the expected
    think/answer format for the naive GRPO setting.
    """
    new_completions = []
    for completion in completions:
        # our think_format_reward_func expects content to not start with <think>
        content = completion[0]["content"]
        content = content.removeprefix("<think>")
        completion[0]["content"] = content
        new_completions.append(completion)

    completions = new_completions
    format_rewards = think_format_reward_func(completions, **kwargs)
    return format_rewards

def combine_think_format_and_answer_accuracy_reward_func(completions, reward_model=None, **kwargs):
    new_completions = []
    for completion in completions:
        # our think_format_reward_func expects content to not start with <think>
        content = completion[0]["content"]
        content = content.removeprefix("<think>")
        completion[0]["content"] = content
        new_completions.append(completion)

    completions = new_completions

    format_rewards = think_format_reward_func(completions, **kwargs)
    acc_rewards = answer_accuracy_reward_func(completions, reward_model, **kwargs)
    # if format is 0, reward is -1.0
    # if format is 1, acc_rewards is 0.0, reward is 0.0
    # if format is 1, acc_rewards is 1.0, reward is 1.0
    rewards = []
    for f, a in zip(format_rewards, acc_rewards):
        if f == 0.0:
            rewards.append(-1.0)
        else:
            rewards.append(a)
    return rewards

# Unified test data builder to avoid duplication
def _build_test_data():
    completions = [
        # 0
        [{"content": "\nReasoning\n</think> Answer as follows: \\boxed{42}"}],
        # 1
        [{"content": "\nReasoning not ended  \\boxed{42}"}],
        # 2
        [{"content": "\nReasoning\n</think> \\boxed{41}"}],
        # 3
        [{"content": "\nReasoning\n</think> give one \\boxed{42} then another \\boxed{42}"}],
        # 4
        [{"content": "\nReasoning\n</think> reference answer is: \\boxed{42}"}],
        # 5
        [{"content": "\nIs 0.5 equivalent to fraction form?\n</think> \\boxed{0.5}"}],
        # 6
        [{"content": "\nWrite as fraction\n</think> finally \\boxed{\\frac{1}{2}}"}],
        # 7
        [{"content": "\nNot simplified\n</think> \\boxed{\\frac{2}{4}}"}],
        # 8
        [{"content": "\nScientific notation\n</think> \\boxed{3.2\\times 10^{5}}"}],
        # 9
        [{"content": "\nAnother notation\n</think> \\boxed{3.2e5}"}],
        # 10
        [{"content": "\nAnalyze four choices\n</think> Conclusion \\boxed{A}"}],
        # 11
        [{"content": "\nAnalyzed but wrong\n</think> \\boxed{C}"}],
        # 12
        [{"content": "\nExpand (a+b)^2\n</think> \\boxed{(a+b)^2}"}],
        # 13
        [{"content": "\nWrite expansion directly\n</think> \\boxed{a^2+2ab+b^2}"}],
        # 14
        [{"content": "\nNested test\n</think> \\boxed{\\frac{(a+b)^2}{2}}"}],
        # 15
        [{"content": "\nEquivalent nested form\n</think> \\boxed{(a^2+2ab+b^2)/2}"}],
        # 16
        [{"content": "\nMissing term\n</think> \\boxed{(a^2+b^2)/2}"}],
        # 17
        [{"content": "\nUnclosed test\n</think> here \\boxed{42"}],
        # 18
        [{"content": "\nNoise\n</think> some notes: [value] => \\boxed{1/2}"}],
        # 19
        [{"content": "\nMultiple candidates\n</think> \\boxed{1/2} backup \\boxed{0.5}"}],
        # 20
        [{"content": "\nShow text <\\/think> without actual closing\nStill thinking \\boxed{42}"}],
        # 21
        [{"content": "\nInclude escaped <\\/think> example\n</think> \\boxed{42}"}],
        # 22
        [{"content": "\nExample: \"</think>\" is just display\n</think> \\boxed{42}"}],
        # 23
        [{"content": "one step</think> middle <think>two steps</think> \\boxed{42}"}],
        # 24
        [{"content": "outer inner</think> end</think> \\boxed{42}"}],
        # 25
        [{"content": "random text </think> \\boxed{42}"}],
        # 26
        [{"content": "empty content</think> \\boxed{}"}],
        # 27
        [{"content": "blank</think> \\boxed{   }"}],
        # 28
        [{"content": "nested</think> \\boxed{1+\\boxed{2}}"}],
        # 29
        [{"content": "implicit multiplication</think> \\boxed{2(a+b)}"}],
        # 30
        [{"content": "unicode multiply sign</think> \\boxed{3.2×10^{5}}"}],
        # 31
        [{"content": "middle dot</think> \\boxed{3.2·10^{5}}"}],
        # 32
        [{"content": "E notation</think> \\boxed{3.2E5}"}],
        # 33
        [{"content": "Expand (a-b)^2</think> \\boxed{a^2-2ab+b^2}"}],
        # 34
        [{"content": "Wrong expansion</think> \\boxed{a^2+2ab+b^2}"}],
        # 35
        [{"content": "Identity</think> \\boxed{((a+b)^2-(a-b)^2)/(4ab)}"}],
        # 36
        [{"content": "Identity direct</think> \\boxed{1}"}],
        # 37
        [{"content": "spaced fraction</think> \\boxed{\\frac{2}{4}}"}],
        # 38
        [{"content": "leading zeros</think> \\boxed{00042}"}],
        # 39
        [{"content": "plus sign</think> \\boxed{+42}"}],
        # 40
        [{"content": "difference is 0</think> \\boxed{0}"}],
        # 41 Missing </think> and contains reference answer (should trigger no_reference penalty)
        [{"content": "unclosed reference answer is: \\boxed{42}"}],
        # 42 Nested boxed equivalent but should be 0
        [{"content": "nested equivalent</think> \\boxed{1+\\boxed{1}}"}],
        # 43 Approximate decimal vs fraction 0.3333333 vs 1/3 (treated as equivalent)
        [{"content": "approximate fraction</think> \\boxed{0.3333333}"}],
        # 44 Negative exponent scientific notation
        [{"content": "negative exponent</think> \\boxed{1.2\\times 10^{-3}}"}],
        # 45 reference answer appears only inside think; should not be penalized
        [{"content": "reference answer is: note</think> \\boxed{42}"}],
        # 46 Answer contains extra text; test answer_accuracy_reward_func
        [{"content": "reference answer is: note</think> \\boxed{42}"}],
        # 47 Prediction contains extra text while GT doesn't; test answer_accuracy_reward_func
        [{"content": "reference answer is: note</think> \\boxed{42}"}],
        # 48
        [{"content": "reference answer is: note</think> \\boxed{16}"}],
        # 49
        [{"content": "reference answer is: note</think> \\boxed{1/42}"}],
        # 50
        [{"content": "reference answer is: note</think> \\boxed{F}"}],
        # 51
        [{"content": "reference answer is: note</think> \\boxed{24\pi}"}],
        # 52
        [{"content": "reference answer is: note</think> \\boxed{23.0} \\boxed{23}"}]
    ]
    reward_model = [
        {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"},
        {"ground_truth": "1/2"}, {"ground_truth": "1/2"}, {"ground_truth": "1/2"}, {"ground_truth": "3.2e5"}, {"ground_truth": "3.2e5"},
        {"ground_truth": "A"}, {"ground_truth": "A"}, {"ground_truth": "(a+b)^2"}, {"ground_truth": "(a+b)^2"}, {"ground_truth": "(a+b)^2/2"},
        {"ground_truth": "(a+b)^2/2"}, {"ground_truth": "(a+b)^2/2"}, {"ground_truth": "42"}, {"ground_truth": "1/2"}, {"ground_truth": "1/2"},
        {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"},
        {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "1+2"}, {"ground_truth": "2*(a+b)"},
        {"ground_truth": "3.2E5"}, {"ground_truth": "3.2E5"}, {"ground_truth": "3.2E5"}, {"ground_truth": "(a-b)^2"}, {"ground_truth": "(a-b)^2"},
        {"ground_truth": "1"}, {"ground_truth": "1"}, {"ground_truth": "1/2"}, {"ground_truth": "42"}, {"ground_truth": "42"}, {"ground_truth": "0"},
        {"ground_truth": "42"},  # 41
        {"ground_truth": "2"},   # 42 nested equivalent but should be 0
        {"ground_truth": "1/3"}, # 43
        {"ground_truth": "1.2e-3"}, # 44
        {"ground_truth": "42"},  # 45
        {"ground_truth": "42cm^2"},  # 46 answer contains extra text; test answer_accuracy_reward_func
        {"ground_truth": "42"},  # 47 answer contains extra text; test answer_accuracy_reward_func
        {"ground_truth": "16 \\text{ square feet}"},
        {"ground_truth": "\\frac{1}{42}"},
        {"ground_truth": "F"},
        {"ground_truth": "24\pi"},
        {"ground_truth": "23.0"},
    ]
    expected_think_format = [
        1.0,0.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0,
        1.0,1.0,1.0,1.0,1.0, 0.0,1.0,0.0,0.0,0.0, 1.0,1.0,1.0,1.0,1.0,
        1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,
        0.0, # 41 unclosed
        1.0, # 42
        1.0, # 43
        1.0, # 44
        1.0, # 45
        1.0, # 46 answer contains extra text; test answer_accuracy_reward_func
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0
    ]
    expected_accuracy = [
        1.0,1.0,0.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,0.0,1.0,1.0,1.0,
        1.0,0.0,0.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,0.0,0.0,0.0,1.0,
        1.0,1.0,1.0,1.0,0.0, 1.0,1.0,1.0,1.0,1.0, 1.0,
        1.0, # 41 boxed unique and value equals
        0.0, # 42 nested boxed -> 0
        1.0, # 43 approximate decimal accepted as equivalent 0.3333333 ~= 1/3
        1.0, # 44 negative exponent equivalent
        1.0, # 45 equivalent
        1.0, # 46 answer contains extra text; test answer_accuracy_reward_func
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0
    ]
    expected_no_reference = [
        1.0,1.0,1.0,1.0,0.0, 1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0,
        1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0,
        1.0,1.0,1.0,1.0,1.0, 1.0,1.0,1.0,1.0,1.0, 1.0,
        0.0, # 41 unclosed but contains the phrase
        1.0, # 42
        1.0, # 43
        1.0, # 44
        1.0, # 45 phrase only inside think
        1.0, # 46 answer contains extra text; test answer_accuracy_reward_func
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0
    ]
    return completions, reward_model, expected_think_format, expected_accuracy, expected_no_reference


def validate_test_cases():
    completions, reward_model, expected_think_format, expected_accuracy, expected_no_reference = _build_test_data()
    print("=== Validate reward_func.py test cases ===")
    think_rewards = think_format_reward_func(completions)
    acc_rewards = answer_accuracy_reward_func(completions, reward_model)
    ref_rewards = no_reference_answer_leakage_reward_func(completions)
    print(f"Total number of test cases: {len(completions)}")
    print(f"think_format_reward results: {think_rewards}")
    print(f"answer_accuracy_reward results: {acc_rewards}")
    print(f"no_reference_answer_reward results: {ref_rewards}")
    think_correct = think_rewards == expected_think_format
    acc_correct = acc_rewards == expected_accuracy
    ref_correct = ref_rewards == expected_no_reference
    print("\n=== Validation results ===")
    print(f"think_format_reward correct: {think_correct}")
    print(f"answer_accuracy_reward correct: {acc_correct}")
    print(f"no_reference_answer_reward correct: {ref_correct}")
    if not think_correct:
        print("\nthink_format_reward mismatches:")
        for i, (actual, expected) in enumerate(zip(think_rewards, expected_think_format)):
            if actual != expected:
                print(f"  Case {i}: actual={actual}, expected={expected}")
    if not acc_correct:
        print("\nanswer_accuracy_reward mismatches:")
        for i, (actual, expected) in enumerate(zip(acc_rewards, expected_accuracy)):
            if actual != expected:
                print(f"  Case {i}: actual={actual}, expected={expected}")
    if not ref_correct:
        print("\nno_reference_answer_reward mismatches:")
        for i, (actual, expected) in enumerate(zip(ref_rewards, expected_no_reference)):
            if actual != expected:
                print(f"  Case {i}: actual={actual}, expected={expected}")
    all_correct = think_correct and acc_correct and ref_correct
    print("\n=== Summary ===")
    print(f"All test cases correct: {all_correct}")
    if all_correct:
        print("✅ Test cases in reward_func.py are designed correctly; all reward functions behave as expected!")
    else:
        print("❌ Issues found in test cases or reward function implementations; further investigation needed.")
    return all_correct


if __name__ == "__main__":
    print("=== Raw test outputs (unified data source) ===")
    completions, reward_model, expected_think_format, expected_accuracy, expected_no_reference = _build_test_data()
    print("think_format_reward:", think_format_reward_func(completions))
    acc_rewards = answer_accuracy_reward_func(completions, reward_model)
    print("answer_accuracy_reward:", acc_rewards)
    print(f"answer_accuracy_reward summary: total={sum(acc_rewards)} / {len(acc_rewards)}")
    ref_rewards = no_reference_answer_leakage_reward_func(completions)
    print("no_reference_answer_in_response_reward_func:", ref_rewards)
    print(f"no_reference summary: total={sum(ref_rewards)} / {len(ref_rewards)}")
    print("\n" + "="*50)
    validate_test_cases()
