# CODE adapted from https://github.dev/huggingface/open-r1
"""Reward functions for GRPO training."""

import math
import re
from typing import Callable, Dict, Optional, Literal
from dataclasses import dataclass, field

from math_verify import LatexExtractionConfig, parse, verify

from bbeh_evaluator import evaluate_correctness


def normalize_gold_answer(answer: str) -> str:
    """Normalize gold answers to reduce parse failures for math evaluators.

    Steps:
    - strip math delimiters and \text{...}
    - handle special cases (intervals, fractions, etc.)
    - try LaTeX -> sympy -> canonical string
    - normalize MC choices like (B) -> B
    """
    text = str(answer).strip()
    if not text:
        return text
    
    # Remove surrounding $...$ or \[ \] delimiters
    text = re.sub(r"\$(.*?)\$", r"\1", text)
    text = re.sub(r"\\\[(.*?)\\\]", r"\1", text, flags=re.S)
    
    # Handle special cases before general processing
    # 1. Interval notation: (a,b], [a,b), (a,b), [a,b]
    interval_match = re.match(r"([\[\(])\s*([^,]+)\s*,\s*([^,\)\]]+)\s*([\]\)])", text)
    if interval_match:
        left_bracket, left_val, right_val, right_bracket = interval_match.groups()
        # Keep intervals as-is for now, just clean up spaces
        text = f"{left_bracket}{left_val},{right_val}{right_bracket}"
        # Convert to boxed format
        text = f"\\boxed{{{text}}}"
        return text
    
    # 2. Multiple choice answers: \text{(A)}, (A), A, etc.
    mc_match = re.fullmatch(r"\\text\s*\{\(?\s*([A-Ea-e])\s*\)?\}", text)
    if mc_match:
        return f"\\boxed{{{mc_match.group(1).upper()}}}"
    
    # 3. Simple choice letters: (A), A), [A], etc.
    simple_mc_match = re.fullmatch(r"\(?\s*([A-Ea-e])\s*\)?", text.strip())
    if simple_mc_match:
        return f"\\boxed{{{simple_mc_match.group(1).upper()}}}"
    
    # 4. Text expressions: \text{word} -> word
    text_match = re.match(r"\\text\s*\{([^}]+)\}", text)
    if text_match:
        return f"\\boxed{{{text_match.group(1).strip()}}}"
    
    # Remove \text{...} wrappers (general case)
    text = re.sub(r"\\text\s*\{([^}]*)\}", r"\1", text)
    
    # Handle LaTeX fractions: \frac{a}{b} -> a/b
    text = re.sub(r"\\frac\s*\{([^}]+)\}\s*\{([^}]+)\}", r"\1/\2", text)
    text = re.sub(r"\\frac\s*([^{}\s]+)\s*\{([^}]+)\}", r"\1/\2", text)
    text = re.sub(r"\\frac\s*\{([^}]+)\}\s*([^{}\s]+)", r"\1/\2", text)
    text = re.sub(r"\\frac\s*([^{}\s]+)\s*([^{}\s]+)", r"\1/\2", text)
    
    # Handle square roots: \sqrt{x} -> sqrt(x)
    text = re.sub(r"\\sqrt\s*\{([^}]+)\}", r"sqrt(\1)", text)
    text = re.sub(r"\\sqrt\s*([^{}\s]+)", r"sqrt(\1)", text)
    
    # Handle pi: \pi -> pi
    text = text.replace("\\pi", "pi")
    
    # Handle infinity: \infty -> oo
    text = text.replace("\\infty", "oo")
    
    # Remove LaTeX thousands separators like \, and \!
    text = text.replace("\\,", "").replace("\\!", "")
    
    # Remove common units/words but preserve commas in intervals
    text = re.sub(r"\bdegrees?\b", "", text, flags=re.I)
    text = text.strip()
    
    # Try LaTeX parse for remaining complex expressions
    # Only try this if the text still contains LaTeX commands
    if '\\' in text:
        try:
            from latex2sympy2_extended import latex2sympy
            import sympy as sp
            expr = latex2sympy(text)
            # Try to simplify to a canonical form
            expr = sp.simplify(expr)
            # If numeric, keep exact rational or simplified form
            text = str(expr)
        except Exception:
            pass  # Keep original text if parsing fails
    
    # Convert to \boxed{} format that math_verify.parse() expects
    # Only if it's not already in boxed format
    if not text.startswith('\\boxed{') and not text.startswith('$'):
        text = f'\\boxed{{{text}}}'
    
    return text


def correctness_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[float]:
    """Reward function that uses the bbeh_evaluator for correctness checking."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        try:
            is_correct = evaluate_correctness(content, sol)
            reward = 1.0 if is_correct else 0.0
        except Exception as e:
            print(f"evaluate_correctness failed: {e}, content: {content}, solution: {sol}")
            reward = 0.0
        rewards.append(reward)
    return rewards


def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        # Normalize the gold solution before parsing
        normalized_sol = normalize_gold_answer(sol)
        gold_parsed = parse(
            normalized_sol,
            extraction_mode="first_match",
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[LatexExtractionConfig()],
                extraction_mode="first_match",
            )
            # Compute binary rewards if verifiable, `None` otherwise to skip this example
            try:
                reward = float(verify(gold_parsed, answer_parsed))
            except Exception as e:
                print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
                reward = None
        else:
            # If the gold solution is not parseable, we assign `None` to skip this example
            reward = None
            print("Failed to parse gold solution: ", sol, " (normalized: ", normalized_sol, ")")
        rewards.append(reward)
        # print(rewards)

    return rewards



def len_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[float]:
    """Compute length-based rewards to discourage overthinking and promote token efficiency.

    Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599

    Args:
        completions: List of model completions
        solution: List of ground truth solutions

    Returns:
        List of rewards where:
        - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
        - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
    """
    contents = [completion[0]["content"] for completion in completions]

    # First check correctness of answers
    correctness = []
    for content, sol in zip(contents, solution):
        # Normalize the gold solution before parsing
        normalized_sol = normalize_gold_answer(sol)
        gold_parsed = parse(
            normalized_sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) == 0:
            # Skip unparseable examples
            correctness.append(True)  # Treat as correct to avoid penalizing
            print("Failed to parse gold solution: ", sol, " (normalized: ", normalized_sol, ")")
            continue

        answer_parsed = parse(
            content,
            extraction_config=[LatexExtractionConfig()],
            extraction_mode="first_match",
        )
        try:
            correctness.append(verify(answer_parsed, gold_parsed))
        except Exception as e:
            print(f"verify failed in len_reward: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
            correctness.append(False)  # Treat as incorrect if verification fails

    # Calculate lengths
    lengths = [len(content) for content in contents]
    min_len = min(lengths)
    max_len = max(lengths)

    # If all responses have the same length, return zero rewards
    if max_len == min_len:
        return [0.0] * len(completions)

    rewards = []
    for length, is_correct in zip(lengths, correctness):
        lambda_val = 0.5 - (length - min_len) / (max_len - min_len)

        if is_correct:
            reward = lambda_val
        else:
            reward = min(0, lambda_val)

        rewards.append(float(reward))

    return rewards


def get_cosine_scaled_reward(
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    def cosine_scaled_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[float]:
        """Reward function that scales based on completion length using a cosine schedule.

        Shorter correct solutions are rewarded more than longer ones.
        Longer incorrect solutions are penalized less than shorter ones.

        Args:
            completions: List of model completions
            solution: List of ground truth solutions

        This function is parameterized by the following arguments:
            min_value_wrong: Minimum reward for wrong answers
            max_value_wrong: Maximum reward for wrong answers
            min_value_correct: Minimum reward for correct answers
            max_value_correct: Maximum reward for correct answers
            max_len: Maximum length for scaling
        """
        contents = [completion[0]["content"] for completion in completions]
        rewards = []

        for content, sol in zip(contents, solution):
            # Normalize the gold solution before parsing
            normalized_sol = normalize_gold_answer(sol)
            gold_parsed = parse(
                normalized_sol,
                extraction_mode="first_match",
                extraction_config=[LatexExtractionConfig()],
            )
            if len(gold_parsed) == 0:
                rewards.append(1.0)  # Skip unparseable examples
                print("Failed to parse gold solution: ", sol, " (normalized: ", normalized_sol, ")")
                continue

            answer_parsed = parse(
                content,
                extraction_config=[LatexExtractionConfig()],
                extraction_mode="first_match",
            )

            try:
                is_correct = verify(answer_parsed, gold_parsed)
            except Exception as e:
                print(f"verify failed in cosine_scaled_reward: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
                is_correct = False  # Treat as incorrect if verification fails
                
            gen_len = len(content)

            # Apply cosine scaling based on length
            progress = gen_len / max_len
            cosine = math.cos(progress * math.pi)

            if is_correct:
                min_value = min_value_correct
                max_value = max_value_correct
            else:
                # Swap min/max for incorrect answers
                min_value = max_value_wrong
                max_value = min_value_wrong

            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))

        return rewards

    return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = "en"):
    """
    Computes N-gram repetition penalty as described in Appendix C.2 of https://huggingface.co/papers/2502.03373.
    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

    Args:
    ngram_size: size of the n-grams
    max_penalty: Maximum (negative) penalty for wrong answers
    language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams.
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")
    
    if ngram_size <= 0:
        raise ValueError(f"ngram_size {ngram_size} should be positive")

    if language == "en":

        def zipngram(text: str, ngram_size: int):
            words = text.lower().split()
            return zip(*[words[i:] for i in range(ngram_size)]), words

    elif language == "zh":
        from transformers.utils.import_utils import _is_package_available

        if not _is_package_available("jieba"):
            raise ValueError("Please install jieba to use Chinese language")

        def zipngram(text: str, ngram_size: int):
            import jieba

            seg_list = list(jieba.cut(text))
            return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list

    else:
        raise ValueError(
            f"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function."
        )

    def repetition_penalty_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
        """
        reward function the penalizes repetitions
        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

        Args:
            completions: List of model completions
        """

        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for completion in contents:
            if completion == "":
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            ngram_array, words = zipngram(completion, ngram_size)

            if len(words) < ngram_size:
                rewards.append(0.0)
                continue

            for ng in ngram_array:
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty
            rewards.append(reward)
        return rewards

    return repetition_penalty_reward


@dataclass
class RewardArguments:
    """
    Script arguments for the reward script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'correctness', 'accuracy', 'cosine', 'repetition_penalty', 'length'.
        cosine_min_value_wrong (`float`):
            Minimum reward for cosine scaling for wrong answers.
        cosine_max_value_wrong (`float`):
            Maximum reward for cosine scaling for wrong answers.
        cosine_min_value_correct (`float`):
            Minimum reward for cosine scaling for correct answers.
        cosine_max_value_correct (`float`):
            Maximum reward for cosine scaling for correct answers.
        cosine_max_len (`int`):
            Maximum length for cosine scaling.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["correctness", "repetition_penalty"],
        metadata={
            "help": "List of reward functions. Possible values: 'correctness', 'math_accuracy', 'cosine', 'repetition_penalty', 'length'"
        },
    )
    cosine_min_value_wrong: float = field(
        default=0.0,
        metadata={"help": "Minimum reward for wrong answers"},
    )
    cosine_max_value_wrong: float = field(
        default=-0.5,
        metadata={"help": "Maximum reward for wrong answers"},
    )
    cosine_min_value_correct: float = field(
        default=0.5,
        metadata={"help": "Minimum reward for correct answers"},
    )
    cosine_max_value_correct: float = field(
        default=1.0,
        metadata={"help": "Maximum reward for correct answers"},
    )
    cosine_max_len: int = field(
        default=1000,
        metadata={"help": "Maximum length for scaling"},
    )
    repetition_n_grams: int = field(
        default=3,
        metadata={"help": "Number of n-grams for repetition penalty reward"},
    )
    repetition_max_penalty: float = field(
        default=-1.0,
        metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
    )


def get_reward_funcs(script_args: RewardArguments) -> list[Callable]:
    REWARD_FUNCS_REGISTRY = {
        "correctness": correctness_reward,
        "math_accuracy": accuracy_reward,
        "cosine": get_cosine_scaled_reward(
            min_value_wrong=script_args.cosine_min_value_wrong,
            max_value_wrong=script_args.cosine_max_value_wrong,
            min_value_correct=script_args.cosine_min_value_correct,
            max_value_correct=script_args.cosine_max_value_correct,
            max_len=script_args.cosine_max_len,
        ),
        "repetition_penalty": get_repetition_penalty_reward(
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
        "length": len_reward,
    }
    reward_funcs = {func: REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs}

    return reward_funcs


if __name__ == "__main__":
    args = RewardArguments()
    reward_funcs = get_reward_funcs(args)
    
    print("=== Testing Reward Functions ===\n")
    
    # Test data with simple text (no LaTeX dependencies)
    test_completions = [
        [{"content": r"\boxed{\frac{63}{400}}"}],
        [{"content": r"\boxed{The answer is \frac{63}{400}}"}],
        [{"content": "The answer is 42"}],
        [{"content": "After careful calculation, I found that 42 is the solution"}],
        [{"content": "24"}],
        [{"content": "The result is 100"}],
        [{"content": "I think the answer might be 42 but I'm not sure"}],
        [{"content": "Alright! The final answer is: 2, 3, 4"}],
        [{"content": "Ok The answer is: **25**\nHere's why."}],
        [{"content": "Ok The answer is: (A). \nHere's why."}],
        [{"content": "[Reasoning] The final answer is: \\boxed{4}."}],
        [{"content": "Ok The final answer is: \\boxed{undecided}. Here is the ..."}],
        [{"content": "[Reasoning] The final answer is: \\boxed{I don't know}. Here is the ..."}],
        [{"content": "[Reasoning] The final answer is: \\boxed{IDknow}. Here are my ..."}],
        [{"content": "[Reasoning] The final answer is: \\boxed{B}. Here are my ..."}]
    ]
    
    test_solutions = [
        r"\frac{63}{400}",
        r"\frac{63}{400}",
        "42",
        "42", 
        "24",
        "100",
        "42",
        "2,3,4",
        "25.0",
        "a",
        "4",
        "undecided",
        "I don't know",
        "IDKnow",
        "B"
    ]
    

    print("1-1. Testing correctness_reward:")
    try:
        correctness_rewards = correctness_reward(test_completions, test_solutions)
        for i, (comp, sol, reward) in enumerate(zip(test_completions, test_solutions, correctness_rewards)):
            print(f"   Test {i+1}: Completion='{comp[0]['content']}' | Solution='{sol}' | Reward={reward}")
    except Exception as e:
        print(f"   Error testing correctness_reward: {e}")
    print()
    
    print("1-2. Testing accuracy_reward:")
    try:
        accuracy_rewards = accuracy_reward(test_completions, test_solutions)
        for i, (comp, sol, reward) in enumerate(zip(test_completions, test_solutions, accuracy_rewards)):
            print(f"   Test {i+1}: Completion='{comp[0]['content']}' | Solution='{sol}' | Reward={reward}")
    except Exception as e:
        print(f"   Error testing accuracy_rewards: {e}")
    print()
    
    
    # Test repetition_penalty_reward (doesn't need math verification)
    print("2. Testing repetition_penalty_reward:")
    try:
        repetition_rewards = get_repetition_penalty_reward(ngram_size=2, max_penalty=-0.5)(test_completions)
        for i, (comp, reward) in enumerate(zip(test_completions, repetition_rewards)):
            print(f"   Test {i+1}: Completion='{comp[0]['content']}' | Reward={reward}")
    except Exception as e:
        print(f"   Error testing repetition_penalty_reward: {e}")
    print()
    
    # Test edge cases for repetition penalty
    print("3. Testing edge cases for repetition_penalty_reward:")
    
    # Empty completion
    empty_completion = [[{"content": ""}]]
    try:
        empty_reward = get_repetition_penalty_reward(ngram_size=2, max_penalty=-0.5)(empty_completion)
        print(f"   Empty completion: {empty_reward}")
    except Exception as e:
        print(f"   Error with empty completion: {e}")
    
    # Repetitive text
    repetitive_completion = [[{"content": "the the the the the the the the the the"}]]
    try:
        repetitive_reward = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)(repetitive_completion)
        print(f"   Repetitive text: {repetitive_reward}")
    except Exception as e:
        print(f"   Error with repetitive text: {e}")
    
    # Short text (less than ngram_size)
    short_completion = [[{"content": "hello"}]]
    try:
        short_reward = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)(short_completion)
        print(f"   Short text (n=3): {short_reward}")
    except Exception as e:
        print(f"   Error with short text: {e}")
    print()
    
    # Test parameter validation
    print("4. Testing parameter validation:")
    try:
        # Test invalid ngram_size
        get_repetition_penalty_reward(ngram_size=0, max_penalty=-0.5)
        print("   ❌ Should have raised ValueError for ngram_size=0")
    except ValueError as e:
        print(f"   ✅ Correctly raised ValueError: {e}")
    
    try:
        # Test invalid max_penalty
        get_repetition_penalty_reward(ngram_size=2, max_penalty=0.5)
        print("   ❌ Should have raised ValueError for positive max_penalty")
    except ValueError as e:
        print(f"   ✅ Correctly raised ValueError: {e}")
    
    try:
        # Test unsupported language
        get_repetition_penalty_reward(ngram_size=2, max_penalty=-0.5, language="fr")
        print("   ❌ Should have raised ValueError for unsupported language")
    except ValueError as e:
        print(f"   ✅ Correctly raised ValueError: {e}")
    print()
    
    # Test function registry
    print("5. Testing function registry:")
    try:
        funcs = get_reward_funcs(args)
        print(f"   ✅ Successfully created {len(funcs)} reward functions")
        print(f"   Functions: {[f.__name__ if hasattr(f, '__name__') else 'lambda' for f in funcs]}")
    except Exception as e:
        print(f"   ❌ Error creating reward functions: {e}")
    print()
    
    print("=== Test Summary ===")
    print("✅ Correctness reward function tested successfully!")
    print("✅ Repetition penalty reward function tested successfully!")
    print("✅ Parameter validation tested successfully!")
    print("✅ Function registry tested successfully!")