#!/usr/bin/env python3
"""
Reward function for Bridges puzzle in VERL.

Ported from an internal reward-functions module (pre-existing utility code)
Adapted to VERL's compute_score signature.
"""

import re
import os
from typing import Optional, Tuple

# Debug flags (compatible with the legacy reward-debug flags)
VERBOSE_REWARD_LOGGING = os.environ.get("VERBOSE_REWARD_LOGGING", "0") == "1"
DEBUG_GENERATIONS = os.environ.get("DEBUG_GENERATIONS", "0") == "1"


def extract_answer(text: str) -> Optional[str]:
    """
    Extract content from the last <final>, <answer> tag, or "Final Solution:" in the text.

    Args:
        text: Model's output text

    Returns:
        Content from the last <final>/<answer> tag or after "Final Solution:", or None if not found

    Priority order:
        1. <final>...</final> tags (preferred for puzzle grids, Qwen3/DeepSeek R1 puzzle training)
        2. <answer>...</answer> tags (Qwen2.5 style)
        3. "Final Solution:" format (legacy SFT training)

    Note:
        For JSON-wrapped answers (e.g., `{"response": "[[...]]"}`), the JSON is parsed
        and the grid is converted to newline-separated format.
    """
    extracted = None

    # Try <final> tags first (preferred for puzzle grids)
    final_pattern = r'<final>(.*?)</final>'
    all_finals = re.findall(final_pattern, text, re.DOTALL)
    if all_finals:
        extracted = all_finals[-1].strip()

    # Fallback: <answer> tags
    if extracted is None:
        answer_pattern = r'<answer>(.*?)</answer>'
        all_answers = re.findall(answer_pattern, text, re.DOTALL)
        if all_answers:
            extracted = all_answers[-1].strip()

    # Fallback: "Final Solution:" format (from legacy SFT training)
    if extracted is None and "Final Solution:" in text:
        solution_part = text.split("Final Solution:")[-1].strip()
        # Extract grid lines (stop at empty line or end)
        lines = []
        for line in solution_part.split('\n'):
            line = line.strip()
            if not line:
                break
            # Check if it looks like a grid line (has commas)
            if ',' in line:
                lines.append(line)
        if lines:
            extracted = '\n'.join(lines)

    # Fallback: last JSON code block with {"response": ...} (Qwen3 native format)
    if extracted is None:
        json_blocks = re.findall(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL)
        if json_blocks:
            extracted = json_blocks[-1].strip()
        elif '{"response"' in text:
            json_matches = re.findall(r'(\{"response".*?\})', text, re.DOTALL)
            if json_matches:
                extracted = json_matches[-1].strip()

    if extracted is None:
        return None

    # Handle JSON-wrapped answers (from DeepSeek R1 traces)
    # Format: {"response": "[[\"4\", \"4\", ...], [...]]"} or similar
    extracted = extracted.strip()

    # Strip markdown code block wrapper if present
    if extracted.startswith("```"):
        lines = extracted.split("\n")
        if len(lines) > 2 and lines[-1].strip() == "```":
            extracted = "\n".join(lines[1:-1]).strip()
        elif len(lines) > 1:
            extracted = "\n".join(lines[1:]).strip()

    # Check if it's JSON-wrapped
    if '{"response":' in extracted or '"response":' in extracted:
        import json
        try:
            data = json.loads(extracted)
            # Parse the nested JSON string (guard against list/non-dict shapes)
            grid_str = data.get("response", "") if isinstance(data, dict) else ""
            if grid_str:
                # Convert JSON array string to grid format
                grid = json.loads(grid_str)
                if isinstance(grid, list) and len(grid) > 0:
                    # Format as newline-separated rows with comma-separated cells
                    return "\n".join(",".join(str(cell) for cell in row) for row in grid)
        except (json.JSONDecodeError, KeyError, TypeError, ValueError, AttributeError):
            # Fallback: model often produces invalid JSON like {"response": "[["4", "8"...]]"}
            # where inner quotes break the outer JSON string. Use regex to extract [[...]]
            grid_match = re.search(r'(\[\[.*\]\])', extracted, re.DOTALL)
            if grid_match:
                grid_str = grid_match.group(1).strip()
                # Try parsing directly, then with cleaned escaped quotes
                for candidate in [grid_str, grid_str.replace('\\"', '"')]:
                    try:
                        grid = json.loads(candidate)
                        if isinstance(grid, list) and len(grid) > 0:
                            return "\n".join(",".join(str(cell) for cell in row) for row in grid)
                    except (json.JSONDecodeError, TypeError, ValueError):
                        continue

    return extracted


def calculate_partial_correctness(
    startboard: str,
    board: str,
    solution: str
) -> float:
    """
    Calculate the fraction of correctly filled cells (v1 - fillable cells only).

    Args:
        startboard: The initial puzzle state
        board: The model's extracted answer
        solution: The ground truth solution

    Returns:
        Fraction of correctly filled cells (0.0 to 1.0)
    """
    startboard_lines = startboard.strip().splitlines()
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()

    if len(board_lines) != len(solution_lines):
        return 0.0

    total_cells = 0
    correct_cells = 0

    for sb_line, b_line, s_line in zip(startboard_lines, board_lines, solution_lines):
        for sb_char, b_char, s_char in zip(sb_line, b_line, s_line):
            # Only count cells that need to be filled (different in solution vs startboard)
            if sb_char != s_char:
                total_cells += 1
                if b_char == s_char:
                    correct_cells += 1

    if total_cells == 0:
        return 0.0

    return correct_cells / total_cells


def calculate_partial_correctness_v2(
    board: str,
    solution: str,
    startboard: str,
    changed_cell_weight: float = 2.0
) -> Tuple[float, int, int]:
    """
    Calculate partial correctness with cell weighting (v2).

    Unlike v1, this counts ALL cells and applies strict structural checks.
    Changed cells (bridges) get higher weight than unchanged cells.

    Args:
        board: The model's extracted answer
        solution: The ground truth solution
        startboard: Initial puzzle state (required for cell weighting)
        changed_cell_weight: Weight multiplier for filled cells (default: 2.0)

    Returns:
        Tuple of (partial_score, correct_cells, total_cells)
    """
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()
    startboard_lines = startboard.strip().splitlines()

    # Strict size check
    if len(board_lines) != len(solution_lines):
        return 0.0, 0, 0

    if len(startboard_lines) != len(solution_lines):
        return 0.0, 0, 0

    total_weight = 0.0
    correct_weight = 0.0
    total_cells = 0
    correct_cells = 0

    for i, (b_line, s_line) in enumerate(zip(board_lines, solution_lines)):
        # Strict line length check
        if len(b_line) != len(s_line):
            return 0.0, 0, 0

        sb_line = startboard_lines[i]

        for j, (b_char, s_char) in enumerate(zip(b_line, s_line)):
            total_cells += 1

            # Changed cell gets higher weight
            weight = changed_cell_weight if j < len(sb_line) and sb_line[j] != s_char else 1.0

            total_weight += weight
            if b_char == s_char:
                correct_cells += 1
                correct_weight += weight

    if total_weight == 0:
        return 0.0, 0, 0

    partial_score = correct_weight / total_weight
    return partial_score, correct_cells, total_cells


def normalize_grid_format(grid_str: str) -> str:
    """
    Normalize grid format to newline-separated, comma-separated format.

    Handles multiple input formats:
    1. JSON array string: '[["4", "4"], ["6", "3"]]' -> '4,4\n6,3'
    2. Already formatted: '4,4\n6,3' -> '4,4\n6,3' (unchanged)
    3. List (from dataset): [["4", "4"], ["6", "3"]] -> '4,4\n6,3'

    Args:
        grid_str: Grid in any supported format

    Returns:
        Normalized grid string with newline-separated rows
    """
    if not grid_str:
        return ""

    grid_str = str(grid_str).strip()

    # Check if it's a JSON array string
    if grid_str.startswith('['):
        try:
            import json
            grid = json.loads(grid_str)
            if isinstance(grid, list) and len(grid) > 0:
                if isinstance(grid[0], list):
                    # It's a 2D array
                    return "\n".join(",".join(str(cell) for cell in row) for row in grid)
        except (json.JSONDecodeError, TypeError, ValueError):
            pass

    # Return as-is if already in expected format or couldn't parse
    return grid_str


def compute_score(
    solution_str: str = None,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "exact",
    data_source: str = None,  # Passed by VERL but not used
    **kwargs  # Accept any additional kwargs from VERL
) -> float | dict:
    """
    Compute reward score for bridges puzzle.

    This is the main entry point for VERL's reward system.

    Args:
        solution_str: Model's generated response (full text including reasoning)
        ground_truth: Expected solution grid (supports JSON array or grid format)
        extra_info: Dict containing:
            - initial_state: Initial puzzle grid (required for partial methods)
            - difficulty: Puzzle difficulty (optional)
            - power_exponent: Exponent for partial_v2 shaping (default: 5.0)
            - changed_cell_weight: Weight for changed cells in v2 (default: 2.0)
        method: Scoring method
            - "exact" or "strict": Binary 1.0/0.0 for exact match
            - "partial": Partial correctness v1 (fillable cells only)
            - "partial_v2": Partial correctness v2 with power shaping
            - "exact_plus_format": Combined exact + xmlcount format reward (returns dict)
        data_source: Task identifier (passed by VERL, not used here)
        **kwargs: Additional arguments from VERL (ignored)

    Returns:
        float: Reward score (0.0 to 1.0) for exact/partial methods
        dict: {"score", "acc", "format_xmlcount"} for exact_plus_format method
    """
    extra_info = extra_info or {}

    # Normalize ground_truth format (handles JSON array strings from DSR datasets)
    ground_truth = normalize_grid_format(ground_truth)

    # Extract answer from model output
    extracted = extract_answer(solution_str)

    # Normalize extracted answer too (model may output JSON array matching input format)
    if extracted is not None:
        extracted = normalize_grid_format(extracted)

    if extracted is None:
        # No valid answer tag found
        # For combined methods, still give partial format credit
        if method.startswith("exact_plus_format"):
            from reward_function.advanced.format_rewards import xmlcount_reward_func
            format_score = xmlcount_reward_func(solution_str)
            format_weight = extra_info.get("format_weight", 0.05)
            parts = method.split("_")
            if len(parts) == 4:
                try:
                    format_weight = float(parts[3])
                except ValueError:
                    pass
            score = format_weight * format_score
            if VERBOSE_REWARD_LOGGING:
                print(f"[bridges] No answer tag, format-only: {format_weight}*xmlcount={format_score:.2f} = {score:.4f}")
            return {"score": score, "acc": 0.0, "format_xmlcount": format_score}
        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] No <answer> tag found in response")
        return 0.0

    if VERBOSE_REWARD_LOGGING:
        print(f"[bridges] Extracted answer:\n{extracted[:200]}...")
        print(f"[bridges] Ground truth:\n{ground_truth[:200]}...")

    # Exact match
    if method in ("exact", "strict"):
        score = 1.0 if extracted.strip() == ground_truth.strip() else 0.0
        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] Exact match: {score}")
        return score

    # Combined exact + format reward (exact + weight * xmlcount)
    # method format: "exact_plus_format" or "exact_plus_format_W" where W is weight
    # Default weight: 0.05. Override via method name or extra_info.format_weight
    if method.startswith("exact_plus_format"):
        from reward_function.advanced.format_rewards import xmlcount_reward_func

        exact_score = 1.0 if extracted.strip() == ground_truth.strip() else 0.0
        format_score = xmlcount_reward_func(solution_str)

        # Parse weight: "exact_plus_format_0.1" -> 0.1, default 0.05
        format_weight = extra_info.get("format_weight", 0.05)
        parts = method.split("_")
        if len(parts) == 4:  # exact_plus_format_0.1
            try:
                format_weight = float(parts[3])
            except ValueError:
                pass

        score = exact_score + format_weight * format_score
        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] Combined: exact={exact_score} + {format_weight}*xmlcount={format_score:.2f} = {score:.4f}")
        return {"score": score, "acc": exact_score, "format_xmlcount": format_score}

    # Partial correctness methods require initial_state
    initial_state = extra_info.get("initial_state")
    if initial_state is None:
        # Fall back to exact match if no initial state
        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] No initial_state, falling back to exact match")
        return 1.0 if extracted.strip() == ground_truth.strip() else 0.0

    # Normalize initial_state format (handles JSON array strings)
    initial_state = normalize_grid_format(initial_state)

    if method == "partial":
        # Partial correctness v1
        if extracted.strip() == ground_truth.strip():
            return 1.0
        partial_score = calculate_partial_correctness(initial_state, extracted, ground_truth)
        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] Partial v1 score: {partial_score}")
        return partial_score

    elif method == "partial_v2":
        # Partial correctness v2 with power shaping
        power_exponent = extra_info.get("power_exponent", 5.0)
        changed_cell_weight = extra_info.get("changed_cell_weight", 2.0)

        partial_score, correct_cells, total_cells = calculate_partial_correctness_v2(
            extracted, ground_truth, initial_state, changed_cell_weight
        )

        # Exact match gets full reward
        if partial_score >= 0.9995:
            score = 1.0
        else:
            # Nonlinear shaping with power function
            # Note: VERL uses 0-1 scale, the legacy upstream uses 0-2
            score = partial_score ** power_exponent

        if VERBOSE_REWARD_LOGGING:
            print(f"[bridges] Partial v2: {partial_score:.4f} ({correct_cells}/{total_cells})")
            print(f"[bridges] Shaped score: {score:.4f} (exp={power_exponent})")

        return score

    else:
        raise ValueError(f"Unknown method: {method}")


# For compatibility with VERL's aggregator pattern
def compute_score_exact(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Exact match scoring."""
    return compute_score(solution_str, ground_truth, extra_info, method="exact")


def compute_score_partial(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Partial correctness v1 scoring."""
    return compute_score(solution_str, ground_truth, extra_info, method="partial")


def compute_score_partial_v2(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
    """Partial correctness v2 scoring with power shaping."""
    return compute_score(solution_str, ground_truth, extra_info, method="partial_v2")


# Test code
if __name__ == "__main__":
    # Test with sample data
    test_response = """<reasoning>
Let me solve this step by step.
Looking at the grid...
</reasoning>
<answer>
7,3,3,3,7
2,4,4,4,2
2,5,3,3,7
2,4,4,4,4
8,D,8,D,6
</answer>"""

    test_ground_truth = """7,3,3,3,7
2,4,4,4,2
2,5,3,3,7
2,4,4,4,4
8,D,8,D,6"""

    test_initial = """7,4,4,4,7
4,4,4,4,4
4,5,4,4,7
4,4,4,4,4
8,4,8,4,6"""

    extra = {"initial_state": test_initial}

    print("Testing exact match (correct):")
    print(f"  Score: {compute_score(test_response, test_ground_truth, extra, 'exact')}")

    print("\nTesting partial v1:")
    print(f"  Score: {compute_score(test_response, test_ground_truth, extra, 'partial')}")

    print("\nTesting partial v2:")
    print(f"  Score: {compute_score(test_response, test_ground_truth, extra, 'partial_v2')}")

    # Test with wrong answer
    wrong_response = """<answer>
7,3,3,3,7
2,4,4,4,2
2,5,4,4,7
2,4,4,4,4
8,D,8,D,6
</answer>"""

    print("\nTesting with partially wrong answer:")
    print(f"  Exact: {compute_score(wrong_response, test_ground_truth, extra, 'exact')}")
    print(f"  Partial v1: {compute_score(wrong_response, test_ground_truth, extra, 'partial')}")
    print(f"  Partial v2: {compute_score(wrong_response, test_ground_truth, extra, 'partial_v2')}")
