from typing import Any, Dict, Tuple, Optional, List
import copy
import re
import random

from rllm.environments.base.base_env import BaseEnv
from rllm.environments.games.rush_hour_utils import RushHourState, RushHourAction


class RushHourEnv(BaseEnv):
    """
    Multi-turn environment for Rush Hour puzzles.
    Provides an environment where an LLM solves Rush Hour puzzles through actions.
    """
    parsing_error_penalty = 1.0
    parsing_duplicate_penalty = 1.0
    parsing_invalid_penalty = 1.0
    invalid_action_penalty = 0.25
    wrong_action_penalty = 0.5
    # step-level reward
    format_reward = 0.5
    valid_action_reward = 0.5
    progress_reward = 0.5

    # trajectory reward
    success_reward = 1.0

    def __init__(
        self, 
        task: dict,
    ):
        """
        Initialize the Rush Hour multi-turn environment.

        Args:
            task: Task dictionary containing Rush Hour task information
                  - "board_config": Board configuration string
                  - "rows": Number of rows (default: 6)
                  - "cols": Number of columns (default: 6)
                  - "target_car": Target car letter (default: 'A' for red car)
                  - "max_turns": Maximum number of turns (default: 30)
        """
        self.task = task
        self.max_turns = task.get("max_turns", 30)
        self.strict_termination = task.get("strict_termination", False)

        # State tracking for detecting loops
        self.current_state: Optional[RushHourState] = None
        self.initial_state: Optional[RushHourState] = None
        self.state_history = []

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0

        # terminate conditions
        if self.strict_termination:
            self.max_consecutive_same_states = 2
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
        else:
            self.max_consecutive_same_states = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20

        self._initialize_from_task(task)

    def _initialize_from_task(self, task: dict):
        """Initialize the environment from a task dictionary."""
        # Get board configuration
        self.board_config = task.get("board_config")
        if not self.board_config:
            raise ValueError("board_config must be provided in the task")
        
        self.rows = task.get("rows", 6)
        self.cols = task.get("cols", 6)
        self.target_car = task.get("target_car", "A")
        
        # Set initial state
        self.initial_state = RushHourState(
            board_config=self.board_config,
            rows=self.rows,
            cols=self.cols
        )

    def reset(self):
        """Reset the environment to the initial state."""
        self.done = False
        self.current_turn = 0
        self.history = []
        self.state_history = []

        # terminate conditions
        if self.strict_termination:
            self.max_consecutive_same_states = 3
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
        else:
            self.max_consecutive_same_states = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 10

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0
        self.incorrect_moves_num = 0
        
        if self.initial_state:
            self.current_state = self.initial_state.copy()
            self.initial_progress = self._calculate_progress()
        
        # Return the first observation
        observation = self._get_observation()
        info = self._get_info()
        
        return observation, info

    def _get_observation(self) -> dict:
        """Return the observation of the current environment state."""
        if not self.current_state:
            return {}
        
        observation = {
            "board_ascii": self.current_state.to_ascii(),
            "board_string": self.current_state.to_string(),
            "current_turn": self.current_turn,
            "max_turns": self.max_turns,
            "env_message": self.env_message,
            "target_car": self.target_car,
            "rows": self.rows,
            "cols": self.cols,
        }

        if self.current_turn == 0:
            observation["task_info"] = self.task
        
        # Add progress information
        progress = self._calculate_progress()
        observation["progress"] = progress
        
        return observation

    def _get_info(self) -> dict:
        """Return additional information."""
        info = {
            "turn": self.current_turn,
            "max_turns": self.max_turns,
            "board_size": f"{self.rows}x{self.cols}",
        }
        
        if self.current_state:
            info["is_complete"] = self._is_puzzle_complete()
        
        return info

    def step(self, action: Any) -> tuple[dict, dict, bool, dict]:
        """
        Take a step in the environment based on the action.
        """
        # Store the action in history
        self.history.append(action)

        # Calculate reward for the current turn using the abstract method
        assert self.task is not None, "Task is not set"

        # Increment turn counter
        self.current_turn += 1

        reward, next_obs = self.get_reward_and_next_obs(self.task, action)

        if self.current_state and not self.done:
            current_state_str = self.current_state.to_string()
            self.state_history.append(current_state_str)
            
            # check if the same state has appeared consecutively
            if self._check_consecutive_same_states():
                self.done = True
                reward["step"] = 0.0
                self.termination_reason = "STATE_UNCHANGED"
            
            # check if puzzle is complete
            if self._is_puzzle_complete():
                self.done = True
                self.termination_reason = "PUZZLE_COMPLETE"

        # check if we've reached the maximum number of turns
        if self.current_turn >= self.max_turns:
            self.done = True
            if self.termination_reason is None:
                self.termination_reason = "MAX_TURNS"
        
        if self.done:
            next_obs["termination_reason"] = self.termination_reason

        return next_obs, reward, self.done, self.task

    def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[dict, dict]:
        """
        Calculate reward and next observation for the given action.

        Args:
            task: Task dictionary
            action: LLM's action (string or dictionary)

        Returns:
            (reward, next_observation) tuple
        """
        _, action_list = self._parse_action(action)
        if len(action_list) == 0:
            self.max_parsing_error -= 1
            step_reward = 0
        else:
            step_reward = self.format_reward

            error_messages = []
            # tmp code: only execute the first action
            if len(action_list) > 1:
                action_list = action_list[:1]
            for rush_action in action_list:
                try:
                    # Execute action
                    success = self.current_state.move_vehicle(
                        rush_action.vehicle,
                        rush_action.direction,
                        rush_action.num_moves
                    )
                    if not success:
                        error_msg = f"Cannot move vehicle {rush_action.vehicle} {rush_action.num_moves} step(s) in direction {rush_action.direction}"
                        error_messages.append(error_msg)
                        self.max_invalid_actions -= 1
                        if self.max_invalid_actions <= 0:
                            break
                except Exception as e:
                    self.max_invalid_actions -= 1
                    if self.max_invalid_actions <= 0:
                        break
                    error_messages.append(str(e))
                    print(e)
                    # continue to the next action
                    continue
            
            if len(error_messages) > 0:
                self.env_message = "\n".join(error_messages)
            else:
                self.env_message = None
                step_reward += self.valid_action_reward

        next_obs = self._get_observation()
        # Reward for completing the puzzle
        traj_reward, progress_reward = self._calculate_reward(next_obs.get("progress", 0))

        reward = {
            "step": step_reward,
            "traj": traj_reward,
        }
        
        if self.max_invalid_actions <= 0:
            self.done = True
            self.termination_reason = "MAX_INVALID_ACTIONS"
        if self.max_parsing_error <= 0:
            self.done = True
            self.termination_reason = "MAX_PARSING_ERROR"
        
        return reward, next_obs

    def _parse_action(self, action: Any) -> tuple[float, list[RushHourAction]]:
        """Parse the action into a RushHourAction object."""
        if self.strict_termination:
            required_markers = ["</think>", "REASON", "ACTION"]
            if action.strip()[-3:] != "```":
                return self.parsing_error_penalty, []
            # redundant format penalty
            think_num = action.count("</think>")
            reason_num = action.count("REASON")
            action_num = action.count("ACTION")
            if think_num > 1 or reason_num > 1 or action_num > 1:
                return self.parsing_duplicate_penalty, []
        else:
            required_markers = ["ACTION"]
        
        if any(marker not in action for marker in required_markers):
            return self.parsing_error_penalty, []

        pattern = re.compile(
            r"""
            ACTION:\s*
            (?:                                     # Option A: fenced code block
                ```(?:[a-zA-Z0-9_-]+)?\s*          # opening fence with optional language
                (?P<fenced>.*?)                    # content inside fences (non-greedy)
                \s*```                             # closing fence
            |
                (?P<inline>.+)                     # Option B: inline content until end
            )
            """,
            re.DOTALL | re.VERBOSE,
        )

        m = pattern.search(action)
        if not m:
            return self.parsing_invalid_penalty, []

        content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
        # Normalize lines: strip blanks and ignore empty lines
        lines = [ln.strip() for ln in content.strip().splitlines() if ln.strip()]

        parsed_actions: List[RushHourAction] = []
        for line in lines:
            try:
                parsed_actions.append(RushHourAction.from_serialized(line))
            except Exception:
                return self.parsing_invalid_penalty, []
        
        if len(parsed_actions) == 0:
            return self.parsing_invalid_penalty, []

        return 0.0, parsed_actions

    def _calculate_progress(self) -> float:
        """Calculate progress for the current state."""
        
        return 0.0

    def _calculate_reward(self, progress: float) -> tuple[float, float]:
        """Calculate reward for the action."""

        if self._is_puzzle_complete():
            traj_reward = self.success_reward
            self.done = True
            self.termination_reason = "PUZZLE_COMPLETE"
        else:
            traj_reward = 0.0

        progress_reward = 0.0

        return traj_reward, progress_reward

    def _is_puzzle_complete(self) -> bool:
        """Check if the puzzle is complete."""
        if not self.current_state:
            return False
        
        return self.current_state.is_complete()

    def _check_consecutive_same_states(self) -> bool:
        """
        Check if the same state has appeared consecutively for a specified number of times.
        
        Returns:
            bool: True if the same state has appeared consecutively for max_consecutive_same_states times
        """
        if len(self.state_history) < self.max_consecutive_same_states:
            return False
        
        # Check the last max_consecutive_same_states states
        recent_states = self.state_history[-self.max_consecutive_same_states:]
        
        # Check if all states are the same
        return all(state == recent_states[0] for state in recent_states)

    @staticmethod
    def from_dict(env_args: dict) -> "RushHourEnv":
        """Generate a RushHourEnv from a dictionary."""
        return RushHourEnv(task=env_args)

