#!/usr/bin/env python3
"""
Reward function for Galaxies (Spiral Galaxies / Tentai Show) puzzle in VERL.

Galaxies grids alternate structural cells (corners `2`, drawn edges `3`/`5`,
fixed-content cells `0`/`4`) with model-controlled cells (interior horizontal
and vertical edges). The previous generic char-by-char scoring gave ~80%
partial credit for an "echo input" answer because most chars are structural
and unchanged.

This module mirrors `undead.py`'s "fillable-only" approach: only the cells
that *change* between the initial puzzle (`extra_info["initial_state"]`) and
the ground-truth solution count toward partial credit. The trivial "do
nothing" answer scores 0 because all fillable positions remain wrong.
"""

import os
from typing import Optional

from reward_function.generic_puzzle import (
    extract_answer,
    normalize_grid,
)

VERBOSE_REWARD_LOGGING = os.environ.get("VERBOSE_REWARD_LOGGING", "0") == "1"
DEBUG_GENERATIONS = os.environ.get("DEBUG_GENERATIONS", "0") == "1"


def calculate_partial_fillable_v2(
    board: str,
    solution: str,
    problem: Optional[str] = None,
    power_exponent: float = 5.0,
    changed_cell_weight: float = 2.0,
) -> float:
    """Partial correctness on fillable cells only, with power shaping.

    Fillable = positions where `problem[i][j] != solution[i][j]`. These are the
    cells the model has to change. Structural cells (corners, fixed dots,
    boundary edges, commas) are not scored.

    Args:
        board: model's extracted answer (normalized)
        solution: ground-truth solution (normalized)
        problem: initial puzzle state (normalized). Required.
        power_exponent: shaping exponent applied to the fraction of correct
            fillable cells.
        changed_cell_weight: kept for signature parity with generic_puzzle;
            unused (all fillable cells are weighted equally — they are by
            definition the changed cells).

    Returns:
        Reward in [0, 1].
    """
    if problem is None:
        return 0.0

    b_lines = board.strip().splitlines()
    s_lines = solution.strip().splitlines()
    p_lines = problem.strip().splitlines()

    if len(b_lines) != len(s_lines) or len(p_lines) != len(s_lines):
        return 0.0

    total = 0
    correct = 0
    for bl, sl, pl in zip(b_lines, s_lines, p_lines):
        if len(bl) != len(sl):
            return 0.0
        for j, (bc, sc) in enumerate(zip(bl, sl)):
            pc = pl[j] if j < len(pl) else None
            if pc is None or pc == sc:
                continue  # not a fillable cell
            total += 1
            if bc == sc:
                correct += 1

    if total == 0:
        return 0.0

    return (correct / total) ** power_exponent


def compute_score(
    solution_str: str = None,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "exact",
    data_source: str = None,
    **kwargs,
) -> float:
    # A/B pilot toggle: route to the original generic_puzzle scoring so we can
    # reproduce the broken pre-fix behavior on demand. Set
    # extra_info["reward_mode"]="legacy" (or pass via reward_kwargs) to enable.
    if extra_info and extra_info.get("reward_mode") == "legacy":
        from reward_function.generic_puzzle import compute_score as _legacy
        return _legacy(
            solution_str=solution_str,
            ground_truth=ground_truth,
            extra_info=extra_info,
            method=method,
            data_source=data_source,
            **kwargs,
        )

    if solution_str is None or ground_truth is None:
        return 0.0

    extracted = extract_answer(solution_str)
    if extracted is None:
        return 0.0

    extracted = normalize_grid(extracted)
    ground_truth = normalize_grid(ground_truth)

    initial_state = None
    if extra_info:
        initial_state = extra_info.get("initial_state") or extra_info.get("problem")
    if initial_state:
        initial_state = normalize_grid(initial_state)

    if method in ("exact", "strict"):
        return 1.0 if extracted == ground_truth else 0.0

    if method == "partial":
        # Fillable-cell partial without power shaping
        if not initial_state:
            return 0.0
        return calculate_partial_fillable_v2(
            extracted, ground_truth, initial_state, power_exponent=1.0
        )

    if method == "partial_v2":
        if not initial_state:
            if VERBOSE_REWARD_LOGGING:
                print("[galaxies] partial_v2 requires initial_state; returning 0")
            return 0.0
        power = 5.0
        ccw = 2.0
        if extra_info:
            power = extra_info.get("power_exponent", power)
            ccw = extra_info.get("changed_cell_weight", ccw)
        return calculate_partial_fillable_v2(
            extracted, ground_truth, initial_state,
            power_exponent=power, changed_cell_weight=ccw,
        )

    # Unknown method: fall back to exact
    return 1.0 if extracted == ground_truth else 0.0


__all__ = ["compute_score"]
