import ast
import re
from collections import defaultdict
from typing import List, Optional, Any, Dict, Set

StateType = str  # Caller ensures list states are converted to strings
ActionType = Any  # Define specific types if needed


class StateActionStack:
    def __init__(self):
        self.stack: List[StateType] = []  # Maintain states in LIFO order (DFS)
        self.explored: Dict[StateType, Set[ActionType]] = {}  # Explored actions per state
        self.unexplored: Dict[StateType, Set[ActionType]] = {}  # Unexplored actions per state
        self.fixing: Dict[StateType, Optional[ActionType]] = {}  # Fixing action per state
        self._last_popped_state: Optional[StateType] = None  # Track last popped state

        self.state_to_input_ids = defaultdict(set)  # Track all input id that may lead to this fixed state
        self.locked_map = set()  # States that are locked at beginning and can never be unfixed

    def record_state(self, state, input_id):
        if isinstance(state, list):
            state = str(state)
        self.state_to_input_ids[state].add(input_id)
        print(f"[RECORD STATE] {state}, input_id: {input_id}")

    def print_track_info(self):
        for state, input_ids in self.state_to_input_ids.items():
            print(f"[TRACK INFO] State: {state}, setting_ids: {input_ids}")

    def lock_all_states(self, start_setting_id=None):
        """
        Lock all states at the beginning of the game.
        These states cannot be unfixed later.
        """
        for state in self.stack:
            self.locked_map.add(state)
            if start_setting_id:
                self.record_state(state, start_setting_id)
            print(f"[LOCK STATE] {state} is locked.")

    def is_locked(self, state):
        if isinstance(state, list):
            state = str(state)
        return state in self.locked_map

    def update_record(self, input_id, reached_states):
        """
        After simulating input_id, if fixed states that stored in state_to_input_ids are not reached.
        Remove it from all state-to-id mappings.
        """
        reached_set = {str(s) if isinstance(s, list) else s for s in reached_states}
        for state, ids in self.state_to_input_ids.items():
            if input_id in ids and state not in reached_set:
                print(f"[UPDATE RECORD] Removed input_id {input_id} from {state}")
                ids.remove(input_id)

    def get_all_ids_for_fixed_states(self) -> set:
        """
        Return a set of all input-ids that can reach any currently fixed state.
        """
        ids_for_fixed_states = set()
        for state, action in self.fixing.items():
            if action is not None:
                ids_for_fixed_states.update(self.state_to_input_ids.get(state, set()))
        return ids_for_fixed_states

    def push_state(self, state, actions: Set[ActionType]):
        """Push a new state onto the stack with its unexplored actions."""
        if isinstance(state, list):
            state = str(state)  # Convert list states to strings
        assert isinstance(state, str), "State must be a string."

        if not self.should_push(state):
            print(f"[SKIP PUSH] {state} was just popped. Not pushing again.")
            return False

        if state in self.stack:
            return False
        if state not in self.explored:
            self.explored[state] = set()  # Initialize explored actions
        if state not in self.unexplored:
            self.unexplored[state] = set(actions)  # Store unexplored actions in a set
        if state not in self.fixing:
            self.fixing[state] = None  # No fixing action initially

        self.stack.append(state)  # Push state onto the DFS stack
        self._last_popped_state = None
        print(f"[PUSH] {state} with actions: {actions}")
        return True

    def should_push(self, state) -> bool:
        if isinstance(state, list):
            state = str(state)
        return state != self._last_popped_state

    def pop_state(self) -> Optional[StateType]:
        """Pop the most recent state from the stack (LIFO order for DFS)."""
        if self.stack:
            state = self.stack.pop()  # Removes last added state (DFS behavior)
            self._last_popped_state = state
            # Cleanup associated data
            self._clear_state_data(state)
            print(f"[POP] {state}.")
            return state
        print("[POP] No states left to pop.")
        return None

    def pop_specific_key(self, state) -> Optional[StateType]:
        if isinstance(state, list):
            state = str(state)

        if state in self.stack:
            self.stack.remove(state)
            self._last_popped_state = state
            self._clear_state_data(state)
            print(f"[POP SPECIFIC] {state}. Cleared all related data.")
            return state
        print(f"[POP SPECIFIC] {state} not found in the stack.")
        return None

    def _clear_state_data(self, state):
        """Helper function to clear all associated data for a state."""
        self.explored.pop(state, None)
        self.unexplored.pop(state, None)
        self.fixing.pop(state, None)

    def peek_state(self) -> Optional[StateType]:
        """Peek at the most recent state without popping."""
        if self.stack:
            return self.stack[-1]  # Peek at last added state
        return None

    def pop_next_unexplored_action(self, state) -> Optional[ActionType]:
        """Remove and return an unexplored action from the set (arbitrary order)."""
        if isinstance(state, list):
            state = str(state)

        if state in self.unexplored and self.unexplored[state]:
            action = self.unexplored[state].pop()
            print(f"[POP ACTION] {state}. Action: {action}.")
            return action  # Remove a random action
        print(f"[POP ACTION] {state}. No unexplored actions left for state .")
        return None

    def mark_explored(self, state, action: ActionType):
        """Mark an action as explored for a state."""
        if isinstance(state, list):
            state = str(state)

        if state not in self.explored:
            self.explored[state] = set()
        self.explored[state].add(action)

        # Remove from unexplored if present
        if state in self.unexplored and action in self.unexplored[state]:
            self.unexplored[state].remove(action)

        print(f"[MARK] {state}. Marked action {action} as explored in state {state}.")

    def is_explored(self, state, action: ActionType) -> bool:
        """Check if an action has been explored for a given state."""
        if isinstance(state, list):
            state = str(state)

        return action in self.explored.get(state, set())

    def has_unexplored_actions(self, state) -> bool:
        """Check if there are unexplored actions left for a state."""
        if isinstance(state, list):
            state = str(state)

        return bool(self.unexplored.get(state, set()))

    def get_unexplored_actions(self, state) -> Set[ActionType]:
        """Return the set of unexplored actions for a state."""
        if isinstance(state, list):
            state = str(state)

        return self.unexplored.get(state, set())

    def set_fixing_action(self, state, action: ActionType):
        """Set a fixing action for a state."""
        if isinstance(state, list):
            state = str(state)

        self.fixing[state] = action
        print(f"[FIX] {state}. Set fixing action {action} for state {state}.")

    def get_fixing_action(self, state) -> Optional[ActionType]:
        """Get the currently fixing action for a state."""
        if isinstance(state, list):
            state = str(state)

        return self.fixing.get(state, None)

    def reset_fixing_action(self, state):
        """Reset the fixing action for a state."""
        if isinstance(state, list):
            state = str(state)

        if state in self.fixing:
            print(f"[RESET] {state}. Reset fixing action for state {state}.")
            self.fixing[state] = None

    def is_fixed(self, state) -> bool:
        """Check if a state has a fixing action."""
        if isinstance(state, list):
            state = str(state)

        action = self.fixing.get(state, None)
        if action is not None:
            print(f"[FIXED] {state}. State {state} is fixed to {action}")
        return action is not None

    def find_latest_key(self, keys) -> Optional[StateType]:
        """Find the most recently added key from the given set of keys."""
        if isinstance(keys[0], list):
            keys = [str(key) for key in keys]

        for state in reversed(self.stack):  # Iterate stack in LIFO order (latest first)
            if state in keys:
                print(f"[FIND] {state}. Found latest key {state}.")
                return state
        return None  # No matching key found

    def latest_key(self) -> StateType:
        """Return the most recently added key."""
        if self.stack:
            return self.stack[-1]
        return None

    def get_all_states(self) -> List[StateType]:
        """Return all states currently in the stack."""
        return list(self.stack)

    def is_empty(self) -> bool:
        """Check whether the stack is empty."""
        return len(self.stack) == 0

    def print_stack_info(self):
        """Print the current stack and all associated state information."""
        print("\n[STACK INFO] Current State Stack (Top -> Bottom):")
        if not self.stack:
            print("  [EMPTY STACK]")
            return

        for state in reversed(self.stack):  # Print in LIFO order
            unexplored = self.unexplored.get(state, set())
            explored = self.explored.get(state, set())
            fixing = self.fixing.get(state, None)
            print(f"  - State {state}")
            print(f"    - Unexplored Actions: {unexplored}")
            print(f"    - Explored Actions: {explored}")
            print(f"    - Fixing Action: {fixing}")
        print("[END OF STACK INFO]\n")

    def load_from_stack_info(self, output: str):
        """Parse a print_stack_info() output and reconstruct the internal state."""
        lines = output.strip().splitlines()

        # Clear current state
        self.stack.clear()
        self.explored.clear()
        self.unexplored.clear()
        self.fixing.clear()
        self._last_popped_state = None

        current_state = None
        states_in_order = []  # From top to bottom

        for line in lines:
            line = line.strip()
            if line.startswith("- State"):
                match = re.match(r"- State (.+)", line)
                if match:
                    current_state = match.group(1)
                    states_in_order.append(current_state)
            elif line.startswith("- Unexplored Actions:"):
                match = re.match(r"- Unexplored Actions: (.+)", line)
                if match and current_state:
                    self.unexplored[current_state] = ast.literal_eval(match.group(1))
            elif line.startswith("- Explored Actions:"):
                match = re.match(r"- Explored Actions: (.+)", line)
                if match and current_state:
                    self.explored[current_state] = ast.literal_eval(match.group(1))
            elif line.startswith("- Fixing Action:"):
                match = re.match(r"- Fixing Action: (.+)", line)
                if match and current_state:
                    val = match.group(1)
                    self.fixing[current_state] = None if val == "None" else ast.literal_eval(val)

        # Rebuild the stack (bottom to top)
        for state in reversed(states_in_order):
            self.stack.append(state)
