import logging
from typing import Any, List

from rllm.agents.agent import Action, Step, Trajectory
from rllm.agents.game_agents.base import BaseGameAgent, SYSTEM_PROMPT_TEMPLATE, USER_PROMPT_TEMPLATE, INTRO_USER_PROMPT_TEMPLATE

logger = logging.getLogger(__name__)


class TowerOfHanoiAgent(BaseGameAgent):
    """
    Tower of Hanoi Agent class
    """
    role = "You are a professional Tower of Hanoi solver. You are given a Tower of Hanoi board and you need to solve it."

    format_explanation = """Representation of the board:
- The disks are named from A to Z(given last alphabet), where A is the smallest disk and Z is the largest disk.
- The pegs are numbered from 1 to given number of pegs."""

    action_space = """1. Move: ```move(disk_name, from_peg, to_peg)```
- Move a disk from one peg to another. (e.g. move('A', 1, 3))
- disk_name: name of the disk to move (e.g. 'A')
- from_peg: starting peg number
- to_peg: target peg number"""

    goal = "Complete the Tower of Hanoi board based on the rules."

    output_requirements = """1. Thought:
Provide a detailed, step-by-step reasoning process explaining your thought process in solving the task.
2. Reason:
Give a concise explanation summarizing the key logic behind your action(s).
3. Action:
List the concrete solving actions you want to take, each on its own line, wrapped in triple backticks.

### Output Format
You must generate your thought, reason and action in the following format:
<think>
[Your thought process in solving the task.]
</think>
REASON: [Your reason for the action(s)]
ACTION: ```
[one or more actions, each on its own line]
```
"""
    def __init__(self, max_steps: int = 30, use_accumulate_thinking: bool = False, history_window: int | None = None, use_multi_turn_format: bool = True, additional_info_path: str = None, board_format: str = "base"):
        self._trajectory = Trajectory()
        self.messages = []
        self.step: int = 0
        self.use_accumulate_thinking = use_accumulate_thinking  # controlls whether to accumulate the thinking portion of the response
        self.max_steps = max_steps
        self.history_window = history_window
        self.use_multi_turn_format = use_multi_turn_format  # reasoning models have good performace with single-turn format
        self.additional_info_path = additional_info_path

        # state
        self.current_observation = None
        self.additional_info = None

        self.reset()

    def _make_system_prompt(self, task_info: dict) -> str:
        role_text = self.role + f"\n\n## Format Explanation\n{self.format_explanation}\n"

        system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
            role=role_text,
            output_requirements=self.output_requirements,
        )
        return system_prompt
    
    def _make_init_user_prompt(self, task_info: dict) -> str:
        environment_info_str = f"""### Available Actions
{self.action_space}

### Goal
Solve the Tower of Hanoi problem with {task_info['num_disks']} disks and {task_info['num_pegs']} pegs.
Move all disks from peg {task_info['start_peg']} to peg {task_info['target_peg']} following the rules:
- Only one disk can be moved at a time.
- A larger disk cannot be placed on top of a smaller disk.
- All disks must be on a peg at all times.
"""
        user_prompt = INTRO_USER_PROMPT_TEMPLATE.format(
            environment_info=environment_info_str,
        )
        return user_prompt

    def _make_user_prompt(self, observation: Any) -> str:
        user_prompt = USER_PROMPT_TEMPLATE.format(
            current_observation=observation["observation"],
        )
        return user_prompt

    def _process_observation(self, observation: Any) -> str:
        """
        Process observation from environment.
        """

        observation_str = observation.get("state_string", "")
        env_message = observation.get("env_message", None)
        if env_message:
            observation_str += f"\n\n- Error message from environment:\n{env_message}"
        return {
            "observation": observation_str,
            "observation_info": {
                "current_turn": observation.get("current_turn", ""),
                "max_turns": observation.get("max_turns", ""),
                "progress": observation.get("progress", ""),
            }
        }