from typing import Annotated
from langchain_core.tools import tool

from langgraph.prebuilt import InjectedState
from langgraph.prebuilt.chat_agent_executor import AgentState
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import InjectedToolArg
from orchestrator_maze_implementation.state.maze_state import MazeState
from orchestrator_maze_implementation.utils.maze_utils.maze_wrapper import MazeWrapper
from collections import deque

@tool
def move_north(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Move one step north if possible. Returns success/failure status."""
    print(f"Moving north for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        success = maze_wrapper.try_move('north')
        new_pos = maze_wrapper.get_agent_position()
        if success:
            return f"Move north: Success! New position: {new_pos}"
        else:
            return f"Move north: Failed - blocked by wall or boundary"
    except Exception as e:
        return f"Move north: Error - {str(e)}"

@tool
def move_south(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Move one step south if possible. Returns success/failure status."""
    print(f"Moving south for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        success = maze_wrapper.try_move('south')
        new_pos = maze_wrapper.get_agent_position()
        if success:
            return f"Move south: Success! New position: {new_pos}"
        else:
            return f"Move south: Failed - blocked by wall or boundary"
    except Exception as e:
        return f"Move south: Error - {str(e)}"

@tool
def move_east(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Move one step east if possible. Returns success/failure status."""
    print(f"Moving east for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        success = maze_wrapper.try_move('east')
        new_pos = maze_wrapper.get_agent_position()
        if success:
            return f"Move east: Success! New position: {new_pos}"
        else:
            return f"Move east: Failed - blocked by wall or boundary"
    except Exception as e:
        return f"Move east: Error - {str(e)}"

@tool
def move_west(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Move one step west if possible. Returns success/failure status."""
    print(f"Moving west for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        success = maze_wrapper.try_move('west')
        new_pos = maze_wrapper.get_agent_position()
        if success:
            return f"Move west: Success! New position: {new_pos}"
        else:
            return f"Move west: Failed - blocked by wall or boundary"
    except Exception as e:
        return f"Move west: Error - {str(e)}"

@tool
def mark_dead_end(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Mark a dead end at the current position. Returns success/failure status."""
    print(f"Marking dead end for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        current_pos = maze_wrapper.get_agent_position()
        success = maze_wrapper.mark_dead_end()
        if success:
            return f"Mark dead end: Success! Marked position {current_pos} as dead end"
        else:
            return f"Mark dead end: Failed - position {current_pos} already marked or invalid"
    except Exception as e:
        return f"Mark dead end: Error - {str(e)}"

@tool
def get_current_view(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Get the current view from the current position. Returns a string representation of the view."""
    print(f"Getting current view for agent {agent_id}")
    try:
        maze_wrapper = state['maze_wrappers'][agent_id]
        current_pos = maze_wrapper.get_agent_position()
        view = maze_wrapper.get_current_view()
        
        # Add additional context about possible moves
        possible_moves = maze_wrapper.get_possible_moves()
        
        enhanced_view = f"Current position: {current_pos}\n"
        enhanced_view += f"Possible moves: {possible_moves}\n"
        enhanced_view += f"At exit: {maze_wrapper.is_at_exit()}\n\n"
        enhanced_view += view
        
        return enhanced_view #needs to be stored in state as a temp object
    except Exception as e:
        return f"Get current view: Error - {str(e)}"

@tool
def start_backtracking(state: Annotated[MazeState, InjectedState], agent_id: Annotated[str, InjectedToolArg] = "agent_1") -> str:
    """Start backtracking to the nearest unexplored opening. Use when stuck or at dead end."""
    print(f"Starting backtracking for agent {agent_id}")
    try:
        # Check if already in backtracking mode - prevent redundant calls
        existing_backtrack_state = state.get('agent_backtracking_state', {})
        agent_backtrack = existing_backtrack_state.get(agent_id, {})
        if agent_backtrack.get('is_backtracking', False):
            # Agent is already backtracking, provide current guidance instead of recalculating
            current_step = agent_backtrack.get('current_step')
            path = agent_backtrack.get('path', [])
            target_pos = agent_backtrack.get('target_position')
            
            if current_step < len(path) - 1:
                next_move = _get_next_backtrack_move(path[current_step], path)
                remaining_steps = len(path) - current_step - 1
                return f"Already backtracking to {target_pos}. Continue with: {next_move} ({remaining_steps} steps remaining)"
            else:
                return "Backtracking target reached. Resume normal exploration."
        
        maze_wrapper = state['maze_wrappers'][agent_id]
        current_pos = maze_wrapper.get_agent_position()
        move_history = maze_wrapper.move_history
        known_openings = state.get('known_openings', {}).get(agent_id, [])
        
        # Calculate shortest path to nearest opening
        target_pos, path = _calculate_backtrack_path(current_pos, known_openings, move_history)
        
        if not target_pos:
            return "Start backtracking: No valid openings found. Continue exploration manually."
        
        # Initialize backtracking state with lock mode
        if 'agent_backtracking_state' not in state:
            state['agent_backtracking_state'] = {}
        
        state['agent_backtracking_state'][agent_id] = {
            'is_backtracking': True,
            'lock_mode': True,  # Lock mode prevents any deviations
            'target_position': target_pos,
            'path': path,
            'current_step': 0
        }
        
        next_move = _get_next_backtrack_move(current_pos, path)
        
        return f"Start backtracking: Target {target_pos} via path {path[:3]}{'...' if len(path) > 3 else ''}. Next move: {next_move}"
        
    except Exception as e:
        return f"Start backtracking: Error - {str(e)}"

def _calculate_backtrack_path(current_pos, known_openings, move_history):
    """Calculate shortest path to nearest unexplored opening using BFS through visited positions."""
    if not known_openings:
        return None, []
    
    # Parse openings to get target positions - ONLY UNEXPLORED ones
    targets = []
    for opening in known_openings:
        # Parse format: "(1,2)-north-UNEXPLORED" -> position (1,2)
        try:
            if "UNEXPLORED" not in opening:
                continue  # Skip explored openings
            pos_str = opening.split('-')[0]
            pos = eval(pos_str)  # Safe since we control the format
            targets.append(pos)
        except Exception:
            continue
    
    if not targets:
        return None, []
    
    # Use BFS to find shortest path through visited positions
    visited_positions = set(move_history)
    
    best_target = None
    best_path = []
    shortest_distance = float('inf')
    
    for target in targets:
        # Skip if we're already at the target
        if current_pos == target:
            continue
            
        path = _bfs_path(current_pos, target, visited_positions)
        if path and len(path) < shortest_distance:
            shortest_distance = len(path)
            best_target = target
            best_path = path
    
    # If no BFS path found, try direct path to closest target
    if not best_path and targets:
        closest_target = min(targets, key=lambda t: abs(t[0] - current_pos[0]) + abs(t[1] - current_pos[1]))
        if closest_target != current_pos:
            best_target = closest_target
            best_path = [current_pos, closest_target]
    
    return best_target, best_path

def _bfs_path(start, target, valid_positions):
    """BFS pathfinding through valid (visited) positions only."""
    if start == target:
        return [start]
    
    queue = deque([(start, [start])])
    visited = {start}
    
    while queue:
        current, path = queue.popleft()
        
        # Check all 4 directions
        for direction in ['north', 'south', 'east', 'west']:
            next_pos = _get_next_position_for_backtrack(current, direction)
            
            if (next_pos and 
                next_pos not in visited and 
                next_pos in valid_positions):
                
                new_path = path + [next_pos]
                
                if next_pos == target:
                    return new_path
                
                queue.append((next_pos, new_path))
                visited.add(next_pos)
    
    return []  # No path found

def _get_next_position_for_backtrack(current_pos, direction):
    """Calculate next position for backtracking (same as maze execution agent)."""
    row, col = current_pos
    direction_deltas = {
        'north': (-1, 0),
        'south': (1, 0),
        'east': (0, 1),
        'west': (0, -1)
    }
    
    if direction in direction_deltas:
        delta_row, delta_col = direction_deltas[direction]
        return (row + delta_row, col + delta_col)
    return None

def _get_next_backtrack_move(current_pos, path):
    """Get the next move direction from current position following the path."""
    if not path or len(path) < 2:
        return "Target reached"
    
    # Find current position in path
    found_current = False
    for i, pos in enumerate(path):
        if pos == current_pos:
            current_step = i
            found_current = True
            break
    
    # If current position not found in path, assume we're at the start
    if not found_current:
        print(f"WARNING: Current position {current_pos} not found in path {path[:3]}...")
        current_step = 0
    
    # Get next position in path
    next_step = current_step + 1
    if next_step >= len(path):
        return "Target reached"
    
    next_pos = path[next_step]
    row_diff = next_pos[0] - current_pos[0]
    col_diff = next_pos[1] - current_pos[1]
    
    # Validate move is adjacent
    if abs(row_diff) + abs(col_diff) != 1:
        print(f"WARNING: Invalid move from {current_pos} to {next_pos}")
        return f"Invalid move: {current_pos} -> {next_pos}"
    
    if row_diff == -1 and col_diff == 0:
        return "NORTH"
    elif row_diff == 1 and col_diff == 0:
        return "SOUTH"
    elif row_diff == 0 and col_diff == 1:
        return "EAST"
    elif row_diff == 0 and col_diff == -1:
        return "WEST"
    else:
        return f"Invalid move: {current_pos} -> {next_pos}"

# Export the tools list
MazeTools = [move_north, move_south, move_east, move_west, mark_dead_end, get_current_view, start_backtracking]