import os
from typing import Any, Optional, List
import copy
import re
import random

from rllm.environments.base.base_env import BaseEnv
from rllm.environments.games.sudoku_utils import (
    SudokuBoard,
    SudokuAction,
    ValueType,
    ActionType,
    OperationType,
)


class SudokuFormatEnv(BaseEnv):
    """
    Multi-turn environment for Sudoku puzzles.
    Provides an environment where an LLM solves Sudoku puzzles through actions.

    Add-on features:
      - WRONG_VALUE_FIX_CHANCES (env var, default 0):
          In a single step (multi-action response), if VALUE action places a wrong value on an empty cell,
          the env may replace it with the solution value, consuming 1 chance per wrong VALUE placement.
          If wrong VALUE placements exceed the chance budget, the step fails immediately.
      - action_fix_history (observation field):
          Cumulative per-step log of (original_action -> fixed_action) pairs.
          Stored as list[list[tuple[str,str]]] where outer index is 1-based step idx.
    """

    parsing_error_penalty = 1.0
    parsing_duplicate_penalty = 1.0
    parsing_invalid_penalty = 1.0
    invalid_action_penalty = 0.5
    wrong_action_penalty = 0.5

    # step-level reward
    format_reward = 0.5
    valid_action_reward = 0.5
    progress_reward = 0.5

    # trajectory reward
    success_reward = 1.0

    def __init__(self, task: dict):
        self.task = task
        self.max_turns = task.get("max_turns", 30)
        self.progress_reward_type = task.get("progress_reward_type", "")
        self.strict_termination = task.get("strict_termination", False)  # for training
        self.only_one_action = task.get("only_one_action", False)

        self.max_k_action: Optional[int] = None

        # Wrong VALUE auto-fix budget (per-step), loaded in reset() via env var
        self.wrong_value_fix_chances: int = 0

        # Cumulative log: list[step_idx -> list[(orig_action_str, fixed_action_str)]]
        # We keep 1-based step_idx alignment (index 0 unused)
        self.action_fix_history: list[list[tuple[str, str]]] = []

        # Board state tracking for detecting loops
        self.current_board: Optional[SudokuBoard] = None
        self.initial_board: Optional[SudokuBoard] = None
        self.solution_board: Optional[SudokuBoard] = None
        self.board_history: list[str] = []

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0

        # terminate conditions
        if self.strict_termination:  # for training
            self.max_consecutive_same_boards = 2
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
            self.max_wrong_actions = 100
        else:
            self.max_consecutive_same_boards = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20
            self.max_wrong_actions = 1  # for early termination

        self._initialize_from_task(task)

        log_postfix = ''
        if os.getenv('WRONG_VALUE_FIX_CHANCES'):
            log_postfix += "   WRONG_VALUE_FIX_CHANCES " + str(os.getenv('WRONG_VALUE_FIX_CHANCES'))
        if os.getenv("MAX_K_ACTION"):
            log_postfix += "   MAX_K_ACTION " + str(os.getenv("MAX_K_ACTION"))
            
        print('envformat!!', log_postfix)

    def _initialize_from_task(self, task: dict):
        """Initialize the environment from a task dictionary."""
        # Set board size
        if "rows" in task and "cols" in task:
            self.rows = task["rows"]
            self.cols = task["cols"]
            sr, sc = self._get_subgrid_dims()
            self.sr = sr
            self.sc = sc
        else:
            raise ValueError("rows and cols must be provided in the task")

        self.difficulty_level = task.get("difficulty", "default")

        if task.get("visual_elements", "None") == "None":
            self.sub_task_type = "standard_sudoku"
        else:
            self.sub_task_type = "ctc_sudoku"

        # Set initial board
        initial_board_data = task.get("initial_board")
        if isinstance(initial_board_data, str):
            self.initial_board = SudokuBoard.from_ascii(initial_board_data, self.rows, self.cols)
        elif isinstance(initial_board_data, SudokuBoard):
            self.initial_board = initial_board_data
        elif isinstance(initial_board_data, dict):
            self.initial_board = SudokuBoard.from_serialized(
                initial_board_data.get("state", "{}"),
                self.rows,
                self.cols,
                initial_board_data.get("givens", {}),
            )
        else:
            raise ValueError("task['initial_board'] must be str | SudokuBoard | dict")

        # save the initial board coordinates
        self.initial_board_coordinates = [
            f"r{cell.row}c{cell.col}"
            for cell in self.initial_board.cells
            if cell.value != ValueType.empty
        ]
        self.initial_board.save_initial_board_coordinates(self.initial_board_coordinates)

        # Set solution board (optional)
        solution_data = task.get("solution")
        if isinstance(solution_data, str):
            self.solution_board = SudokuBoard.from_ascii(solution_data, self.rows, self.cols)
        elif isinstance(solution_data, SudokuBoard):
            self.solution_board = solution_data
        else:
            self.solution_board = None

        self.solution_board_ascii = self.solution_board.to_ascii() if self.solution_board else None

    def reset(self):
        """Reset the environment to the initial state."""
        self.done = False
        self.current_turn = 0
        self.history = []
        self.board_history = []

        # reset cumulative fix history (must persist across steps, only last obs saved)
        self.action_fix_history = [[]]  # index 0 unused for 1-based step index

        # terminate conditions
        if self.strict_termination:  # for training
            self.max_consecutive_same_boards = 3
            self.max_parsing_error = 2
            self.max_invalid_actions = 5
            if "progress_C" in self.progress_reward_type or "progress_F" in self.progress_reward_type:
                self.max_wrong_actions = 1
            elif "progress_E" in self.progress_reward_type:
                self.max_consecutive_same_boards = 5
                self.max_parsing_error = 5
                self.max_invalid_actions = 20
                self.max_wrong_actions = 10
            else:
                self.max_wrong_actions = 100
        else:
            self.max_consecutive_same_boards = 5
            self.max_parsing_error = 5
            self.max_invalid_actions = 20
            self.max_wrong_actions = 1  # for early termination

        # --- MAX_K_ACTION handling (env var) ---
        self.max_k_action = None
        max_k_str = os.getenv("MAX_K_ACTION")
        if max_k_str is not None:
            try:
                self.max_k_action = int(max_k_str)
            except ValueError:
                raise ValueError(f"Invalid MAX_K_ACTION (must be int): {max_k_str}")

        if os.getenv("ONLY_K_ACTION") is not None:
            self.max_parsing_error = 50
            self.max_invalid_actions = 30

        # --- WRONG_VALUE_FIX_CHANCES handling (env var) ---
        fix_str = os.getenv("WRONG_VALUE_FIX_CHANCES", "0")
        try:
            self.wrong_value_fix_chances = int(fix_str)
        except ValueError:
            raise ValueError(f"Invalid WRONG_VALUE_FIX_CHANCES (must be int): {fix_str}")
        if self.wrong_value_fix_chances < 0:
            raise ValueError(f"WRONG_VALUE_FIX_CHANCES must be >= 0: {self.wrong_value_fix_chances}")
        self.remaining_value_fix_chances = self.wrong_value_fix_chances

        self.env_message = None
        self.termination_reason = None
        self.prev_progress = 0.0
        self.incorrect_cells_num = 0

        if self.initial_board:
            self.current_board = copy.deepcopy(self.initial_board)
            self.total_cells = self.rows * self.cols
            self.initial_correct, _ = self._count_correct_cells(self.initial_board)
            self.subgrid_results, self.initial_correct_subgrid_count = self._make_subgrid_results_dict(
                self.initial_board
            )

        observation = self._get_observation()
        info = self._get_info()
        return observation, info

    def _get_observation(self) -> dict:
        """Return the observation of the current environment state."""
        if not self.current_board:
            return {}

        observation = {
            "board_ascii": self.current_board.to_ascii(),
            "board_spaced": self.current_board.to_spaced_ascii(),
            "board_tokens": self.current_board.to_string(),
            "current_turn": self.current_turn,
            "max_turns": self.max_turns,
            "env_message": self.env_message,
            "action_fix_history": self.action_fix_history,
        }

        if self.current_turn == 0:
            observation["task_info"] = self.task

        if self.solution_board:
            progress, correct_cell, incorrect_cell = self._calculate_progress()
            observation["progress"] = progress
            observation["correct_cells"] = correct_cell
            observation["incorrect_cells"] = incorrect_cell

        return observation

    def _get_info(self) -> dict:
        """Return additional information."""
        info = {
            "turn": self.current_turn,
            "max_turns": self.max_turns,
            "board_size": f"{self.rows}x{self.cols}",
        }
        if self.current_board:
            info["is_complete"] = self._is_puzzle_complete()
        return info

    def step(self, action: Any) -> tuple[dict, dict, bool, dict]:
        """Take a step in the environment based on the action."""
        self.history.append(action)
        assert self.task is not None, "Task is not set"

        # Increment turn counter (this turn index is used for fix-history indexing)
        self.current_turn += 1

        reward, next_obs = self.get_reward_and_next_obs(self.task, action)

        if self.current_board and not self.done:
            current_board_spaced = self.current_board.to_spaced_ascii()
            self.board_history.append(current_board_spaced)

            if self._check_consecutive_same_boards():
                self.done = True
                reward["step"] = 0.0
                self.termination_reason = "BOARD_STATE_UNCHANGED"

            if self._is_puzzle_filled():
                self.done = True
                self.termination_reason = "PUZZLE_FILLED"

        if self.current_turn >= self.max_turns:
            self.done = True
            if self.termination_reason is None:
                self.termination_reason = "MAX_TURNS"

        if self.done:
            next_obs["termination_reason"] = self.termination_reason
            if reward["traj"] == 0:
                reward["traj"] = -1.0

        return next_obs, reward, self.done, self.task

    def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[dict, dict]:
        """Calculate reward and next observation for the given action."""
        _, action_list = self._parse_action(action)
        step_reward = 0.0

        if len(action_list) == 0:
            self.max_parsing_error -= 1
            step_reward -= 1.0  # format penalty
        else:
            step_reward += self.format_reward

            error_messages: list[str] = []

            # Per-step remaining fix budget
            remaining_fixes = self.remaining_value_fix_chances

            # Per-step list of (original -> fixed) action pairs
            step_fix_pairs: list[tuple[str, str]] = []

            # 1-based step index: step() increments current_turn before calling this.
            step_idx = self.current_turn
            while len(self.action_fix_history) <= step_idx:
                self.action_fix_history.append([])

            for sudoku_action in action_list:
                try:
                    # Execute action (VALUE on non-empty cell will raise ValueError here as before)
                    self.current_board.execute_action(sudoku_action)

                    # Wrong VALUE auto-fix (only when enabled AND solution exists)
                    if (
                        self.wrong_value_fix_chances > 0
                        and self.solution_board is not None
                        and sudoku_action.action_type == ActionType.VALUE
                        and sudoku_action.coordinates
                        and len(sudoku_action.coordinates) == 1
                        and self.current_board is not None
                    ):
                        r, c = sudoku_action.coordinates[0]
                        cur_cell = self.current_board.get_cell(r, c)
                        sol_cell = self.solution_board.get_cell(r, c)

                        # If model placed a wrong value, either fix (if budget remains) or fail.
                        if cur_cell.value != ValueType.empty and cur_cell.value != sol_cell.value:
                            if remaining_fixes > 0:
                                orig_str = sudoku_action.to_serialized()
                                fixed_action = SudokuAction(
                                    action_type=ActionType.VALUE,
                                    value=sol_cell.value,
                                    coordinates=[(r, c)],
                                )
                                fixed_str = fixed_action.to_serialized()
                                step_fix_pairs.append((orig_str, fixed_str))

                                # Force replace with correct value
                                cur_cell.value = sol_cell.value
                                cur_cell.candidates = []
                                cur_cell.pencilmarks = []
                                cur_cell.colors = []
                                cur_cell.impossible_candidates = []

                                print('fix', orig_str, '->', fixed_str, f'(remain{remaining_fixes})')

                                remaining_fixes -= 1
                                self.remaining_value_fix_chances = remaining_fixes
                            else:
                                self.done = True
                                self.termination_reason = "WRONG_VALUE_EXCEEDED_FIX_BUDGET"
                                error_messages.append(
                                    f"Wrong VALUE placements exceed fix budget ({self.wrong_value_fix_chances})."
                                )
                                break

                except Exception as e:
                    self.max_invalid_actions -= 1
                    if self.max_invalid_actions <= 0:
                        break
                    error_messages.append(str(e))
                    continue

            # Commit per-step fixes into cumulative history (accumulate; do not overwrite)
            if step_fix_pairs:
                self.action_fix_history[step_idx].extend(step_fix_pairs)

            if len(error_messages) > 0:
                self.env_message = "\n".join(error_messages)
                step_reward -= self.invalid_action_penalty
            else:
                self.env_message = None
                step_reward += self.valid_action_reward

        next_obs = self._get_observation()

        # Reward for completing the puzzle (and progress-based reward shaping)
        traj_reward, progress_reward = self._calculate_reward(next_obs.get("progress", 0))
        incorrect_cells_num = next_obs.get("incorrect_cells", 0)
        do_wrong_action = True if incorrect_cells_num > self.incorrect_cells_num else False

        if "progress_A" in self.progress_reward_type:
            if traj_reward > 0 or progress_reward > 0:
                traj_reward = 1.0
            else:
                traj_reward = 0.0
        elif "progress_B" in self.progress_reward_type:
            if "noise" in self.progress_reward_type:
                if do_wrong_action:
                    apply_noise = self._apply_noise()
                    if not apply_noise:
                        traj_reward = -1.0
                else:
                    apply_noise = self._apply_noise(noise_rate=0.01)
                    if apply_noise:
                        traj_reward = -1.0
            else:
                if do_wrong_action:
                    traj_reward = -1.0
                else:
                    traj_reward = 0.2
        elif "progress_C" in self.progress_reward_type:
            if do_wrong_action:
                traj_reward = -1.0
            else:
                traj_reward = 1.0
        elif "progress_D" in self.progress_reward_type:
            if traj_reward == 0:
                if progress_reward > 0:
                    traj_reward = 1.0
                else:
                    traj_reward = 0.0
        elif "progress_E" in self.progress_reward_type:
            if do_wrong_action:
                traj_reward = -1.0
            else:
                if step_reward == 1.0:
                    traj_reward = 1.0
                else:
                    traj_reward = -1.0

        if do_wrong_action:
            self.incorrect_cells_num = incorrect_cells_num

        reward = {"step": step_reward, "traj": traj_reward}

        if self.max_invalid_actions <= 0:
            self.done = True
            self.termination_reason = "MAX_INVALID_ACTIONS"
        if self.max_wrong_actions <= self.incorrect_cells_num:
            self.done = True
            self.termination_reason = "MAX_WRONG_ACTIONS"
        if self.max_parsing_error <= 0:
            self.done = True
            self.termination_reason = "MAX_PARSING_ERROR"

        return reward, next_obs

    def _parse_action(self, action: Any) -> tuple[float, list["SudokuAction"]]:
        """Parse the action into SudokuAction list."""
        if self.strict_termination:
            required_markers = ["</think>", "REASON", "ACTION"]
            if action.strip()[-3:] != "```":
                return self.parsing_error_penalty, []
            think_num = action.count("</think>")
            reason_num = action.count("REASON")
            action_num = action.count("ACTION")
            if think_num > 1 or reason_num > 1 or action_num > 1:
                return self.parsing_duplicate_penalty, []
        else:
            required_markers = ["ACTION"]

        if any(marker not in action for marker in required_markers):
            return self.parsing_error_penalty, []

        pattern = re.compile(
            r"""
            ACTION:\s*
            (?:
                ```(?:[a-zA-Z0-9_-]+)?\s*
                (?P<fenced>.*?)
                \s*```
            |
                (?P<inline>.+)
            )
            """,
            re.DOTALL | re.VERBOSE,
        )

        m = pattern.search(action)
        if not m:
            return self.parsing_invalid_penalty, []

        content = m.group("fenced") if m.group("fenced") is not None else m.group("inline")
        if self.only_one_action:
            content = content.split("\n")[0].strip()

        lines = [ln.strip() for ln in content.strip().splitlines() if ln.strip()]

        parsed_actions: List["SudokuAction"] = []
        for line in lines:
            try:
                parsed_actions.append(SudokuAction.from_serialized(line.lower()))
            except Exception:
                return self.parsing_invalid_penalty, []

        if len(parsed_actions) == 0:
            return self.parsing_invalid_penalty, []

        # MAX_K_ACTION truncation
        if self.max_k_action is not None:
            if self.max_k_action <= 0:
                return self.parsing_invalid_penalty, []
            parsed_actions = parsed_actions[: self.max_k_action]

        return 0.0, parsed_actions

    def _calculate_progress(self) -> tuple[float, int, int]:
        """Calculate progress for the current board."""
        curr_correct, curr_incorrect = self._count_correct_cells()
        try:
            progress = (curr_correct - self.initial_correct) / (self.total_cells - self.initial_correct)
        except Exception:
            progress = 0.0
        return progress, curr_correct, curr_incorrect

    def _calculate_reward(self, progress: float) -> tuple[float, float]:
        """Calculate (traj_reward, progress_reward)."""
        if self._is_puzzle_complete():
            traj_reward = self.success_reward
            self.done = True
            self.termination_reason = "PUZZLE_COMPLETE"
        else:
            traj_reward = 0.0

        progress_reward = 0.0
        if "progress_A" in self.progress_reward_type or "progress_C" in self.progress_reward_type:
            correct_subgrid_count = self._get_subgrid_results(self.current_board)
            if "A2" in self.progress_reward_type:
                target_num_improvements = 2
            elif "A3" in self.progress_reward_type:
                target_num_improvements = 3
            else:
                target_num_improvements = 1
            if correct_subgrid_count >= self.prev_progress + target_num_improvements:
                progress_reward = 1.0
                self.prev_progress = correct_subgrid_count
            else:
                progress_reward = 0.0

            if "noise" in self.progress_reward_type:
                apply_noise = self._apply_noise()
                if apply_noise and progress_reward > 0:
                    progress_reward = 0.0
                elif apply_noise and progress_reward == 0:
                    progress_reward = 1.0

        elif "progress_B" in self.progress_reward_type:
            if progress >= self.prev_progress:
                progress_reward = 1.0
                self.prev_progress = progress

            if "noise" in self.progress_reward_type:
                apply_noise = self._apply_noise()
                if apply_noise and progress_reward > 0:
                    progress_reward = 0.0
                elif apply_noise and progress_reward == 0:
                    progress_reward = 1.0

        elif "progress_D" in self.progress_reward_type or "progress_E" in self.progress_reward_type:
            reward_interval = 0.25
            if progress >= self.prev_progress + reward_interval:
                progress_reward = 1.0
                self.prev_progress += reward_interval
            else:
                progress_reward = 0.0

            if "noise" in self.progress_reward_type:
                if progress_reward > 0:
                    apply_noise = self._apply_noise()
                    progress_reward = 0.0 if apply_noise else 1.0
                else:
                    apply_noise = self._apply_noise(noise_rate=0.01)
                    if apply_noise:
                        progress_reward = 1.0

        if "traj_noise" in self.progress_reward_type:
            apply_noise = self._apply_noise()
            if apply_noise and traj_reward > 0:
                traj_reward = 0.0
            elif apply_noise and traj_reward == 0:
                traj_reward = 1.0

        progress_reward = round(progress_reward, 3)
        return traj_reward, progress_reward

    def _apply_noise(self, noise_rate: float = 0.0) -> bool:
        """Apply noise to the reward."""
        if "noise_low" in self.progress_reward_type:
            return random.random() < 0.05
        if "noise_medium" in self.progress_reward_type:
            return random.random() < 0.125
        if "noise_high" in self.progress_reward_type:
            return random.random() < 0.25
        if "noise_very_high" in self.progress_reward_type:
            return random.random() < 0.5
        if noise_rate > 0.0:
            return random.random() < noise_rate
        raise ValueError("Set noise level or noise rate for applying noise")

    def _is_puzzle_filled(self) -> bool:
        if not self.current_board:
            return False
        return all(cell.value != ValueType.empty for cell in self.current_board.cells)

    def _is_puzzle_complete(self) -> bool:
        if not self._is_puzzle_filled():
            return False
        if self.solution_board:
            correct_count, _ = self._count_correct_cells()
            return correct_count == self.rows * self.cols
        return False

    def _count_correct_cells(self, board: Optional[SudokuBoard] = None) -> tuple[int, int]:
        """Return (correct_count, incorrect_count)."""
        if not self.solution_board:
            return 0, getattr(self, "total_cells", self.rows * self.cols)

        target_board = board if board else self.current_board
        if not target_board:
            return 0, getattr(self, "total_cells", self.rows * self.cols)

        correct_count = 0
        incorrect_count = 0
        empty_count = 0

        for cell in target_board.cells:
            if cell.value != ValueType.empty:
                solution_cell = self.solution_board.get_cell(cell.row, cell.col)
                if cell.value == solution_cell.value:
                    correct_count += 1
                else:
                    incorrect_count += 1
            else:
                empty_count += 1

        total = getattr(self, "total_cells", self.rows * self.cols)
        assert correct_count + incorrect_count + empty_count == total
        return correct_count, incorrect_count

    def _get_subgrid_dims(self) -> tuple[int, int]:
        if self.rows == self.cols == 9:
            return 3, 3
        if self.rows == self.cols == 6:
            return 2, 3
        if self.rows == self.cols == 4:
            return 2, 2
        raise ValueError(f"Unsupported board size for subgrid check: {self.rows}x{self.cols}")

    def _make_subgrid_results_dict(self, board: Optional[SudokuBoard] = None) -> tuple[dict, int]:
        subgrid_results = {}
        max_sg_r = self.rows // self.sr
        max_sg_c = self.cols // self.sc
        correct_subgrid_count = 0
        for sg_row in range(1, max_sg_r + 1):
            for sg_col in range(1, max_sg_c + 1):
                is_correct = self._is_subgrid_correct(sg_row, sg_col, board)
                subgrid_results[f"sg{sg_row}sg{sg_col}"] = is_correct
                if is_correct:
                    correct_subgrid_count += 1
        return subgrid_results, correct_subgrid_count

    def _is_subgrid_correct(self, sg_row: int, sg_col: int, board: Optional[SudokuBoard] = None) -> bool:
        if not self.solution_board or board is None:
            return False
        r0 = (sg_row - 1) * self.sr + 1
        c0 = (sg_col - 1) * self.sc + 1
        cnt = 0
        for r in range(r0, r0 + self.sr):
            for c in range(c0, c0 + self.sc):
                curr = board.get_cell(r, c)
                if curr.value == ValueType.empty:
                    continue
                sol = self.solution_board.get_cell(r, c)
                if curr.value == sol.value:
                    cnt += 1
        return cnt == self.sr * self.sc

    def _get_subgrid_results(self, board: Optional[SudokuBoard] = None) -> float:
        if board is None:
            return 0.0
        max_sg_r = self.rows // self.sr
        max_sg_c = self.cols // self.sc
        correct_subgrid_count = 0
        for sg_row in range(1, max_sg_r + 1):
            for sg_col in range(1, max_sg_c + 1):
                key = f"sg{sg_row}sg{sg_col}"
                if self.subgrid_results.get(key, False):
                    correct_subgrid_count += 1
                else:
                    is_correct = self._is_subgrid_correct(sg_row, sg_col, board)
                    self.subgrid_results[key] = is_correct
                    if is_correct:
                        correct_subgrid_count += 1
        return correct_subgrid_count

    def _check_consecutive_same_boards(self) -> bool:
        if len(self.board_history) < self.max_consecutive_same_boards:
            return False
        recent_boards = self.board_history[-self.max_consecutive_same_boards :]
        return all(b == recent_boards[0] for b in recent_boards)

    @staticmethod
    def from_dict(env_args: dict) -> "SudokuFormatEnv":
        return SudokuFormatEnv(task=env_args)
