import logging
import re
from typing import Any, List

from rllm.agents.agent import Action, Step, Trajectory
from rllm.agents.game_agents.base import BaseGameAgent, SYSTEM_PROMPT_TEMPLATE, USER_PROMPT_TEMPLATE, INTRO_USER_PROMPT_TEMPLATE

logger = logging.getLogger(__name__)


class SudokuAgentNew(BaseGameAgent):
    """
    Sudoku Agent class
    """
    role = "You are a professional sudoku solver. You are given a sudoku board and you need to solve it. You can solve the puzzle over multiple turns, so it's not necessary to outline the full solution at once."

    format_explanation = """Coordinates:
- We will use rXcY coordinates. For example, r1c1 is the top-left cell at row 1 column 1, r1c2 is the cell to the right at row 1 column 2, r2c1 is the cell below at row 2 column 1, and so on.

Representation of the board:
- Initial board values (given numbers) are represented as value (e.g., 4, 7).
- Empty cells are represented as '.' (e.g., '.') and your value will be placed in the cell (e.g., '3').

Visual Elements:
- Any visual elements will be described in text using rXcY coordinates.
- Please note the visual elements will be described as-is. If a thermo or arrow appears on the board, the location of the circle or bulb will be listed, and the line or arrow will be listed as a separate object. But you can infer they are part of the same object by their coordinates.
- If a visual element is described as "between" two cells, it means the visual element appears on the edge between the two cells.
- In some puzzles there may be visual elements outside of the grid and these will be described using the same coordinate system. For example an arrow in r0c1 pointing to the lower right means there is an arrow above r1c1 that points in the direction of the diagonal: r1c2, r2c3, etc.
- When "None" is given, it is a standard sudoku board, so you should only place digits in the empty cells."""

    action_space = """1. Value Setting: value(digit, rXcY)
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules).
- You can use only ONE action at a time."""

    goal = "Complete the sudoku board based on the rules and visual elements."

    output_requirements = """1. Thought:
Provide a detailed, step-by-step reasoning process explaining your thought process in solving the task.
2. Reason:
This will be used as a short-term memory in the next turn after the Thought section is removed. Summarize ALL key logical information needed to continue solving:
- The decided cell and digit
- The constraints involved (row / column / box)
- Any partially solved structures (e.g., pairs, triples, remaining candidates)
Reason does not exceed 8 sentences.
3. Action:
List the concrete solving action you want to take, wrapped in triple backticks.

### Output Format
You must generate your thought, reason and action in the following format:
<think>
[Your thought process in solving the task.]
</think>
REASON: [Summary of your reasoning process to maintain the logic consistency.]
ACTION: ```
[Your action]
```
"""
    def __init__(self, max_steps: int = 30, use_accumulate_thinking: bool = False, history_window: int | None = None, use_multi_turn_format: bool = True, additional_info_path: str = None, board_format: str = "base"):
        self._trajectory = Trajectory()
        self.messages = []
        self.step: int = 0
        self.use_accumulate_thinking = use_accumulate_thinking  # controlls whether to accumulate the thinking portion of the response
        self.max_steps = max_steps
        self.history_window = history_window
        self.use_multi_turn_format = use_multi_turn_format  # reasoning models have good performace with single-turn format
        self.additional_info_path = additional_info_path
        self.board_format = board_format
        # state
        self.current_observation = None
        self.additional_info = None

        self.reset()

    def _make_system_prompt(self, task_info: dict) -> str:
        role_text = self.role
        
        role_text += f"\n\n## Format Explanation\n{self.format_explanation}"

        if self.additional_info:
            if task_info.get("visual_elements", "None") == "None":
                subtask_type = "standard_sudoku"
            else:
                subtask_type = "ctc_sudoku"
            
            general_sudoku_insights = self.additional_info.get("general_sudoku", "")
            subtask_insights = self.additional_info.get(subtask_type, "")
            role_text += f"\n\nHere are some insights to help you solve the puzzle:\n### General Insights\n{general_sudoku_insights}\n### Additional Insights\n{subtask_insights}"

        system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
            role=role_text,
            output_requirements=self.output_requirements,
        )
        return system_prompt
    
    def _make_init_user_prompt(self, task_info: dict) -> str:
        environment_info_str = f"""### Available Actions
{self.action_space}

### Goal
{self.goal}

### Size
{task_info['rows']}x{task_info['cols']}

### Rules
{task_info['rules']}

### Visual Elements
{task_info['visual_elements']}
"""
        user_prompt = INTRO_USER_PROMPT_TEMPLATE.format(
            environment_info=environment_info_str,
        )
        return user_prompt

    def _make_user_prompt(self, observation: Any) -> str:
        user_prompt = USER_PROMPT_TEMPLATE.format(
            current_observation=observation["observation"],
        )
        return user_prompt

    def _reformat_action(self, response: str) -> str:
        if "value(" in response:
            return response.split("\n")[0].strip()
        else:
            return response

    def _process_observation(self, observation: Any) -> str:
        """
        Process observation from environment.
        """
        if isinstance(observation, dict):
            observation_str = reformat_observation(observation.get("board_spaced", ""), keep_dots= True, board_format=self.board_format)
            env_message = observation.get("env_message", None)
            if env_message:
                observation_str += f"\n\n- Error message from environment:\n{env_message}"
            return {
                "observation": observation_str,
                "observation_info": {
                    "board_ascii": observation.get("board_ascii", ""),
                    "board_spaced": observation.get("board_spaced", ""),
                    "board_tokens": observation.get("board_tokens", ""),
                    "current_turn": observation.get("current_turn", ""),
                    "max_turns": observation.get("max_turns", ""),
                    "progress": observation.get("progress", ""),
                    "correct_cells": observation.get("correct_cells", ""),
                    "total_cells": observation.get("total_cells", ""),
                }
            }
        elif isinstance(observation, str):
            return {
                "observation": observation,
            }
        else:
            raise ValueError(f"Invalid observation type: {type(observation)}")


def reformat_observation(puzzle_str: str, keep_dots: bool = True, board_format: str = "base") -> str:
    """
    Format a variable-size Sudoku board (4x4, 6x6, 9x9) into a fixed-width, left-aligned string.

    Parameters
    ----------
    puzzle_str : str
        Multiline string of whitespace-separated tokens (digits, '.', '0', '_').
        Example:
            "1 . 3 .\n. 2 . 4\n3 . . .\n. . . ."
    keep_dots : bool
        Whether to keep blanks as '.' when rendering.
        - True  -> output '.'
        - False -> render as spaces (width 3)
    size : Optional[int]
        Board size (one of {4, 6, 9}). If not provided, it is inferred from the number of tokens
        in the first non-empty line.

    Rules
    -----
    - Blank tokens in {'.', '0', '_'} are treated as empty.
    - Header labels are c1..cN; row labels are r1..rN.
    - Each cell has width 3 and is left-aligned.

    Raises
    ------
    ValueError
        - Unsupported size (must be 4, 6, or 9).
        - Any row has a token count different from `size`.
        - Number of puzzle rows differs from `size`.
        - `puzzle_str` is empty or contains only whitespace.
    """
    ALLOWED_SIZES = {4, 6, 9}

    def fmt_cell(s: str) -> str:
        """Left-align a token within width 3."""
        return f"{s:<3}"

    processed_puzzle_str = puzzle_str.split("\n\n")[0]

    # Preprocess: strip lines and split into tokens
    raw_lines: List[str] = [ln.strip() for ln in processed_puzzle_str.strip().splitlines() if ln.strip()]
    if not raw_lines:
        raise ValueError("puzzle_str is empty.")

    token_lines: List[List[str]] = [ln.split() for ln in raw_lines]

    # Infer/validate size
    n = len(token_lines[0])

    if n not in ALLOWED_SIZES:
        split_num = None
        for allowed_size in ALLOWED_SIZES:
            if n > allowed_size:
                split_num = allowed_size
        if split_num:
            token_lines = token_lines[:split_num]


    row_lengths = [len(r) for r in token_lines]
    if any(l != n for l in row_lengths):
        raise ValueError(
            f"Every row must have exactly {n} tokens. (row token counts: {row_lengths})"
        )

    if len(token_lines) != n:
        raise ValueError(
            f"Number of rows ({len(token_lines)}) must match the size {n}."
        )

    # Normalize tokens (handle blanks)
    BLANK_TOKENS = {".", "0", "_"}

    def normalize_token(t: str) -> str:
        """Map blank-like tokens to '.' or empty string depending on keep_dots."""
        if t in BLANK_TOKENS:
            return "." if keep_dots else ""
        return t

    if board_format == "base":
        # Header
        header = "    " + "".join(fmt_cell(f"c{i}") for i in range(1, n + 1))
        lines = [header]

        # Body
        for i, row_tokens in enumerate(token_lines, start=1):
            normalized = [normalize_token(t) for t in row_tokens]
            row_label = f"r{i}"
            row = f"{row_label:<3} " + "".join(fmt_cell(t) for t in normalized)
            lines.append(row)
    elif board_format == "position":
        lines = []
    else:
        raise ValueError(f"Invalid board format: {board_format}")

    if len(puzzle_str.split("\n\n")) > 1:
        additional_info_str = "\n\n".join(puzzle_str.split("\n\n")[1:])
        observation_str = "\n".join(lines) + "\n\n" + additional_info_str.strip()
    else:
        observation_str = "\n".join(lines)

    return observation_str