from typing import Any, Dict, Tuple, Optional, List

import copy
import re
import random
import os  # [ADD]

from rllm.environments.base.base_env import BaseEnv
from rllm.environments.games.sudoku_utils import SudokuBoard, SudokuAction, SudokuCell, ValueType, ActionType, OperationType


class SudokuEnv(BaseEnv):
    """
    Multi-turn environment for Sudoku puzzles.
    Provides an environment where an LLM solves Sudoku puzzles through actions.
    """
    parsing_error_penalty = 0.1
    parsing_duplicate_penalty = 0.1
    parsing_invalid_penalty = 1.0
    invalid_action_penalty = 0.1
    wrong_action_penalty = 0.1

    success_reward = 1.0

    def __init__(
        self, 
        task: dict,
    ):
        """
        Initialize the Sudoku multi-turn environment.

        Args:
            task: Task dictionary containing Sudoku task information
                  - "initial_board": Initial board state (string or SudokuBoard)
                  - "solution": Solution board (optional)
            max_turns: Maximum number of turns
            progress_reward_type: Reward for progress
        """
        self.task = task
        self.max_turns = task.get("max_turns", 30)
        self.progress_reward_type = task.get("progress_reward_type", "")
        self.strict_termination = task.get("strict_termination", False) # for training, we use strict termination
        self.only_one_action = task.get("only_one_action", False)

        # [ADD] ONLY_K_ACTION env var support
        self.only_k_action: Optional[int] = None
        _k = os.getenv("ONLY_K_ACTION")
        if _k is not None and str(_k).strip() != "":
            try:
                self.only_k_action = int(_k)
            except ValueError:
                self.only_k_action = None

        # [ADD] internal flag for "insufficient actions" case
        self._insufficient_actions = False

        # Board state tracking for detecting loops
        self.current_board: Optional[SudokuBoard] = None
        self.initial_board: Optional[SudokuBoard] = None
        self.solution_board: Optional[SudokuBoard] = None
        self.board_history = []

        self.env_message = None
        self.termination_reason = None
        self.prev_progress= 0.0

        # terminate conditions
        if self.strict_termination: # for training
            self.max_consecutive_same_boards = 3
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
            self.max_wrong_actions = 100
        else:
            self.max_consecutive_same_boards = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20
            self.max_wrong_actions = 1 # for early termination

            if os.getenv("ONLY_K_ACTION") is not None:
                self.max_parsing_error = 50
                self.max_invalid_actions = 30
        
        self._initialize_from_task(task)

    def _initialize_from_task(self, task: dict):
        """Initialize the environment from a task dictionary."""
        # Set board size
        if "rows" in task and "cols" in task:
            self.rows = task["rows"]
            self.cols = task["cols"]
            sr, sc = self._get_subgrid_dims()
            self.sr = sr
            self.sc = sc
        else:
            raise ValueError("rows and cols must be provided in the task")
        
        # extract difficulty level from task
        self.difficulty_level = task.get("difficulty", "default")
        # extract sub task type from task
        if task.get("visual_elements", "None") == "None":
            self.sub_task_type = "standard_sudoku"
        else:
            self.sub_task_type = "ctc_sudoku"
        
        # Set initial board
        initial_board_data = task.get("initial_board")
        if isinstance(initial_board_data, str):
            # Create board from string
            self.initial_board = SudokuBoard.from_ascii(
                initial_board_data, self.rows, self.cols
            )
        elif isinstance(initial_board_data, SudokuBoard):
            self.initial_board = initial_board_data
        elif isinstance(initial_board_data, dict):
            # Create board from dictionary
            self.initial_board = SudokuBoard.from_serialized(
                initial_board_data.get("state", "{}"),
                self.rows, self.cols,
                initial_board_data.get("givens", {})
            )

        # save the initial board coordinates
        self.initial_board_coordinates = [
            f"r{cell.row}c{cell.col}" for cell in self.initial_board.cells if cell.value != ValueType.empty
        ]
        self.initial_board.save_initial_board_coordinates(self.initial_board_coordinates)
        
        # Set solution board (optional)
        solution_data = task.get("solution")
        if isinstance(solution_data, str):
            self.solution_board = SudokuBoard.from_ascii(
                solution_data, self.rows, self.cols
            )
        elif isinstance(solution_data, SudokuBoard):
            self.solution_board = solution_data
        self.solution_board_ascii = self.solution_board.to_ascii()

    def reset(self):
        """Reset the environment to the initial state."""
        self.done = False
        self.current_turn = 0
        self.history = []
        self.board_history = []

        # terminate conditions
        if self.strict_termination: # for training
            self.max_consecutive_same_boards = 3
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
            self.max_wrong_actions = 100
        else:
            self.max_consecutive_same_boards = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20
            self.max_wrong_actions = 1 # for early termination
        
        import os
        if os.getenv("ONLY_K_ACTION") is not None:
            self.max_parsing_error = 50
            self.max_invalid_actions = 30


        self.env_message = None
        self.termination_reason = None
        self.prev_progress= 0.0
        
        if self.initial_board:
            self.current_board = copy.deepcopy(self.initial_board)
            self.total_cells = self.rows * self.cols
            self.initial_correct = self._count_correct_cells(self.initial_board)
            self.subgrid_results, self.initial_correct_subgrid_count = self._make_subgrid_results_dict(self.initial_board)
        
        # Return the first observation
        observation = self._get_observation()
        info = self._get_info()
        
        return observation, info

    def _get_observation(self) -> dict:
        """Return the observation of the current environment state."""
        if not self.current_board:
            return {}
        
        observation = {
            "board_ascii": self.current_board.to_ascii(),
            "board_spaced": self.current_board.to_spaced_ascii(),
            "board_tokens": self.current_board.to_string(),
            "current_turn": self.current_turn,
            "max_turns": self.max_turns,
            "env_message": self.env_message,
        }

        if self.current_turn == 0:
            observation["task_info"] = self.task
        
        # Add progress information
        if self.solution_board:
            progress, correct_cells = self._calculate_progress()
            observation["progress"] = progress
            observation["correct_cells"] = correct_cells
        
        return observation

    def _get_info(self) -> dict:
        """Return additional information."""
        info = {
            "turn": self.current_turn,
            "max_turns": self.max_turns,
            "board_size": f"{self.rows}x{self.cols}",
        }
        
        if self.current_board:
            info["is_complete"] = self._is_puzzle_complete()
        
        return info

    def step(self, action: Any) -> tuple[dict, float, bool, dict]:
        """
        Take a step in the environment based on the action.
        """
        # Store the action in history
        self.history.append(action)

        # Calculate reward for the current turn using the abstract method
        assert self.task is not None, "Task is not set"

        # Increment turn counter
        self.current_turn += 1

        reward, next_obs = self.get_reward_and_next_obs(self.task, action)

        if self.current_board and not self.done:
            current_board_spaced = self.current_board.to_spaced_ascii()
            self.board_history.append(current_board_spaced)
            
            # check if the same board has appeared consecutively for 3 times
            if self._check_consecutive_same_boards():
                self.done = True
                self.termination_reason = "BOARD_STATE_UNCHANGED"
            # check if the puzzle is filled
            if self._is_puzzle_filled():
                self.done = True
                self.termination_reason = "PUZZLE_FILLED"

        # check if we've reached the maximum number of turns
        if self.current_turn >= self.max_turns:
            self.done = True
            if self.termination_reason is None:
                self.termination_reason = "MAX_TURNS"
        
        if self.done:
            next_obs["termination_reason"] = self.termination_reason

        return next_obs, reward, self.done, self.task

    def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[float, dict]:
        """
        Calculate reward and next observation for the given action.

        Args:
            task: Task dictionary
            action: LLM's action (string or dictionary)

        Returns:
            (reward, next_observation) tuple
        """
        format_penalty, action_list = self._parse_action(action)
        if len(action_list) == 0:
            # [MOD] Do not count as parsing error if it was due to insufficient actions
            if not getattr(self, "_insufficient_actions", False):
                self.max_parsing_error -= 1
            # reset flag each step
            self._insufficient_actions = False
        else:
            # reset flag each step
            self._insufficient_actions = False
            error_messages = []
            for sudoku_action in action_list:
                try:
                    # Execute action
                    self.current_board.execute_action(sudoku_action)
                except Exception as e:
                    self.max_invalid_actions -= 1
                    if self.max_invalid_actions <= 0:
                        break
                    error_messages.append(str(e))
                    print(e)
                    continue
            
            if len(error_messages) > 0:
                self.env_message = "\n".join(error_messages)
                format_penalty = self.invalid_action_penalty
            else:
                self.env_message = None
        
        if self.max_invalid_actions <= 0:
            self.done = True
            self.termination_reason = "MAX_INVALID_ACTIONS"
        if self.max_wrong_actions <= 0:
            self.done = True
            self.termination_reason = "MAX_WRONG_ACTIONS"
        if self.max_parsing_error <= 0:
            self.done = True
            self.termination_reason = "MAX_PARSING_ERROR"

        # Generate next observation
        next_obs = self._get_observation()
        goal_reward = self._calculate_reward(next_obs.get("progress", 0))
        reward = goal_reward - format_penalty
        
        return reward, next_obs

    def _parse_action(self, action: Any) -> tuple[float, list[SudokuAction]]:
        """Parse the action into a SudokuAction object."""
        # [ADD] reset flag at the beginning
        self._insufficient_actions = False

        if self.strict_termination:
            required_markers = ["</think>", "REASON", "ACTION"]
            if action.strip()[-3:] != "```":
                return self.parsing_error_penalty, []
            think_num = action.count("</think>")
            reason_num = action.count("REASON")
            action_num = action.count("ACTION")
            if think_num > 1 or reason_num > 1 or action_num > 1:
                return self.parsing_duplicate_penalty, []
        else:
            required_markers = ["ACTION"]
        if any(marker not in action for marker in required_markers):
            return self.parsing_error_penalty, []

        pattern = re.compile(
            r"""
            ACTION:\s*
            (?:                                     # Option A: fenced code block
                ```(?:[a-zA-Z0-9_-]+)?\s*          # opening fence with optional language
                (?P<fenced>.*?)                    # content inside fences (non-greedy)
                \s*```                             # closing fence
            |
                (?P<inline>.+)                     # Option B: inline content until end
            )
            """,
            re.DOTALL | re.VERBOSE,
        )

        m = pattern.search(action)
        if not m:
            return self.parsing_invalid_penalty, []

        content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
        if self.only_one_action:
            content = content.split("\n")[0].strip()
        # Normalize lines: strip blanks and ignore empty lines
        lines = [ln.strip() for ln in content.strip().splitlines() if ln.strip()]

        parsed_actions: List["SudokuAction"] = []
        for line in lines:
            try:
                parsed_actions.append(SudokuAction.from_serialized(line.lower()))
            except Exception:
                return self.parsing_invalid_penalty, []
        
        if len(parsed_actions) == 0:
            return self.parsing_invalid_penalty, []

                # [ADD] ONLY_K_ACTION enforcement
        k = getattr(self, "only_k_action", None)
        if isinstance(k, int) and k > 0:
            if len(parsed_actions) > k:
                parsed_actions = parsed_actions[:k]
            elif len(parsed_actions) < k:
                # [ADD] allow "final step" where remaining empties == provided actions
                empty_cnt = None
                if self.current_board is not None:
                    empty_cnt = sum(1 for cell in self.current_board.cells if cell.value == ValueType.empty)

                if empty_cnt is not None and empty_cnt == len(parsed_actions):
                    # allow applying fewer actions on the last step
                    return 0.0, parsed_actions

                self.env_message = (
                    f"Insufficient number of actions: required {k}, but got {len(parsed_actions)}. "
                    f"No actions will be taken this step."
                )
                self._insufficient_actions = True
                return 0.0, []

        return 0.0, parsed_actions

    def _calculate_progress(self) -> tuple[float, int]:
        """Calculate progress for the current board."""
        curr_correct = self._count_correct_cells()
        try:
            progress = (curr_correct - self.initial_correct) / (self.total_cells - self.initial_correct)
        except:
            print("Error calculating progress")
            progress = 0.0
        return progress, curr_correct

    def _calculate_reward(self, progress: float) -> float:
        """Calculate reward for the action."""
        # Reward for completing the puzzle
        progress_mode = str(self.progress_reward_type or "")

        if self._is_puzzle_complete():
            reward = self.success_reward
            self.done = True
            self.termination_reason = "PUZZLE_COMPLETE"
        else:
            if "progress_A" in progress_mode:
                subgrid_progress = self._get_subgrid_results(self.current_board)
                max_progress_reward = 2.0
                reward_interval = round(max_progress_reward / self.rows, 3)
                reward_per_improvement = round(max_progress_reward / self.rows, 3)
                # if self.rows==self.cols==9:
                #     reward_interval = 0.111
                #     reward_per_improvement = 0.111
                # elif self.rows==self.cols==6:
                #     reward_interval = 0.166
                #     reward_per_improvement = 0.166
                # elif self.rows==self.cols==4:
                #     reward_interval = 0.25
                #     reward_per_improvement = 0.25

                if subgrid_progress >= self.prev_progress + reward_interval:
                    progress_interval = int((subgrid_progress - self.prev_progress) / reward_interval)
                    reward = reward_per_improvement * progress_interval
                    self.prev_progress += reward_interval * progress_interval
                else:
                    reward = 0.0
            elif "progress_B" in progress_mode:
                reward_interval = 0.2
                if progress >= self.prev_progress + reward_interval:
                    progress_interval = int((progress - self.prev_progress) / reward_interval)
                    reward_per_improvement = 0.4
                    # if "B1" in progress_mode:
                    #     reward_per_improvement = 0.1
                    # elif "B2" in progress_mode:
                    #     reward_per_improvement = 0.2
                    # else:
                    #     raise ValueError(f"Invalid progress mode: {progress_mode}")
                    reward = reward_per_improvement * progress_interval
                    self.prev_progress += reward_interval * progress_interval
                else:
                    reward = 0.0
            elif "progress_C" in progress_mode:
                reward_interval = 0.05
                if progress >= self.prev_progress + reward_interval:
                    progress_interval = int((progress - self.prev_progress) / reward_interval)
                    reward_per_improvement = 0.1  # or 0.05
                    reward = reward_per_improvement * progress_interval
                    self.prev_progress += reward_interval * progress_interval
                else:
                    reward = 0.0
            else:
                reward = 0.0
        
        if "noise" in progress_mode:
            noise_reward = self._get_noise_reward(reward, progress_mode)
            reward = noise_reward

        reward = round(reward, 3)

        return reward

    def _get_noise_reward(self, reward: float, progress_mode: str) -> float:
        """Get noise reward for the current progress."""
        if reward > 0.0 and reward < 1.0:
            if "noise_A1" in progress_mode:
                noise_reward = reward * round(random.uniform(0.0, 1.0), 3)
            elif "noise_A2" in progress_mode:
                noise_reward = reward * round(random.uniform(0.3, 1.0), 3)
            elif "noise_A3" in progress_mode:
                noise_reward = reward * round(random.uniform(0.7, 1.0), 3)
            elif "noise_B1" in progress_mode:
                give_reward = random.random() < 0.5
                if give_reward:
                    noise_reward = reward
                else:
                    noise_reward = 0.0
            elif "noise_B2" in progress_mode:
                give_reward = random.random() < 0.25
                if give_reward:
                    noise_reward = reward
                else:
                    noise_reward = 0.0
            else:
                noise_reward = reward
        elif reward == 1.0:
            if "add_correct_1" in progress_mode:
                add_noise = random.random() < 0.3
                if add_noise:
                    noise_reward = round(random.uniform(0.5, 1.0), 3)
                else:
                    noise_reward = 1.0
            elif "add_correct_2" in progress_mode:
                add_noise = random.random() < 0.7
                if add_noise:
                    noise_reward = round(random.uniform(0.5, 1.0), 3)
                else:
                    noise_reward = 1.0
            else:
                noise_reward = 1.0
        else:
            if "add_wrong" in progress_mode:
                give_reward = random.random() < 0.5
                if give_reward:
                    noise_reward = round(random.uniform(0.0, 0.15), 3)
                else:
                    noise_reward = 0.0
            else:
                noise_reward = 0.0
        return noise_reward

    def _is_puzzle_filled(self) -> bool:
        """Check if the puzzle is filled."""
        if not self.current_board:
            return False
        
        for cell in self.current_board.cells:
            if cell.value == ValueType.empty:
                return False
        
        return True

    def _is_puzzle_complete(self) -> bool:
        """Check if the puzzle is complete."""        
        # Check if all cells are filled
        is_filled = self._is_puzzle_filled()
        if not is_filled:
            return False
        
        # If there is a solution board, check if the current board is correct
        if self.solution_board:
            return self._count_correct_cells() == self.rows * self.cols
         
        return False
    
    def _count_correct_cells(self, board: Optional[SudokuBoard] = None) -> int:
        """Return the number of cells that match the solution."""
        if not self.solution_board:
            return 0
        
        target_board = board if board else self.current_board
        if not target_board:
            return 0
        
        correct_count = 0
        for cell in target_board.cells:
            if cell.value != ValueType.empty:
                solution_cell = self.solution_board.get_cell(cell.row, cell.col)
                if cell.value == solution_cell.value:
                    correct_count += 1
        
        return correct_count
    
    def _get_subgrid_dims(self) -> tuple[int, int]:
        """Return (subgrid_rows, subgrid_cols) for supported sizes."""
        if self.rows == self.cols == 9:
            return 3, 3
        if self.rows == self.cols == 6:
            return 2, 3
        if self.rows == self.cols == 4:
            return 2, 2
        raise ValueError(f"Unsupported board size for subgrid check: {self.rows}x{self.cols}")

    def _make_subgrid_results_dict(self, board: Optional[SudokuBoard] = None) -> dict:
        """Make a dictionary of subgrid results."""
        subgrid_results = {}
        max_sg_r = self.rows // self.sr
        max_sg_c = self.cols // self.sc
        correct_subgrid_count = 0
        for sg_row in range(1, max_sg_r + 1):
            for sg_col in range(1, max_sg_c + 1):
                is_correct = self._is_subgrid_correct(sg_row, sg_col, board)
                subgrid_results[f"sg{sg_row}sg{sg_col}"] = is_correct
                if is_correct:
                    correct_subgrid_count += 1
        return subgrid_results, correct_subgrid_count

    def _is_subgrid_correct(self, sg_row: int, sg_col: int, board: Optional[SudokuBoard] = None) -> bool:
        """Check if the subgrid is correct."""
        r0 = (sg_row - 1) * self.sr + 1
        c0 = (sg_col - 1) * self.sc + 1
        cnt = 0
        for r in range(r0, r0 + self.sr):
            for c in range(c0, c0 + self.sc):
                curr = board.get_cell(r, c)
                if curr.value == ValueType.empty:
                    continue
                sol = self.solution_board.get_cell(r, c)
                if curr.value == sol.value:
                    cnt += 1
        return cnt == self.sr * self.sc
    
    def _get_subgrid_results(self, board: Optional[SudokuBoard] = None) -> float:
        """Get the results of the subgrid."""
        max_sg_r = self.rows // self.sr
        max_sg_c = self.cols // self.sc
        correct_subgrid_count = 0
        for sg_row in range(1, max_sg_r + 1):
            for sg_col in range(1, max_sg_c + 1):
                if self.subgrid_results[f"sg{sg_row}sg{sg_col}"]:
                    correct_subgrid_count += 1
                else:
                    is_correct = self._is_subgrid_correct(sg_row, sg_col, board)
                    self.subgrid_results[f"sg{sg_row}sg{sg_col}"] = is_correct
                    if is_correct:
                        correct_subgrid_count += 1
        
        subgrid_progress = (correct_subgrid_count - self.initial_correct_subgrid_count) / (max_sg_r * max_sg_c - self.initial_correct_subgrid_count)
        return subgrid_progress

    def _check_consecutive_same_boards(self) -> bool:
        """
        Check if the same board has appeared consecutively for a specified number of times.
        
        Returns:
            bool: True if the same board has appeared consecutively for max_consecutive_same_boards times
        """
        if len(self.board_history) < self.max_consecutive_same_boards:
            return False
        
        # Check the last max_consecutive_same_boards boards
        recent_boards = self.board_history[-self.max_consecutive_same_boards:]
        
        # Check if all boards are the same
        return all(board == recent_boards[0] for board in recent_boards)

    @staticmethod
    def from_dict(env_args: dict) -> "SudokuEnv":
        """Generate a SudokuMultiTurnEnvironment from a dictionary."""
        return SudokuEnv(task=env_args)
