#!/usr/bin/env python3
"""
Reward function for Undead puzzle in VERL.

Undead ground truth has structural lines (monster count header, blank line,
top/bottom clue rows) that the model may omit. This module strips those
lines from both model answer and ground truth before partial scoring,
so partial credit works even when the model outputs only the grid rows.

For partial scoring, only fillable cells (. in the initial state that become
G/V/Z in the solution) are counted — mirrors, clue numbers, and spacing
are ignored. This matches how bridges handles partial scoring.
"""

import re
import os
from typing import Optional

from reward_function.generic_puzzle import (
    extract_answer,
    normalize_grid,
)

VERBOSE_REWARD_LOGGING = os.environ.get("VERBOSE_REWARD_LOGGING", "0") == "1"
DEBUG_GENERATIONS = os.environ.get("DEBUG_GENERATIONS", "0") == "1"

# Header pattern: "G: N V: N Z: N" (monster counts)
_HEADER_RE = re.compile(r'^[GVZ]:\s*\d+\s+[GVZ]:\s*\d+\s+[GVZ]:\s*\d+\s*$')


def strip_undead_headers(grid: str) -> str:
    """Strip monster count header and blank lines from an undead grid.

    Removes:
      - Lines matching "G: N V: N Z: N" (monster count header)
      - Blank/empty lines

    Keeps:
      - Top/bottom clue rows (e.g. "   2 0 0 0  ")
      - Grid rows with left/right clues (e.g. " 2 V \\ \\ / 0")

    Also strips trailing whitespace per line to handle minor formatting diffs.
    """
    lines = grid.strip().splitlines()
    result = []
    for line in lines:
        stripped = line.strip()
        if not stripped:
            continue  # Skip blank lines
        if _HEADER_RE.match(stripped):
            continue  # Skip header
        result.append(line.rstrip())  # Keep line, strip trailing whitespace
    return "\n".join(result)


def extract_problem_from_extra_info(extra_info: dict) -> Optional[str]:
    """Extract the initial puzzle state (problem) from extra_info.

    Checks in order:
      1. extra_info['initial_state']
      2. extra_info['problem']
      3. Parses from extra_info['solution_path'] (starts with "Problem:\\n...")
    """
    if not extra_info:
        return None

    # Direct fields
    for key in ('initial_state', 'problem'):
        val = extra_info.get(key)
        if val:
            return val

    # Parse from solution_path
    sp = extra_info.get('solution_path', '')
    if sp and 'Problem:' in sp:
        # Extract text between "Problem:\n" and "After move" or end
        match = re.search(r'Problem:\s*\n(.*?)(?:\nAfter move|\nSolved)', sp, re.DOTALL)
        if match:
            return match.group(1).strip()

    return None


def calculate_partial_fillable_only(
    board: str,
    solution: str,
    problem: str,
) -> float:
    """Calculate partial correctness counting only fillable cells.

    Fillable cells are those where problem has '.' but solution has a
    monster (G/V/Z). Structural characters (mirrors, clues, spaces) are
    ignored entirely.

    Args:
        board: Model's answer (after normalization + header strip)
        solution: Ground truth (after normalization + header strip)
        problem: Initial state (after normalization + header strip)

    Returns:
        Fraction of correctly filled cells (0.0 to 1.0)
    """
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()
    problem_lines = problem.strip().splitlines()

    if len(board_lines) != len(solution_lines):
        return 0.0

    # If problem has different line count, fall back to counting all cells
    use_problem = len(problem_lines) == len(solution_lines)

    total_cells = 0
    correct_cells = 0

    for idx, (b_line, s_line) in enumerate(zip(board_lines, solution_lines)):
        if len(b_line) != len(s_line):
            return 0.0

        p_line = problem_lines[idx] if use_problem and idx < len(problem_lines) else None

        for char_idx, (b_char, s_char) in enumerate(zip(b_line, s_line)):
            if p_line and char_idx < len(p_line):
                p_char = p_line[char_idx]
                # Only count fillable cells: where problem differs from solution
                if p_char == s_char:
                    continue  # Fixed cell (mirror, clue, space) — skip
            total_cells += 1
            if b_char == s_char:
                correct_cells += 1

    if total_cells == 0:
        return 0.0

    return correct_cells / total_cells


def calculate_partial_fillable_v2(
    board: str,
    solution: str,
    problem: Optional[str],
    power_exponent: float = 5.0,
    changed_cell_weight: float = 2.0,
) -> float:
    """Partial correctness v2 with fillable-cell weighting and power shaping.

    If problem is available, fillable cells (changed from problem to solution)
    get `changed_cell_weight`, fixed cells get weight 1.0.
    If problem is not available, all cells get equal weight.

    Returns shaped score: (weighted_correct / weighted_total) ** power_exponent
    """
    board_lines = board.strip().splitlines()
    solution_lines = solution.strip().splitlines()

    if len(board_lines) != len(solution_lines):
        return 0.0

    problem_lines = None
    if problem:
        problem_lines = problem.strip().splitlines()
        if len(problem_lines) != len(solution_lines):
            problem_lines = None

    total_weight = 0.0
    correct_weight = 0.0

    for idx, (b_line, s_line) in enumerate(zip(board_lines, solution_lines)):
        if len(b_line) != len(s_line):
            return 0.0

        p_line = problem_lines[idx] if problem_lines and idx < len(problem_lines) else None

        for char_idx, (b_char, s_char) in enumerate(zip(b_line, s_line)):
            if p_line and char_idx < len(p_line):
                p_char = p_line[char_idx]
                weight = changed_cell_weight if p_char != s_char else 1.0
            else:
                weight = 1.0

            total_weight += weight
            if b_char == s_char:
                correct_weight += weight

    if total_weight == 0:
        return 0.0

    return (correct_weight / total_weight) ** power_exponent


def compute_score(
    solution_str: str = None,
    ground_truth: str = None,
    extra_info: dict = None,
    method: str = "exact",
    data_source: str = None,
    **kwargs
) -> float:
    """
    Compute reward score for Undead puzzle.

    For all methods: strips header and blank lines before comparison.
    For partial/partial_v2: only counts fillable cells (. → G/V/Z) when
    initial state is available, ignoring fixed structural characters.
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    extracted_answer = extract_answer(solution_str)
    if extracted_answer is None:
        if DEBUG_GENERATIONS:
            print(f"DEBUG undead: No answer tags found in:\n{solution_str[:200]}...")
        return 0.0

    extracted_answer = normalize_grid(extracted_answer)
    ground_truth_norm = normalize_grid(ground_truth)

    # Strip headers and blanks for all methods (including exact)
    answer_stripped = strip_undead_headers(extracted_answer)
    gt_stripped = strip_undead_headers(ground_truth_norm)

    if method in ["exact", "strict"]:
        score = 1.0 if answer_stripped == gt_stripped else 0.0
        if DEBUG_GENERATIONS and score == 0.0:
            print(f"DEBUG undead exact:\nExpected:\n{gt_stripped[:200]}\nGot:\n{answer_stripped[:200]}")
        return score

    # Get initial state for fillable-cell scoring
    problem_str = extract_problem_from_extra_info(extra_info)
    problem_stripped = None
    if problem_str:
        problem_stripped = strip_undead_headers(normalize_grid(problem_str))

    if method == "partial":
        if problem_stripped:
            score = calculate_partial_fillable_only(answer_stripped, gt_stripped, problem_stripped)
        else:
            # Fallback: count all cells (same as generic)
            from reward_function.generic_puzzle import calculate_partial_correctness
            score = calculate_partial_correctness(answer_stripped, gt_stripped)
        if VERBOSE_REWARD_LOGGING:
            print(f"Undead partial: {score:.4f} (fillable_only={problem_stripped is not None})")
        return score

    elif method == "partial_v2":
        power_exponent = 5.0
        changed_cell_weight = 2.0
        if extra_info:
            power_exponent = extra_info.get('power_exponent', 5.0)
            changed_cell_weight = extra_info.get('changed_cell_weight', 2.0)

        score = calculate_partial_fillable_v2(
            answer_stripped, gt_stripped, problem_stripped,
            power_exponent=power_exponent,
            changed_cell_weight=changed_cell_weight,
        )
        if VERBOSE_REWARD_LOGGING:
            print(f"Undead partial_v2: {score:.4f} (exp={power_exponent})")
        return score

    else:
        return 1.0 if answer_stripped == gt_stripped else 0.0


__all__ = ['compute_score']
