import re
from typing import Any, Optional

from rllm import Action
from rllm.rewards.reward_types import RewardOutput


def extract_solution(solution_str: str) -> Optional[str]:
    """Extract the sudoku solution from the solution string."""
    # Remove everything before the first "Assistant:" if present
    if "Assistant:" in solution_str:
        solution_str = solution_str.split("Assistant:", 1)[1]
    elif "<|im_start|>assistant" in solution_str:
        solution_str = solution_str.split("<|im_start|>assistant", 1)[1]

    # Look for answer pattern in the entire string
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str, re.IGNORECASE | re.DOTALL)
    matches = list(match)
    if matches:
        final_answer = matches[-1].group(1).strip()
        return final_answer

    # If no answer tags, try to extract the grid directly
    # Look for a pattern that looks like a sudoku grid (9 lines with numbers)
    lines = solution_str.split("\n")
    grid_lines = []
    for line in lines:
        # Check if line contains numbers or underscores (for empty cells)
        if re.search(r"[\d_]", line):
            grid_lines.append(line)

    if len(grid_lines) >= 9:
        return "\n".join(grid_lines[:9])

    return solution_str.strip()


def _compute_score_and_stats(answer: str, entry: dict[str, Any]) -> tuple[float, dict[str, Any]]:
    """
    Compute the sudoku score (same semantics as before) and also return
    local statistics useful for feedback.
    """
    oracle_answer = entry["answer"]
    metadata = entry["metadata"]
    solution: list[list[int]] = metadata["solution"]
    board_size: int = len(solution[0])
    num_rows: int = len(solution)
    total_cells: int = board_size * num_rows

    # Match answer without trailing whitespaces (as before)
    answer_stripped = "\n".join(l.rstrip() for l in answer.split("\n"))
    oracle_answer_stripped = "\n".join(l.rstrip() for l in oracle_answer.split("\n"))

    # Default local stats structure
    local_stats: dict[str, Any] = {
        "num_conflicts": None,
        "num_filled_cells": 0,
        "invalid_rows": [],
        "invalid_cols": [],
        "invalid_boxes": [],
        "incorrect_cells": [],
        "num_matching_cells": 0,
    }

    # If exact string match (stripped), treat as perfect solution
    if answer_stripped == oracle_answer_stripped:
        reward = 1.0
        num_matching = total_cells
        num_conflicts = 0
        num_filled_cells = total_cells
        local_stats.update(
            num_conflicts=num_conflicts,
            num_filled_cells=num_filled_cells,
            invalid_rows=[],
            invalid_cols=[],
            invalid_boxes=[],
            incorrect_cells=[],
            num_matching_cells=num_matching,
        )
        return reward, local_stats

    # Otherwise, replicate the original scoring logic, but track local stats.
    row_idx = 0
    num_matching = 0
    num_filled_cells = 0
    incorrect_cells = []
    invalid_rows = set()
    invalid_cols = set()
    invalid_boxes = set()

    # Try to get box dimensions from metadata if present, otherwise fall back to sqrt.
    box_height = metadata.get("box_height")
    box_width = metadata.get("box_width")
    if not isinstance(box_height, int) or box_height <= 0:
        box_height = int(num_rows ** 0.5) if num_rows > 0 else 0
    if not isinstance(box_width, int) or box_width <= 0:
        box_width = int(board_size ** 0.5) if board_size > 0 else 0
    num_box_cols = board_size // box_width if box_width else 0

    for ln in answer.split("\n"):
        if row_idx >= num_rows:
            break
        numbers = [int(c) for c in ln if c in "123456789"]
        if len(numbers) != board_size:
            # Ignore lines without a full row of digits, same as original scoring
            continue

        num_filled_cells += len(numbers)
        for col_idx, cand in enumerate(numbers):
            correct_val = solution[row_idx][col_idx]
            if cand == correct_val:
                num_matching += 1
            else:
                incorrect_cells.append((row_idx, col_idx))
                invalid_rows.add(row_idx)
                invalid_cols.add(col_idx)
                if box_height and box_width and num_box_cols:
                    box_idx = (row_idx // box_height) * num_box_cols + (col_idx // box_width)
                    invalid_boxes.add(box_idx)
        row_idx += 1

    # Base score from matching cells (same as original)
    if total_cells > 0:
        reward = num_matching / total_cells
    else:
        reward = 0.0

    # Penalty for not using standard format (unchanged)
    reward *= 0.9

    # Penalty for additional length (unchanged)
    if len(answer) > len(oracle_answer):
        reward *= len(oracle_answer) / len(answer)

    num_conflicts = total_cells - num_matching if total_cells > 0 else None

    local_stats.update(
        num_conflicts=num_conflicts,
        num_filled_cells=num_filled_cells,
        invalid_rows=sorted(invalid_rows),
        invalid_cols=sorted(invalid_cols),
        invalid_boxes=sorted(invalid_boxes),
        incorrect_cells=incorrect_cells,
        num_matching_cells=num_matching,
    )

    return reward, local_stats


def sudoku_reward_fn_eval(task_info: dict, action) -> RewardOutput:
    """Wrapper for sudoku_reward_fn with eval=True."""
    return sudoku_reward_fn(task_info, action, eval=True)


def score_answer(answer: Optional[str], entry: dict[str, Any]) -> float:
    """
    Score a sudoku answer against the oracle answer.

    This keeps the same scoring logic as reasoning_gym's sudoku dataset,
    but is now implemented via _compute_score_and_stats so we can reuse
    parsing for local feedback.
    """
    if not isinstance(answer, str) or len(answer) == 0:
        return 0.0

    reward, _ = _compute_score_and_stats(answer, entry)
    return reward


def sudoku_reward_fn(task_info: dict, action: str | Action, eval: bool = False) -> RewardOutput:
    """
    A specialized reward function for sudoku tasks.

    Evaluates whether the agent correctly solves the sudoku puzzle.
    Implements the same scoring logic as reasoning_gym's sudoku dataset,
    and also returns local constraint-violation metadata for feedback.
    """
    try:
        if isinstance(action, Action):
            action = action.action

        # Extract solution from action (may contain <answer> tags)
        extracted = extract_solution(action)
        if not extracted:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "No solution extracted",
                },
            )

        # Get the entry dict with metadata
        metadata = task_info.get("metadata", {})
        entry = {
            "question": task_info.get("question", ""),
            "answer": task_info.get("ground_truth") or task_info.get("answer", ""),
            "metadata": metadata,
        }

        # Compute score and local stats
        score, local_stats = _compute_score_and_stats(extracted, entry)

        # Determine correctness based on score
        # For eval=True, require exact score==1.0
        # For eval=False, allow near-perfect (score >= 0.9)
        is_correct = (score == 1.0) if eval else (score >= 0.9)

        # Decide validation label with a bit more nuance
        if is_correct:
            validation = "correct_solution"
        else:
            # If we couldn't really parse any filled digits, treat as invalid
            num_filled_cells = local_stats.get("num_filled_cells") or 0
            if score == 0.0 and num_filled_cells == 0:
                validation = "invalid_solution"
            else:
                validation = "partial_solution"

        meta_out: dict[str, Any] = {
            "validation": validation,
            "score": score,
            "extracted_solution": extracted[:200],  # Truncate for logging
            # Local feedback fields:
            "num_conflicts": local_stats.get("num_conflicts"),
            "num_filled_cells": local_stats.get("num_filled_cells"),
            "invalid_rows": local_stats.get("invalid_rows"),
            "invalid_cols": local_stats.get("invalid_cols"),
            "invalid_boxes": local_stats.get("invalid_boxes"),
            "incorrect_cells": local_stats.get("incorrect_cells"),
            "num_matching_cells": local_stats.get("num_matching_cells"),
        }

        return RewardOutput(
            reward=score,
            is_correct=is_correct,
            metadata=meta_out,
        )

    except Exception as e:
        return RewardOutput(
            reward=0.0,
            is_correct=False,
            metadata={
                "validation": "invalid_solution",
                "error": str(e),
            },
        )

