#!/usr/bin/env python3
"""
Generic reward function for puzzle games in VERL.

Supports: galaxies, loopy, pattern, undead, and any other grid-based puzzles.
Provides exact match and partial correctness scoring.
"""

import re
import os
from typing import Optional

# Debug flags
VERBOSE_REWARD_LOGGING = os.environ.get("VERBOSE_REWARD_LOGGING", "0") == "1"
DEBUG_GENERATIONS = os.environ.get("DEBUG_GENERATIONS", "0") == "1"


def extract_answer(text: str) -> Optional[str]:
    """
    Extract content from the last <final> or <answer> tag in the text.

    Args:
        text: Model's output text

    Returns:
        Content from the last <final>/<answer> tag, or None if not found
    
    Priority order:
        1. <final>...</final> tags (preferred for puzzle grids, Qwen3/DeepSeek R1 puzzle training)
        2. <answer>...</answer> tags (Qwen2.5 style)
    """
    # Try <final> tags first (preferred for puzzle grids)
    final_pattern = r'<final>(.*?)</final>'
    all_finals = re.findall(final_pattern, text, re.DOTALL)
    if all_finals:
        return all_finals[-1].strip()
    
    # Fallback: <answer> tags
    answer_pattern = r'<answer>(.*?)</answer>'
    all_answers = re.findall(answer_pattern, text, re.DOTALL)
    if all_answers:
        return all_answers[-1].strip()

    # Fallback: last JSON code block with {"response": ...} (Qwen3 native format)
    json_blocks = re.findall(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL)
    if json_blocks:
        return json_blocks[-1].strip()
    if '{"response"' in text:
        json_matches = re.findall(r'(\{"response".*?\})', text, re.DOTALL)
        if json_matches:
            return json_matches[-1].strip()

    return None


def normalize_grid(grid: str) -> str:
    """
    Normalize a grid string for comparison.

    Args:
        grid: Grid string (may have extra whitespace, etc.)

    Returns:
        Normalized grid string
    """
    # Strip whitespace
    grid = grid.strip()

    # Unwrap {"response": "..."} JSON wrapper (from DSR/puzzle_mode training)
    if grid.startswith('{') and '"response"' in grid:
        import json
        try:
            obj = json.loads(grid)
            if isinstance(obj, dict) and 'response' in obj:
                grid = obj['response'].strip()
        except (json.JSONDecodeError, AttributeError):
            # JSON parsing failed (literal newlines in string) - use regex
            resp_match = re.search(r'"response"\s*:\s*"(.*)"', grid, re.DOTALL)
            if resp_match:
                content = resp_match.group(1).strip()
                content = content.replace('\\n', '\n').replace('\\"', '"').replace('\\\\', '\\')
                grid = content.strip()

    # Normalize line endings
    grid = grid.replace('\r\n', '\n').replace('\r', '\n')

    # Handle JSON array strings (e.g., galaxies intformat: [["2","3",...],["5","0",...]])
    # Convert to canonical comma-separated rows for character-level comparison
    if grid.startswith('['):
        import json
        try:
            parsed = json.loads(grid)
            if isinstance(parsed, list) and len(parsed) > 0 and isinstance(parsed[0], list):
                return "\n".join(",".join(str(cell) for cell in row) for row in parsed)
        except (json.JSONDecodeError, TypeError, ValueError):
            pass

    # Normalize escaped backslashes (model outputs \\ where ground truth has \)
    grid = grid.replace('\\\\', '\\')

    return grid


def calculate_partial_correctness(
    board: str,
    solution: str,
    changed_cell_weight: float = 1.0
) -> float:
    """
    Calculate the fraction of correctly filled cells.

    Args:
        board: The model's extracted answer
        solution: The ground truth solution
        changed_cell_weight: Weight for cells (default: 1.0 = all cells equal)

    Returns:
        Fraction of correct cells (0.0 to 1.0)
    """
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()

    # Check line count matches
    if len(board_lines) != len(solution_lines):
        if VERBOSE_REWARD_LOGGING:
            print(f"Line count mismatch: board={len(board_lines)}, solution={len(solution_lines)}")
        return 0.0

    total_cells = 0
    correct_cells = 0

    for b_line, s_line in zip(board_lines, solution_lines):
        # Check line length matches
        if len(b_line) != len(s_line):
            if VERBOSE_REWARD_LOGGING:
                print(f"Line length mismatch: board={len(b_line)}, solution={len(s_line)}")
            return 0.0

        for b_char, s_char in zip(b_line, s_line):
            total_cells += 1
            if b_char == s_char:
                correct_cells += 1

    if total_cells == 0:
        return 0.0

    return correct_cells / total_cells


def calculate_partial_correctness_v2(
    board: str,
    solution: str,
    startboard: Optional[str] = None,
    power_exponent: float = 5.0,
    changed_cell_weight: float = 2.0
) -> float:
    """
    Calculate partial correctness with power shaping (v2).

    Applies power function to create softer gradients for learning:
    reward = (partial_score ** power_exponent)

    Args:
        board: The model's extracted answer
        solution: The ground truth solution
        startboard: Initial state (optional, for weighted scoring)
        power_exponent: Exponent for shaping (default: 5.0)
        changed_cell_weight: Weight for changed cells vs unchanged (default: 2.0)

    Returns:
        Shaped reward score (0.0 to 1.0)
    """
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()

    # Check dimensions
    if len(board_lines) != len(solution_lines):
        return 0.0

    total_weight = 0.0
    correct_weight = 0.0

    startboard_lines = None
    if startboard:
        startboard_lines = startboard.strip().splitlines()
        if len(startboard_lines) != len(solution_lines):
            startboard_lines = None  # Ignore if dimensions don't match

    for idx, (b_line, s_line) in enumerate(zip(board_lines, solution_lines)):
        if len(b_line) != len(s_line):
            return 0.0

        sb_line = startboard_lines[idx] if startboard_lines else None

        for char_idx, (b_char, s_char) in enumerate(zip(b_line, s_line)):
            # Determine weight for this cell
            if sb_line and char_idx < len(sb_line):
                sb_char = sb_line[char_idx]
                weight = changed_cell_weight if sb_char != s_char else 1.0
            else:
                weight = 1.0

            total_weight += weight
            if b_char == s_char:
                correct_weight += weight

    if total_weight == 0:
        return 0.0

    partial_score = correct_weight / total_weight

    # Apply power shaping
    shaped_score = partial_score ** power_exponent

    return shaped_score


def compute_score(
    solution_str: str = None,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "exact",
    data_source: str = None,
    **kwargs
) -> float:
    """
    Compute reward score for any puzzle.

    This is the main entry point for VERL's reward system.

    Args:
        solution_str: Model's generated response (full text including reasoning)
        ground_truth: Expected solution grid
        extra_info: Dict containing:
            - initial_state or problem: Initial puzzle grid (optional)
            - power_exponent: Exponent for partial_v2 shaping (default: 5.0)
            - changed_cell_weight: Weight for changed cells in v2 (default: 2.0)
        method: Scoring method
            - "exact" or "strict": Binary 1.0/0.0 for exact match
            - "partial": Partial correctness (all cells weighted equally)
            - "partial_v2": Partial correctness with power shaping
        data_source: Task identifier (passed by VERL, not used here)
        **kwargs: Additional arguments from VERL (ignored)

    Returns:
        float: Reward score (0.0 to 1.0)
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    # Extract answer from model output
    extracted_answer = extract_answer(solution_str)
    if extracted_answer is None:
        if DEBUG_GENERATIONS:
            print(f"DEBUG: No answer tags found in:\n{solution_str[:200]}...")
        return 0.0

    # Normalize both grids
    extracted_answer = normalize_grid(extracted_answer)
    ground_truth = normalize_grid(ground_truth)

    # Get initial state if available
    initial_state = None
    if extra_info:
        initial_state = extra_info.get('initial_state') or extra_info.get('problem')

    # Apply scoring method
    if method in ["exact", "strict"]:
        # Exact match
        score = 1.0 if extracted_answer == ground_truth else 0.0

        if DEBUG_GENERATIONS:
            print(f"DEBUG: Exact match = {score}")
            if score == 0.0:
                print(f"Expected:\n{ground_truth[:200]}")
                print(f"Got:\n{extracted_answer[:200]}")

        return score

    elif method == "partial":
        # Partial correctness (all cells equal weight)
        score = calculate_partial_correctness(extracted_answer, ground_truth)

        if VERBOSE_REWARD_LOGGING:
            print(f"Partial correctness: {score:.4f}")

        return score

    elif method == "partial_v2":
        # Partial correctness with power shaping
        power_exponent = 5.0
        changed_cell_weight = 2.0

        if extra_info:
            power_exponent = extra_info.get('power_exponent', 5.0)
            changed_cell_weight = extra_info.get('changed_cell_weight', 2.0)

        score = calculate_partial_correctness_v2(
            extracted_answer,
            ground_truth,
            startboard=initial_state,
            power_exponent=power_exponent,
            changed_cell_weight=changed_cell_weight
        )

        if VERBOSE_REWARD_LOGGING:
            print(f"Partial v2 correctness: {score:.4f} (exp={power_exponent})")

        return score

    else:
        # Unknown method, default to exact
        if VERBOSE_REWARD_LOGGING:
            print(f"Warning: Unknown method '{method}', defaulting to exact")
        return 1.0 if extracted_answer == ground_truth else 0.0


# Convenience functions for different scoring methods
def compute_score_exact(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Exact match scoring."""
    return compute_score(solution_str, ground_truth, extra_info, method="exact")


def compute_score_partial(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Partial correctness scoring."""
    return compute_score(solution_str, ground_truth, extra_info, method="partial")


def compute_score_partial_v2(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Partial correctness v2 scoring with power shaping."""
    return compute_score(solution_str, ground_truth, extra_info, method="partial_v2")


# Test code
if __name__ == "__main__":
    # Test with sample galaxies puzzle
    test_response = """<reasoning>
Let me solve this step by step.
</reasoning>
<answer>
2,3,2,3,2,3,2,3,2,3,2
5,0,0,0,0,0,5,0,4,0,5
2,0,2,4,2,0,2,3,2,3,2
5,0,0,0,0,0,5,0,5,4,5
2,3,2,3,2,3,2,0,2,3,2
5,0,0,0,5,0,0,4,0,0,5
2,0,2,0,2,3,2,0,2,3,2
5,0,0,4,0,0,5,0,5,0,5
2,3,2,0,2,0,2,3,2,4,2
5,4,5,0,0,0,5,4,5,0,5
2,3,2,3,2,3,2,3,2,3,2
</answer>"""

    test_ground_truth = """2,3,2,3,2,3,2,3,2,3,2
5,0,0,0,0,0,5,0,4,0,5
2,0,2,4,2,0,2,3,2,3,2
5,0,0,0,0,0,5,0,5,4,5
2,3,2,3,2,3,2,0,2,3,2
5,0,0,0,5,0,0,4,0,0,5
2,0,2,0,2,3,2,0,2,3,2
5,0,0,4,0,0,5,0,5,0,5
2,3,2,0,2,0,2,3,2,4,2
5,4,5,0,0,0,5,4,5,0,5
2,3,2,3,2,3,2,3,2,3,2"""

    print("Testing generic puzzle reward function")
    print("=" * 50)

    # Test exact match
    score_exact = compute_score(test_response, test_ground_truth, method="exact")
    print(f"Exact match score: {score_exact}")

    # Test partial
    score_partial = compute_score(test_response, test_ground_truth, method="partial")
    print(f"Partial score: {score_partial}")

    # Test with wrong answer
    wrong_response = test_response.replace("5,0,0,0,0,0,5", "5,0,0,0,0,0,3")
    score_wrong = compute_score(wrong_response, test_ground_truth, method="exact")
    print(f"Wrong answer (exact): {score_wrong}")
    score_wrong_partial = compute_score(wrong_response, test_ground_truth, method="partial")
    print(f"Wrong answer (partial): {score_wrong_partial:.4f}")

    print("\nAll tests completed!")
