import logging
import os
from typing import Any, List

from rllm.agents.agent import Action, Step, Trajectory
from rllm.agents.game_agents.base import BaseGameAgent, SYSTEM_PROMPT_TEMPLATE, USER_PROMPT_TEMPLATE, INTRO_USER_PROMPT_TEMPLATE

logger = logging.getLogger(__name__)

ONLY_ONE_STEP_ROLE = """You are a professional sudoku solver. You are given a sudoku board and you need to solve it entirely in one step."""

ONLY_ONE_ACTION_ROLE = """You are a professional sudoku solver. You are given a sudoku board and you need to solve it. You can solve the puzzle over multiple turns, so it's not necessary to outline the full solution at once."""

def _parse_only_k_action_env() -> int | None:
    v = os.environ.get("ONLY_K_ACTION", None)
    if v is None:
        return None
    try:
        k = int(v)
        return k if k > 0 else None
    except Exception:
        return None

def _only_k_action_role(k: int) -> str:
    return (
        "You are a professional sudoku solver. You are given a sudoku board and you need to solve it. "
        "You can solve the puzzle over multiple turns, so it's not necessary to outline the full solution at once.\n"
        f"IMPORTANT: In each turn, you should return exactly {k} action(s).\n"
        f"EXCEPTION: If the number of remaining empty cells is less than {k}, "
        f"you may return exactly that remaining number of action(s)."
    )

def _only_k_action_output_requirements(k: int) -> str:
    return f"""1. Thought:
Provide a detailed, step-by-step reasoning process explaining your thought process in solving the task.
2. Reason:
This will be used as a short-term memory in the next turn after the Thought section is removed. Summarize ALL key logical information needed to continue solving:
- The decided cell and digit (if any)
- The constraints involved (row / column / box)
- Any partially solved structures (e.g., pairs, triples, remaining candidates)
Reason does not exceed 8 sentences.
3. Action:
List exactly {k} concrete solving actions you want to take, each on its own line, wrapped in triple backticks.
EXCEPTION: If the number of remaining empty cells is less than {k}, you may list exactly that remaining number of actions.

### Output Format
<think>
[Your thought process in solving the task.]
</think>
REASON: [Summary of your reasoning process to maintain the logic consistency.]
ACTION: ```
[exactly {k} actions, OR fewer only if remaining empty cells < {k}]
```
"""

def _only_k_action_action_space(k: int) -> str:
    return f"""1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules).
- Used only when the cell's value is fully deduced.

instruction: You should output exactly {k} action(s) per turn.
exception: If the number of remaining empty cells is less than {k}, you may output exactly that remaining number of action(s)."""

ONLY_ONE_ACTION_ACTION_SPACE = """1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules).
- You can use only ONE action at a time."""

ONLY_ONE_STEP_ACTION_SPACE = """1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules)."""

ONLY_VALUE_ACTION_ACTION_SPACE = """1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules).
- Used only when the cell's value is fully deduced."""

ONLY_ONE_ACTION_OUTPUT_REQUIREMENTS = """1. Thought:
Provide a detailed, step-by-step reasoning process explaining your thought process in solving the task.
2. Reason:
This will be used as a short-term memory in the next turn after the Thought section is removed. Summarize ALL key logical information needed to continue solving:
- The decided cell and digit
- The constraints involved (row / column / box)
- Any partially solved structures (e.g., pairs, triples, remaining candidates)
Reason does not exceed 8 sentences.
3. Action:
List the concrete solving action you want to take, wrapped in triple backticks.

### Output Format
You must generate your thought, reason and action in the following format:
<think>
[Your thought process in solving the task.]
</think>
REASON: [Summary of your reasoning process to maintain the logic consistency.]
ACTION: ```
[Your action]
```
"""

K_MAX_ACTION_ACTION_SPACE = """1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules)."""

class SudokuAgent(BaseGameAgent):
    """
    Sudoku Agent class
    """
    role = "You are a professional sudoku solver. You are given a sudoku board and you need to solve it. You can solve the puzzle over multiple turns, so it's not necessary to outline the full solution at once."

    format_explanation = """Coordinates:
- We will use rXcY coordinates. For example, r1c1 is the top-left cell at row 1 column 1, r1c2 is the cell to the right at row 1 column 2, r2c1 is the cell below at row 2 column 1, and so on.

Representation of the board:
- Initial board values (given numbers) are represented as value (e.g., 4, 7).
- Empty cells are represented as '.' (e.g., '.') and your value will be placed in the cell (e.g., '3').

Visual Elements:
- Any visual elements will be described in text using rXcY coordinates.
- Please note the visual elements will be described as-is. If a thermo or arrow appears on the board, the location of the circle or bulb will be listed, and the line or arrow will be listed as a separate object. But you can infer they are part of the same object by their coordinates.
- If a visual element is described as "between" two cells, it means the visual element appears on the edge between the two cells.
- In some puzzles there may be visual elements outside of the grid and these will be described using the same coordinate system. For example an arrow in r0c1 pointing to the lower right means there is an arrow above r1c1 that points in the direction of the diagonal: r1c2, r2c3, etc.
- When "None" is given, it is a standard sudoku board, so you should only place digits in the empty cells."""

    action_space = """1. Value Setting: ```value(digit, rXcY)```
- Assign a confirmed digit (1–9) to a specific cell.
- Must obey all Sudoku constraints (row/column/box/rules).
- Used only when the cell's value is fully deduced.
2. Candidate Management: ```candidate('+', digit, rXcY)```, ```candidate('-', digit, rXcY)```
- Add or remove a candidate digit from a cell. 
- Adding a candidate ('+'):  Use this to explicitly mark that a digit is currently possible in a cell.
- Removing a candidate ('-'): Use this to explicitly mark that a digit is currently not possible in a cell.
  - If the candidate exists in the cell's candidate list, it will be removed.
  - If the candidate does not exist in the cell's candidate list, it will be marked as an "Impossible Candidate" for that cell."""

    goal = "Complete the sudoku board based on the rules and visual elements."

    output_requirements = """1. Thought:
Provide a detailed, step-by-step reasoning process explaining your thought process in solving the task.
2. Reason:
Give a concise explanation summarizing the key logic behind your action(s).
3. Action:
List the concrete solving actions you want to take, each on its own line, wrapped in triple backticks.

### Output Format
You must generate your thought, reason and action in the following format:
<think>
[Your thought process in solving the task.]
</think>
REASON: [Your reason for the action(s)]
ACTION: ```
[one or more actions, each on its own line]
```
"""
    def __init__(self, max_steps: int = 30, use_accumulate_thinking: bool = False, history_window: int | None = None, use_multi_turn_format: bool = True, additional_info_path: str = None, board_format: str = "base"):
        self._trajectory = Trajectory()
        self.messages = []
        self.only_k_action = _parse_only_k_action_env()
        self.step: int = 0
        self.use_accumulate_thinking = use_accumulate_thinking  # controlls whether to accumulate the thinking portion of the response
        self.max_steps = max_steps
        self.history_window = history_window
        self.use_multi_turn_format = use_multi_turn_format  # reasoning models have good performace with single-turn format
        self.additional_info_path = additional_info_path
        self.board_format = board_format
        self.sub_task_type = None
        # state
        self.current_observation = None
        self.additional_info = None

        self.reset()

    def _make_system_prompt(self, task_info: dict) -> str:
        # envvar ONLY_K_ACTION overrides add_info if set
        if self.only_k_action is not None:
            role_text = _only_k_action_role(self.only_k_action)
            output_requirements = _only_k_action_output_requirements(self.only_k_action)
        elif task_info.get("add_info", "None") == "only_one_action":
            role_text = ONLY_ONE_ACTION_ROLE
            output_requirements = ONLY_ONE_ACTION_OUTPUT_REQUIREMENTS
        elif task_info.get("add_info", "None") == "only_one_step":
            role_text = ONLY_ONE_STEP_ROLE
            output_requirements = self.output_requirements
        else:
            role_text = self.role
            output_requirements = self.output_requirements

        role_text += f"\n\n## Format Explanation\n{self.format_explanation}\n"

        if task_info.get("visual_elements", "None") == "None":
            self.sub_task_type = "standard_sudoku"
        else:
            self.sub_task_type = "ctc_sudoku"

        # print(f"task_info: {task_info}")
        enhanced_knowledge = ""
        if self.additional_info:
            guidance = self.additional_info.get("guidance", "")
            enhanced_knowledge += f"[General Guidance]\nThe following provides high-level problem-solving principles and preferences. Use them implicitly to guide your reasoning and decisions, without referring to them explicitly.\n{guidance}\n"
        
        if task_info.get("hint", None) is not None and task_info.get("add_info", "None") == "use_hint":
            role_text += f"[Hint]\nThe following is task-specific contextual information that may help resolve ambiguities or guide the solution path. At each step, selectively incorporate only the information that is necessary for that step and incorporate it silently into your reasoning.\n{task_info['hint']}\n"
            
        
        if enhanced_knowledge != "":
            role_text += f"\n<tips>\nYou may use the following tips to solve the task.\nCRITICAL: Do NOT mention, quote, or refer to this guidance in the final answer (Do NOT say 'based on the hint/tip/guidance/above').\nHere are some tips for solving the task:\n{enhanced_knowledge}\n</tips>"


        system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
            role=role_text,
            output_requirements=output_requirements,
        )
        return system_prompt
    
    def _make_init_user_prompt(self, task_info: dict) -> str:
        # envvar ONLY_K_ACTION overrides add_info if set
        if self.only_k_action is not None:
            action_space = _only_k_action_action_space(self.only_k_action)
        elif task_info.get("add_info", "None") == "only_one_action":
            action_space = ONLY_ONE_ACTION_ACTION_SPACE
        elif task_info.get("add_info", "None") == "only_one_step":
            action_space = ONLY_ONE_STEP_ACTION_SPACE
        elif task_info.get("add_info", "None") == "only_value_action":
            action_space = ONLY_VALUE_ACTION_ACTION_SPACE
        elif os.getenv("MAX_K_ACTION", None) != None:
            action_space = K_MAX_ACTION_ACTION_SPACE
        else:
            action_space = self.action_space

        environment_info_str = f"""### Available Actions
{action_space}

### Goal
{self.goal}

### Size
{task_info['rows']}x{task_info['cols']}

### Rules
{task_info['rules']}

### Visual Elements
{task_info['visual_elements']}
"""
        user_prompt = INTRO_USER_PROMPT_TEMPLATE.format(
            environment_info=environment_info_str,
        )
        return user_prompt

    def _make_user_prompt(self, observation: Any) -> str:
        user_prompt = USER_PROMPT_TEMPLATE.format(
            current_observation=observation["observation"],
        )
        return user_prompt

    def _process_observation(self, observation: Any) -> str:
        """
        Process observation from environment.
        """
        if isinstance(observation, dict):
            observation_str = reformat_observation(observation.get("board_spaced", ""), keep_dots= True, board_format=self.board_format)
            env_message = observation.get("env_message", None)
            if env_message:
                observation_str += f"\n\n- Error message from environment:\n{env_message}"
            return {
                "observation": observation_str,
                "observation_info": {
                    "board_ascii": observation.get("board_ascii", ""),
                    "board_spaced": observation.get("board_spaced", ""),
                    "board_tokens": observation.get("board_tokens", ""),
                    "current_turn": observation.get("current_turn", ""),
                    "max_turns": observation.get("max_turns", ""),
                    "progress": observation.get("progress", ""),
                    "correct_cells": observation.get("correct_cells", ""),
                    "total_cells": observation.get("total_cells", ""),
                }
            }
        elif isinstance(observation, str):
            return {
                "observation": observation,
            }
        else:
            raise ValueError(f"Invalid observation type: {type(observation)}")


def reformat_observation(puzzle_str: str, keep_dots: bool = True, board_format: str = "base") -> str:
    """
    Format a variable-size Sudoku board (4x4, 6x6, 9x9) into a fixed-width, left-aligned string.

    Parameters
    ----------
    puzzle_str : str
        Multiline string of whitespace-separated tokens (digits, '.', '0', '_').
        Example:
            "1 . 3 .\n. 2 . 4\n3 . . .\n. . . ."
    keep_dots : bool
        Whether to keep blanks as '.' when rendering.
        - True  -> output '.'
        - False -> render as spaces (width 3)
    size : Optional[int]
        Board size (one of {4, 6, 9}). If not provided, it is inferred from the number of tokens
        in the first non-empty line.

    Rules
    -----
    - Blank tokens in {'.', '0', '_'} are treated as empty.
    - Header labels are c1..cN; row labels are r1..rN.
    - Each cell has width 3 and is left-aligned.

    Raises
    ------
    ValueError
        - Unsupported size (must be 4, 6, or 9).
        - Any row has a token count different from `size`.
        - Number of puzzle rows differs from `size`.
        - `puzzle_str` is empty or contains only whitespace.
    """
    ALLOWED_SIZES = {4, 6, 9}

    def fmt_cell(s: str) -> str:
        """Left-align a token within width 3."""
        return f"{s:<3}"

    processed_puzzle_str = puzzle_str.split("\n\n")[0]

    # Preprocess: strip lines and split into tokens
    raw_lines: List[str] = [ln.strip() for ln in processed_puzzle_str.strip().splitlines() if ln.strip()]
    if not raw_lines:
        raise ValueError("puzzle_str is empty.")

    token_lines: List[List[str]] = [ln.split() for ln in raw_lines]

    # Infer/validate size
    n = len(token_lines[0])

    if n not in ALLOWED_SIZES:
        split_num = None
        for allowed_size in ALLOWED_SIZES:
            if n > allowed_size:
                split_num = allowed_size
        if split_num:
            token_lines = token_lines[:split_num]


    row_lengths = [len(r) for r in token_lines]
    if any(l != n for l in row_lengths):
        raise ValueError(
            f"Every row must have exactly {n} tokens. (row token counts: {row_lengths})"
        )

    if len(token_lines) != n:
        raise ValueError(
            f"Number of rows ({len(token_lines)}) must match the size {n}."
        )

    # Normalize tokens (handle blanks)
    BLANK_TOKENS = {".", "0", "_"}

    def normalize_token(t: str) -> str:
        """Map blank-like tokens to '.' or empty string depending on keep_dots."""
        if t in BLANK_TOKENS:
            return "." if keep_dots else ""
        return t

    if board_format == "base":
        # Header
        header = "    " + "".join(fmt_cell(f"c{i}") for i in range(1, n + 1))
        lines = [header]

        # Body
        for i, row_tokens in enumerate(token_lines, start=1):
            normalized = [normalize_token(t) for t in row_tokens]
            row_label = f"r{i}"
            row = f"{row_label:<3} " + "".join(fmt_cell(t) for t in normalized)
            lines.append(row)
    elif board_format == "position":
        lines = []
    else:
        raise ValueError(f"Invalid board format: {board_format}")

    if len(puzzle_str.split("\n\n")) > 1:
        additional_info_str = "\n\n".join(puzzle_str.split("\n\n")[1:])
        observation_str = "\n".join(lines) + "\n\n" + additional_info_str.strip()
    else:
        observation_str = "\n".join(lines)

    return observation_str