"""
Format Validation Reward Functions

Provides reward functions that validate the formatting of model responses,
particularly the presence and structure of XML-like tags (<reasoning>, <answer>).

These can be used to encourage proper formatting during training.

Ported from an internal reward-functions module (pre-existing utility code)
"""

import re
from typing import Tuple, Dict


def count_xml_tags(text: str) -> Dict[str, int]:
    """
    Count occurrences of specific XML tags in text.

    Args:
        text: Model response

    Returns:
        Dict mapping tag name to count
    """
    tags = ['reasoning', 'answer', 'intermediate_grid']
    counts = {}

    for tag in tags:
        open_pattern = f'<{tag}>'
        close_pattern = f'</{tag}>'

        open_count = len(re.findall(open_pattern, text))
        close_count = len(re.findall(close_pattern, text))

        counts[f'{tag}_open'] = open_count
        counts[f'{tag}_close'] = close_count

    return counts


def check_format_strict(text: str) -> Tuple[bool, str]:
    """
    Strict format validation.

    Required format:
    - Exactly one <reasoning> section
    - Exactly one <answer> section
    - Tags properly closed
    - <reasoning> comes before <answer>

    Args:
        text: Model response

    Returns:
        Tuple of (is_valid, error_message)
    """
    # Check for reasoning tags
    reasoning_open = len(re.findall(r'<reasoning>', text))
    reasoning_close = len(re.findall(r'</reasoning>', text))

    if reasoning_open == 0:
        return False, "Missing <reasoning> tag"
    if reasoning_open > 1:
        return False, f"Multiple <reasoning> tags ({reasoning_open})"
    if reasoning_open != reasoning_close:
        return False, "Unclosed <reasoning> tag"

    # Check for answer tags
    answer_open = len(re.findall(r'<answer>', text))
    answer_close = len(re.findall(r'</answer>', text))

    if answer_open == 0:
        return False, "Missing <answer> tag"
    if answer_open > 1:
        return False, f"Multiple <answer> tags ({answer_open})"
    if answer_open != answer_close:
        return False, "Unclosed <answer> tag"

    # Check order: reasoning should come before answer
    reasoning_match = re.search(r'<reasoning>', text)
    answer_match = re.search(r'<answer>', text)

    if reasoning_match and answer_match:
        if reasoning_match.start() > answer_match.start():
            return False, "<answer> appears before <reasoning>"

    return True, "Valid format"


def check_format_soft(text: str) -> Tuple[bool, str]:
    """
    Soft format validation (more lenient).

    Required:
    - At least one <answer> tag
    - <answer> tag is closed

    Optional:
    - <reasoning> tag (gives bonus but not required)

    Args:
        text: Model response

    Returns:
        Tuple of (is_valid, error_message)
    """
    # Must have answer tag
    answer_open = len(re.findall(r'<answer>', text))
    answer_close = len(re.findall(r'</answer>', text))

    if answer_open == 0:
        return False, "Missing <answer> tag"
    if answer_open != answer_close:
        return False, "Unclosed <answer> tag"

    return True, "Valid format"


def strict_format_reward_func(
    solution_str: str,
    ground_truth: str = None,
    extra_info: dict = None,
    **kwargs
) -> float:
    """
    Strict format reward: 1.0 if format is correct, 0.0 otherwise.

    Args:
        solution_str: Model response
        ground_truth: Unused
        extra_info: Unused
        **kwargs: Additional arguments (ignored)

    Returns:
        1.0 if format is valid, 0.0 otherwise
    """
    is_valid, _ = check_format_strict(solution_str)
    return 1.0 if is_valid else 0.0


def soft_format_reward_func(
    solution_str: str,
    ground_truth: str = None,
    extra_info: dict = None,
    **kwargs
) -> float:
    """
    Soft format reward: rewards proper formatting with bonus for reasoning.

    Scoring:
    - Base: 0.5 if <answer> tag present and closed
    - Bonus: +0.5 if <reasoning> tag also present and closed

    Args:
        solution_str: Model response
        ground_truth: Unused
        extra_info: Unused
        **kwargs: Additional arguments (ignored)

    Returns:
        Reward score (0.0 to 1.0)
    """
    score = 0.0

    # Check answer tag (required)
    answer_open = len(re.findall(r'<answer>', solution_str))
    answer_close = len(re.findall(r'</answer>', solution_str))

    if answer_open > 0 and answer_open == answer_close:
        score += 0.5

    # Check reasoning tag (bonus)
    reasoning_open = len(re.findall(r'<reasoning>', solution_str))
    reasoning_close = len(re.findall(r'</reasoning>', solution_str))

    if reasoning_open > 0 and reasoning_open == reasoning_close:
        score += 0.5

    return score


def xmlcount_reward_func(
    solution_str: str,
    ground_truth: str = None,
    extra_info: dict = None,
    **kwargs
) -> float:
    """
    Granular XML tag counting reward.

    Gives partial credit based on number of correctly formed tags.

    Scoring:
    - <reasoning> tag: 0.25 if present and closed
    - <answer> tag: 0.25 if present and closed
    - Proper order: 0.25 if reasoning before answer
    - Single occurrence: 0.25 if no duplicate tags

    Args:
        solution_str: Model response
        ground_truth: Unused
        extra_info: Unused
        **kwargs: Additional arguments (ignored)

    Returns:
        Reward score (0.0 to 1.0)
    """
    score = 0.0

    # Count tags
    reasoning_open = len(re.findall(r'<reasoning>', solution_str))
    reasoning_close = len(re.findall(r'</reasoning>', solution_str))
    answer_open = len(re.findall(r'<answer>', solution_str))
    answer_close = len(re.findall(r'</answer>', solution_str))

    # Reasoning tag (0.25)
    if reasoning_open == 1 and reasoning_close == 1:
        score += 0.25

    # Answer tag (0.25)
    if answer_open == 1 and answer_close == 1:
        score += 0.25

    # Proper order (0.25)
    reasoning_match = re.search(r'<reasoning>', solution_str)
    answer_match = re.search(r'<answer>', solution_str)

    if reasoning_match and answer_match:
        if reasoning_match.start() < answer_match.start():
            score += 0.25

    # Single occurrence (0.25)
    if reasoning_open <= 1 and answer_open <= 1:
        score += 0.25

    return score


def format_reward(
    solution_str: str,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "strict",
    **kwargs
) -> float:
    """
    Main entry point for format rewards.

    Args:
        solution_str: Model response
        ground_truth: Unused
        extra_info: Unused
        method: Reward method ("strict", "soft", "xmlcount")
        **kwargs: Additional arguments (ignored)

    Returns:
        Reward score (0.0 to 1.0)
    """
    if method == "strict":
        return strict_format_reward_func(solution_str, ground_truth, extra_info, **kwargs)
    elif method == "soft":
        return soft_format_reward_func(solution_str, ground_truth, extra_info, **kwargs)
    elif method == "xmlcount":
        return xmlcount_reward_func(solution_str, ground_truth, extra_info, **kwargs)
    else:
        raise ValueError(f"Unknown format reward method: {method}")


# Test code
if __name__ == "__main__":
    # Test cases
    test_cases = [
        # Valid format
        ("""<reasoning>
Let me solve this step by step.
</reasoning>
<answer>
Solution here
</answer>""", "Valid"),

        # Missing reasoning
        ("""<answer>
Solution here
</answer>""", "Missing reasoning"),

        # Missing answer
        ("""<reasoning>
Some reasoning
</reasoning>""", "Missing answer"),

        # Wrong order
        ("""<answer>
Solution here
</answer>
<reasoning>
Some reasoning
</reasoning>""", "Wrong order"),

        # Unclosed tags
        ("""<reasoning>
Some reasoning
<answer>
Solution here
</answer>""", "Unclosed reasoning"),

        # Multiple tags
        ("""<reasoning>
First reasoning
</reasoning>
<reasoning>
Second reasoning
</reasoning>
<answer>
Solution here
</answer>""", "Multiple reasoning"),
    ]

    print("Testing format rewards:")
    print("=" * 60)

    for text, description in test_cases:
        print(f"\nTest: {description}")
        print(f"  Strict:   {strict_format_reward_func(text):.2f}")
        print(f"  Soft:     {soft_format_reward_func(text):.2f}")
        print(f"  XMLCount: {xmlcount_reward_func(text):.2f}")

        is_valid_strict, msg_strict = check_format_strict(text)
        is_valid_soft, msg_soft = check_format_soft(text)
        print(f"  Strict check: {is_valid_strict} - {msg_strict}")
        print(f"  Soft check:   {is_valid_soft} - {msg_soft}")
