import jax
import jax.numpy as jnp
import chex
@chex.dataclass
class DFSState:
    stack : chex.Array
    stack_index : int
    visited : chex.Array
    path : chex.Array
    path_length : int
    iterations : int
    
# Add new helper function for JAX DFS
def jax_dfs_path_planning(grid, start_pos, target_pos, max_depth=100, adjacent_only=False):
    """
        Pure JAX implementation of DFS path planning.
        Returns the next action to take based on the DFS path.
    """
    # Get grid and door info
    
    # Ensure start_pos and target_pos are integer arrays
    start_pos = jnp.array(start_pos, dtype=jnp.int32)
    target_pos = jnp.array(target_pos, dtype=jnp.int32)
    
    # Initialize visited array and stack
    dfs_state = DFSState(
        stack=jnp.zeros((max_depth, 2), dtype=jnp.int32).at[0].set(start_pos),
        stack_index=0,
        visited=jnp.zeros_like(grid, dtype=jnp.bool_),
        path=jnp.zeros((max_depth, 2), dtype=jnp.int32),
        path_length=0,
        iterations=0
    )

    # Define helper functions for DFS
    def is_valid(pos):
        """Check if a position is valid (in bounds and not a wall)."""
        # Ensure position is integer type
        pos = pos.astype(jnp.int32)
        row, col = pos
        in_bounds = jnp.logical_and(
            jnp.logical_and(row >= 0, row < grid.shape[0]),
            jnp.logical_and(col >= 0, col < grid.shape[1])
        )
        # Only check grid if in bounds to avoid index errors
        def check_grid(_):
            blocked = grid[row, col] != 0
            return jnp.logical_not(blocked)
        
        # Return False if out of bounds, otherwise check grid
        return jax.lax.cond(
            in_bounds,
            check_grid,
            lambda _: False,
            operand=None
        )
    
    def is_target(pos):
        """Check if we've reached the target or an adjacent position."""
        # Ensure position is integer type
        pos = pos.astype(jnp.int32)
        if adjacent_only:
            # Check if adjacent to target
            adjacent_positions = jnp.array([
                [target_pos[0] + 1, target_pos[1]],  # South
                [target_pos[0] - 1, target_pos[1]],  # North
                [target_pos[0], target_pos[1] + 1],  # East
                [target_pos[0], target_pos[1] - 1],  # West
            ], dtype=jnp.int32)
            return jnp.any(jnp.all(pos == adjacent_positions, axis=1))
        else:
            return jnp.array_equal(pos, target_pos)
    
    def get_neighbors(pos):
        """Get valid neighbors of a position."""
        # Ensure position is integer type
        pos = pos.astype(jnp.int32)
        row, col = pos
        neighbors = jnp.array([
            [row + 1, col],  # South
            [row - 1, col],  # North
            [row, col + 1],  # East
            [row, col - 1],  # West
        ], dtype=jnp.int32)
        return neighbors
    
    def dfs_step(state):
        """Single step of DFS using JAX scan."""
        # If stack is empty, return as is
        has_elements = state.stack_index >= 0
        
        def process_stack(_):
            # Get current position
            current_pos = state.stack[state.stack_index].astype(jnp.int32)
        
            # jax.debug.print('Iteration {iterations}: Current position - {current_pos}', iterations=state.iterations, current_pos=current_pos)
            # Check if we've reached the target
            reached_target = is_target(current_pos)
            # jax.debug.print('Iteration {iterations}: Reached target - {reached_target}', iterations=state.iterations, reached_target=reached_target)

            # If we've reached the target, return the path
            def handle_target_reached(_):
                # jax.debug.print('Iteration {iterations}: Target reached - {current_pos}', iterations=state.iterations, current_pos=current_pos)
                return state.replace(
                    path_length=state.path_length + 1,
                    stack_index=-1, # done
                    path=state.path.at[state.path_length].set(current_pos),
                    iterations=state.iterations + 1
                )
            
            # If not at target, continue DFS
            def handle_continue_search(_):
                # Mark current position as visited
                current_pos_tuple = (current_pos[0], current_pos[1])
                visited_updated = state.visited.at[current_pos_tuple].set(True)
                # jax.debug.print('Iteration {iterations}: Visited updated - {visited_updated}', iterations=state.iterations, visited_updated=visited_updated)
                # jax.debug.print('Iteration {iterations}: Path - {path}, path_length - {path_length}', iterations=state.iterations, path=state.path, path_length=state.path_length)
                # Get neighbors
                neighbors = get_neighbors(current_pos)
                # sort neighbors by manhattan distance to target
                # greedily expand towards the target
                manhattan_distances = jnp.sum(jnp.abs(neighbors - target_pos), axis=1)
                neighbors = neighbors[jnp.argsort(manhattan_distances, descending=True)]

                # jax.debug.print('Iteration {iterations}: Neighbors - {neighbors}', iterations=state.iterations, neighbors=neighbors)
                # Get valid neighbors that haven't been visited
                is_valid_fn = jax.vmap(is_valid)
                valid_neighbors = is_valid_fn(neighbors)
                # jax.debug.print('Iteration {iterations}: Valid neighbors - {valid_neighbors}', iterations=state.iterations, valid_neighbors=valid_neighbors)
                # Check which neighbors have been visited
                visited_neighbors = jnp.array([
                    state.visited[neighbors[i][0], neighbors[i][1]] 
                    for i in range(4)
                ])
                unvisited_neighbors = jnp.logical_and(valid_neighbors, jnp.logical_not(visited_neighbors))
                # jax.debug.print('Iteration {iterations}: Unvisited neighbors - {unvisited_neighbors}', iterations=state.iterations, unvisited_neighbors=unvisited_neighbors)
                # If no valid unvisited neighbors, backtrack
                has_valid_neighbor = jnp.any(unvisited_neighbors)
                # jax.debug.print('Iteration {iterations}: Has valid neighbor - {has_valid_neighbor}', iterations=state.iterations, has_valid_neighbor=has_valid_neighbor)
                def handle_backtrack(_):
                    # jax.debug.print('Iteration {iterations}: Backtracking - Path - {path}, path_length - {path_length}', iterations=state.iterations, path=state.path, path_length=state.path_length)
                    # jax.debug.print('Iteration {iterations}: Backtracking - Current position - {current_pos}, stack_index - {stack_index}, stack - {stack}', iterations=state.iterations, current_pos=current_pos, stack_index=state.stack_index, stack=state.stack)
                    # not at target and no valid neighbors, backtrack
                    # pop from stack
                    # last position in path is the current position
                    # remove it from the path to backtrack
                    new_state = jax.lax.cond(
                        jnp.array_equal(state.path[state.path_length-1], current_pos),
                        lambda: state.replace(
                            stack_index=state.stack_index - 1,
                            path_length=state.path_length - 1,
                            visited=visited_updated,
                            iterations=state.iterations + 1
                        ),
                        lambda: state.replace(
                            visited=visited_updated,
                            iterations=state.iterations + 1,
                            stack_index=state.stack_index - 1,
                        )
                    )
                    return new_state
                
                # If there are valid neighbors, expand the search
                def handle_explore(_):
                    # Add to stack and update path
                    stack = state.stack
                    stack_index = state.stack_index
                    def _add_to_stack(stack_state, i):
                        stack, stack_index = stack_state
                        return jax.lax.cond(
                            unvisited_neighbors[i],
                            lambda: (stack.at[stack_index+1].set(neighbors[i]), stack_index+1),
                            lambda: (stack, stack_index),
                        ), None
                    (stack, stack_index), _ = jax.lax.scan(_add_to_stack, (stack, stack_index), jnp.arange(4))
                    new_path = state.path.at[state.path_length].set(current_pos) # update path
                    # jax.debug.print('Iteration {iterations}: New stack - {new_stack}, stack_index - {stack_index}', iterations=state.iterations, new_stack=stack, stack_index=stack_index)
                    return state.replace(
                        stack=stack,
                        path=new_path,
                        stack_index=stack_index,
                        path_length=state.path_length + 1,
                        visited=visited_updated,
                        iterations=state.iterations + 1
                    )
                
                return jax.lax.cond(
                    has_valid_neighbor,
                    handle_explore,
                    handle_backtrack,
                    operand=None
                )
            
            return jax.lax.cond(
                reached_target,
                handle_target_reached,
                handle_continue_search,
                operand=None
            )
        
        # Only process if stack has elements
        return jax.lax.cond(
            has_elements,
            process_stack,
            lambda _: state,
            operand=None
        )
    
    # Run DFS using JAX scan
    final_state = jax.lax.while_loop(
        lambda state: (state.stack_index >= 0) & (state.iterations < max_depth), # avoid infinite loops
        dfs_step,
        dfs_state,
    )
    
    return final_state