from typing import Any, Dict, Tuple, Optional, List
import copy
import re
import random

from rllm.environments.base.base_env import BaseEnv
from rllm.environments.games.tower_of_hanoi_utils import TowerOfHanoiState, TowerOfHanoiAction


class TowerOfHanoiEnv(BaseEnv):
    """
    Multi-turn environment for Tower of Hanoi puzzles.
    Provides an environment where an LLM solves Tower of Hanoi puzzles through actions.
    """
    parsing_error_penalty = 1.0
    parsing_duplicate_penalty = 1.0
    parsing_invalid_penalty = 1.0
    invalid_action_penalty = 0.25
    wrong_action_penalty = 0.5
    # step-level reward
    format_reward = 0.5
    valid_action_reward = 0.5
    progress_reward = 0.5

    # trajectory reward
    success_reward = 1.0

    def __init__(
        self, 
        task: dict,
    ):
        """
        Initialize the Tower of Hanoi multi-turn environment.

        Args:
            task: Task dictionary containing Tower of Hanoi task information
                  - "num_disks": Number of disks (default: 3)
                  - "num_pegs": Number of pegs (default: 3)
                  - "start_peg": Starting peg number (default: 1)
                  - "target_peg": Target peg number (default: 3)
                  - "auxiliary_pegs": List of auxiliary peg numbers (optional)
                  - "max_turns": Maximum number of turns (default: 30)
        """
        self.task = task
        self.max_turns = task.get("max_turns", 30)
        self.strict_termination = task.get("strict_termination", False)  # for training, we use strict termination

        # State tracking for detecting loops
        self.current_state: Optional[TowerOfHanoiState] = None
        self.initial_state: Optional[TowerOfHanoiState] = None
        self.state_history = []

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0

        # terminate conditions
        if self.strict_termination:  # for training
            self.max_consecutive_same_states = 2
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
        else:
            self.max_consecutive_same_states = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20

        self._initialize_from_task(task)

    def _initialize_from_task(self, task: dict):
        """Initialize the environment from a task dictionary."""
        # Get number of disks and pegs
        self.num_disks = task.get("num_disks", 3)
        self.num_pegs = task.get("num_pegs", 3)
        self.start_peg = task.get("start_peg", 1)
        self.target_peg = task.get("target_peg", self.num_pegs)
        self.auxiliary_pegs = task.get("auxiliary_pegs", [])
        
        # Validate pegs
        if self.start_peg < 1 or self.start_peg > self.num_pegs:
            raise ValueError(f"Invalid start_peg: {self.start_peg}. Should be between 1 and {self.num_pegs}")
        if self.target_peg < 1 or self.target_peg > self.num_pegs:
            raise ValueError(f"Invalid target_peg: {self.target_peg}. Should be between 1 and {self.num_pegs}")
        if self.start_peg == self.target_peg:
            raise ValueError("start_peg and target_peg must be different")
        
        # Set initial state
        self.initial_state = TowerOfHanoiState(
            num_disks=self.num_disks,
            num_pegs=self.num_pegs,
            start_peg=self.start_peg,
            target_peg=self.target_peg
        )

    def reset(self):
        """Reset the environment to the initial state."""
        self.done = False
        self.current_turn = 0
        self.history = []
        self.state_history = []

        # terminate conditions
        if self.strict_termination:  # for training
            self.max_consecutive_same_states = 3
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
        else:
            self.max_consecutive_same_states = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 10

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0
        
        if self.initial_state:
            self.current_state = self.initial_state.copy()
            self.initial_progress = self._calculate_progress()
        
        # Return the first observation
        observation = self._get_observation()
        info = self._get_info()
        
        return observation, info

    def _get_observation(self) -> dict:
        """Return the observation of the current environment state."""
        if not self.current_state:
            return {}
        
        observation = {
            "state_ascii": self.current_state.to_ascii(),
            "state_string": self.current_state.to_string(),
            "current_turn": self.current_turn,
            "max_turns": self.max_turns,
            "env_message": self.env_message,
            "num_disks": self.num_disks,
            "num_pegs": self.num_pegs,
            "start_peg": self.start_peg,
            "target_peg": self.target_peg,
            "auxiliary_pegs": self.auxiliary_pegs,
            "disk_names": self.current_state.disk_names,
        }

        if self.current_turn == 0:
            observation["task_info"] = self.task
        observation["progress"] = 0.0   # do not use progress for tower of hanoi
        return observation

    def _get_info(self) -> dict:
        """Return additional information."""
        info = {
            "turn": self.current_turn,
            "max_turns": self.max_turns,
            "num_disks": self.num_disks,
            "num_pegs": self.num_pegs,
            "start_peg": self.start_peg,
            "target_peg": self.target_peg,
            "auxiliary_pegs": self.auxiliary_pegs,
            "disk_names": self.current_state.disk_names,
        }
        
        if self.current_state:
            info["is_complete"] = self._is_puzzle_complete()
        
        return info

    def step(self, action: Any) -> tuple[dict, dict, bool, dict]:
        """
        Take a step in the environment based on the action.
        """
        # Store the action in history
        self.history.append(action)

        # Calculate reward for the current turn using the abstract method
        assert self.task is not None, "Task is not set"

        # Increment turn counter
        self.current_turn += 1

        reward, next_obs = self.get_reward_and_next_obs(self.task, action)

        if self.current_state and not self.done:
            current_state_str = self.current_state.to_string()
            self.state_history.append(current_state_str)
            
            # check if the same state has appeared consecutively
            if self._check_consecutive_same_states():
                self.done = True
                reward["step"] = 0.0
                self.termination_reason = "STATE_UNCHANGED"
            
            # check if puzzle is complete
            if self._is_puzzle_complete():
                self.done = True
                self.termination_reason = "PUZZLE_COMPLETE"

        # check if we've reached the maximum number of turns
        if self.current_turn >= self.max_turns:
            self.done = True
            if self.termination_reason is None:
                self.termination_reason = "MAX_TURNS"
        
        if self.done:
            next_obs["termination_reason"] = self.termination_reason

        return next_obs, reward, self.done, self.task

    def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[dict, dict]:
        """
        Calculate reward and next observation for the given action.

        Args:
            task: Task dictionary
            action: LLM's action (string or dictionary)

        Returns:
            (reward, next_observation) tuple
        """
        _, action_list = self._parse_action(action)
        if len(action_list) == 0:
            self.max_parsing_error -= 1
            step_reward = 0
        else:
            step_reward = self.format_reward

            error_messages = []
            for hanoi_action in action_list:
                try:
                    # Execute action
                    success = self.current_state.move_disk(
                        hanoi_action.disk_name,
                        hanoi_action.from_peg,
                        hanoi_action.to_peg
                    )
                    if not success:
                        error_msg = f"Cannot move {hanoi_action.disk_name} from peg {hanoi_action.from_peg} to peg {hanoi_action.to_peg}"
                        error_messages.append(error_msg)
                        self.max_invalid_actions -= 1
                        if self.max_invalid_actions <= 0:
                            break
                except Exception as e:
                    self.max_invalid_actions -= 1
                    if self.max_invalid_actions <= 0:
                        break
                    error_messages.append(str(e))
                    print(e)
                    # continue to the next action
                    continue
            
            if len(error_messages) > 0:
                self.env_message = "\n".join(error_messages)
            else:
                self.env_message = None
                step_reward += self.valid_action_reward

        next_obs = self._get_observation()
        # Reward for completing the puzzle
        traj_reward, _ = self._calculate_reward(next_obs.get("progress", 0))
        
        reward = {
            "step": step_reward,
            "traj": traj_reward,
        }
        
        if self.max_invalid_actions <= 0:
            self.done = True
            self.termination_reason = "MAX_INVALID_ACTIONS"

        if self.max_parsing_error <= 0:
            self.done = True
            self.termination_reason = "MAX_PARSING_ERROR"
        
        return reward, next_obs

    def _parse_action(self, action: Any) -> tuple[float, list[TowerOfHanoiAction]]:
        """Parse the action into a TowerOfHanoiAction object."""
        if self.strict_termination:
            required_markers = ["</think>", "REASON", "ACTION"]
            if action.strip()[-3:] != "```":
                return self.parsing_error_penalty, []
            # redundant format penalty
            think_num = action.count("</think>")
            reason_num = action.count("REASON")
            action_num = action.count("ACTION")
            if think_num > 1 or reason_num > 1 or action_num > 1:
                return self.parsing_duplicate_penalty, []
        else:
            required_markers = ["ACTION"]
        
        if any(marker not in action for marker in required_markers):
            return self.parsing_error_penalty, []

        pattern = re.compile(
            r"""
            ACTION:\s*
            (?:                                     # Option A: fenced code block
                ```(?:[a-zA-Z0-9_-]+)?\s*          # opening fence with optional language
                (?P<fenced>.*?)                    # content inside fences (non-greedy)
                \s*```                             # closing fence
            |
                (?P<inline>.+)                     # Option B: inline content until end
            )
            """,
            re.DOTALL | re.VERBOSE,
        )

        m = pattern.search(action)
        if not m:
            return self.parsing_invalid_penalty, []

        content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
        # Normalize lines: strip blanks and ignore empty lines
        lines = [ln.strip() for ln in content.strip().splitlines() if ln.strip()]

        parsed_actions: List[TowerOfHanoiAction] = []
        for line in lines:
            try:
                parsed_actions.append(TowerOfHanoiAction.from_serialized(line.lower()))
            except Exception:
                return self.parsing_invalid_penalty, []
        
        if len(parsed_actions) == 0:
            return self.parsing_invalid_penalty, []

        return 0.0, parsed_actions

    def _calculate_progress(self) -> float:
        """do not use progress for tower of hanoi"""
        return 0.0

    def _calculate_reward(self, progress: float) -> tuple[float, float]:
        """Calculate reward for the action."""

        if self._is_puzzle_complete():
            traj_reward = self.success_reward
            self.done = True
            self.termination_reason = "PUZZLE_COMPLETE"
        else:
            traj_reward = 0.0

        progress_reward = 0.0

        return traj_reward, progress_reward

    def _is_puzzle_complete(self) -> bool:
        """Check if the puzzle is complete."""
        if not self.current_state:
            return False
        
        return self.current_state.is_complete()

    def _check_consecutive_same_states(self) -> bool:
        """
        Check if the same state has appeared consecutively for a specified number of times.
        
        Returns:
            bool: True if the same state has appeared consecutively for max_consecutive_same_states times
        """
        if len(self.state_history) < self.max_consecutive_same_states:
            return False
        
        # Check the last max_consecutive_same_states states
        recent_states = self.state_history[-self.max_consecutive_same_states:]
        
        # Check if all states are the same
        return all(state == recent_states[0] for state in recent_states)

    @staticmethod
    def from_dict(env_args: dict) -> "TowerOfHanoiEnv":
        """Generate a TowerOfHanoiEnv from a dictionary."""
        return TowerOfHanoiEnv(task=env_args)
