#!/usr/bin/env python3
"""
Reward function for Pattern (Nonogram / Picross) puzzle in VERL.

Pattern grids are ASCII art with column-clue header, row-clue prefix on each
content row, and `+--+--+` border separators between rows:

       2  3  1  2          <- column clues (fixed)
     +--+--+--+--+         <- border (fixed)
    1|..|##|..|..|         <- content row: cells are `..` (empty) or `##` (filled)
     +--+--+--+--+
    2|##|##|..|..|
     +--+--+--+--+
   ...

Most characters are structural (clues, borders, separators, spaces, `|`).
The model-controlled content is the sequence of `..` / `##` cells. Char-by-
char scoring via `generic_puzzle.calculate_partial_correctness_v2` couples
model errors on cell content to errors on layout, which gave a noisy signal
(median rollout `partial_v2 ≈ 0.0` because layout perturbations dominate).

This module scores by extracting cells in row-major order and comparing
sequences. Layout differences are tolerated as long as the cell sequence is
recoverable; dimensional mismatch (wrong total cell count) returns 0.

Pattern eval data does not carry `initial_state` in `extra_info`, but the
`problem` field is sometimes present on rollout JSONLs. We do not require
it — the cell-extraction approach is self-sufficient.
"""

import os
from typing import List, Optional

from reward_function.generic_puzzle import (
    extract_answer,
    normalize_grid,
)

VERBOSE_REWARD_LOGGING = os.environ.get("VERBOSE_REWARD_LOGGING", "0") == "1"
DEBUG_GENERATIONS = os.environ.get("DEBUG_GENERATIONS", "0") == "1"


def extract_cells(grid: str) -> Optional[List[str]]:
    """Extract `..` / `##` cells from a pattern ASCII grid in row-major order.

    A content row contains `|` separators; cells live between consecutive
    `|` characters. Returns None if no cells could be parsed.
    """
    if not grid:
        return None
    cells: List[str] = []
    for line in grid.splitlines():
        if "|" not in line:
            continue
        # Split on `|`; cells are the inner parts between bars.
        parts = line.split("|")
        # Drop first/last (row-clue prefix and trailing whitespace)
        if len(parts) < 3:
            continue
        for part in parts[1:-1]:
            p = part.strip()
            if p in (".", "..", "0"):
                cells.append("..")
            elif p in ("##", "x", "X", "1"):
                cells.append("##")
            elif p == "":
                # Empty cells in unsolved problem grid — treat as `..`
                cells.append("..")
            else:
                # Unrecognized cell content; abort to avoid silent miscounts
                return None
    return cells if cells else None


def calculate_partial_cells(
    board: str,
    solution: str,
    power_exponent: float = 1.0,
) -> float:
    """Cell-only partial correctness with optional power shaping."""
    b = extract_cells(board)
    s = extract_cells(solution)
    if not s:
        return 0.0
    if not b or len(b) != len(s):
        return 0.0
    correct = sum(1 for bc, sc in zip(b, s) if bc == sc)
    frac = correct / len(s)
    return frac ** power_exponent if power_exponent != 1.0 else frac


def compute_score(
    solution_str: str = None,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "exact",
    data_source: str = None,
    **kwargs,
) -> float:
    if solution_str is None or ground_truth is None:
        return 0.0

    extracted = extract_answer(solution_str)
    if extracted is None:
        return 0.0

    extracted = normalize_grid(extracted)
    ground_truth_norm = normalize_grid(ground_truth)

    if method in ("exact", "strict"):
        b_cells = extract_cells(extracted)
        s_cells = extract_cells(ground_truth_norm)
        if not s_cells or not b_cells:
            return 0.0
        return 1.0 if b_cells == s_cells else 0.0

    if method == "partial":
        return calculate_partial_cells(extracted, ground_truth_norm, power_exponent=1.0)

    if method == "partial_v2":
        power = 5.0
        if extra_info:
            power = extra_info.get("power_exponent", power)
        return calculate_partial_cells(extracted, ground_truth_norm, power_exponent=power)

    # Unknown method: fall back to cell exact
    b_cells = extract_cells(extracted)
    s_cells = extract_cells(ground_truth_norm)
    return 1.0 if (b_cells and s_cells and b_cells == s_cells) else 0.0


__all__ = ["compute_score", "extract_cells"]
