from collections import defaultdict
from typing import List, Set, Tuple, Dict
from numpy.random import Generator
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from tasks import register

def create_grid(rng: Generator, k: int, randomize=False) -> List[List[int]]:
    """Create a k x k grid with sequential numbers."""
    if randomize:
        random_map = rng.permutation(k**2)
        return [[random_map[i*k + j] for j in range(k)] for i in range(k)]
    else:
        return [[i * k + j for j in range(k)] for i in range(k)]

def select_subgrid(grid: List[List[int]], 
                  top_left: Tuple[int, int], 
                  height: int, 
                  width: int) -> List[List[int]]:
    """
    Select a subgrid from the main grid starting at top_left with given dimensions.
    
    Args:
        grid: The original grid
        top_left: (row, col) coordinates of the top-left corner of subgrid
        height: Height of the subgrid (m)
        width: Width of the subgrid (n)
    """
    if not grid or not grid[0]:
        raise ValueError("Grid cannot be empty")

    k = len(grid)
    i, j = top_left

    if i < 0 or j < 0:
        raise ValueError("Subgrid dimensions exceed grid boundaries")

    return [[grid[r % k][c % k] for c in range(j, j + width)] for r in range(i, i + height)]

def visualize_grid_selection(full_grid: List[List[int]], 
                           subgrid: List[List[int]], 
                           top_left: Tuple[int, int]):
    """Visualize the full grid and highlight the selected subgrid."""
    plt.figure(figsize=(12, 12))
    plt.title("Grid Selection Visualization")
    
    # Create full grid visualization
    k = len(full_grid)
    
    # Draw grid lines
    for i in range(k + 1):
        plt.axhline(y=-i, color='gray', linestyle='-', alpha=0.3)
        plt.axvline(x=i, color='gray', linestyle='-', alpha=0.3)
    
    # Plot all numbers
    for i in range(k):
        for j in range(k):
            plt.text(j + 0.5, -i - 0.5, str(full_grid[i][j]), 
                    horizontalalignment='center',
                    verticalalignment='center')
    
    # Highlight selected subgrid
    i, j = top_left
    height, width = len(subgrid), len(subgrid[0])
    rect = plt.Rectangle((j, -i), width, -height, 
                        fill=False, color='red', linewidth=2)
    plt.gca().add_patch(rect)
    
    plt.xlim(-0.5, k + 0.5)
    plt.ylim(-(k + 0.5), 0.5)
    plt.axis('equal')
    plt.show()

class WilsonMaze:
    def __init__(self, grid: List[List[int]], rng: Generator, exclude_nodes: Set[int]=set()):
        """Initialize maze with a grid of integer nodes."""
        self.grid = grid
        self.m = len(grid)
        self.n = len(grid[0])
        self.node_positions = {grid[i][j]: (i, j) for i in range(self.m) for j in range(self.n)}
        self.edges = set()
        self.rng = rng
        self.exclude_nodes = exclude_nodes

        # Get boundary nodes
        # self.boundary_nodes = (
        #     [grid[0][j] for j in range(self.n)] +  # top
        #     [grid[self.m-1][j] for j in range(self.n)] +  # bottom
        #     [grid[i][0] for i in range(1, self.m-1)] +  # left
        #     [grid[i][self.n-1] for i in range(1, self.m-1)]  # right
        # )

    def get_neighbors(self, node: int) -> List[int]:
        """Return valid neighboring nodes for a given node."""
        i, j = self.node_positions[node]
        neighbors = []
        for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:  # right, down, left, up
            ni, nj = i + di, j + dj
            if 0 <= ni < self.m and 0 <= nj < self.n and self.grid[ni][nj] not in self.exclude_nodes:
                neighbors.append(self.grid[ni][nj])
        return neighbors
    
    def which_neighbor(self, node1: int, node2: int) -> int:
        """Return the index of the neighbor of node1 that is node2."""
        i1, j1 = self.node_positions[node1]
        i2, j2 = self.node_positions[node2]
        if i1 == i2 and j2 == j1 + 1:
            return 0
        elif i2 == i1 + 1 and j2 == j1:
            return 1
        elif i1 == i2 and j2 == j1 - 1:
            return 2
        elif i2 == i1 - 1 and j2 == j1:
            return 3
        raise ValueError("Nodes are not neighbors")
    
    def looped_erased_random_walk(self, start: int, target_nodes: Set[int]) -> List[int]:
        """Implement loop-erased random walk from start until hitting target nodes."""
        current = start
        walk = [start]

        while current not in target_nodes:
            neighbors = self.get_neighbors(current)
            next_node = self.rng.choice(neighbors)
            
            # If we've seen this node before, erase the loop
            if next_node in walk:
                loop_start = walk.index(next_node)
                walk = walk[:loop_start + 1]
            else:
                walk.append(next_node)
            
            current = next_node
            
        return walk

    def generate_maze(self, n_nodes: int=None) -> Tuple[Set[Tuple[int, int]], int, int]:
        """Generate maze using Wilson's algorithm."""
        # Start with a random node in the maze
        all_nodes = {self.grid[i][j] for i in range(self.m) for j in range(self.n)}
        start_node = self.rng.choice(list(all_nodes))
        end_node = start_node
        maze_nodes = {start_node}
        remaining_nodes = all_nodes - maze_nodes
        # start_pos = self.node_positions[start_node]
        # end_pos = self.node_positions[end_node]

        # Keep adding paths until all nodes are in the maze
        while remaining_nodes:
            start = self.rng.choice(list(remaining_nodes))
            path = self.looped_erased_random_walk(start, maze_nodes)[::-1]

            # Add the path to the maze
            for i in range(len(path)):
                if i < len(path) - 1:
                    self.edges.add(tuple(sorted([path[i], path[i+1]])))
                maze_nodes.add(path[i])

                # pos = self.node_positions[path[i]]
                # # Update start and end positions based on current node position
                # if pos[0] + pos[1] < start_pos[0] + start_pos[1]:
                #     # If current position is more top-left than start_pos
                #     start_node = path[i]
                #     start_pos = pos

                # if pos[0] + pos[1] > end_pos[0] + end_pos[1]:
                #     # If current position is more bottom-right than end_pos
                #     end_node = path[i]
                #     end_pos = pos

                # Stop if we've reached the desired number of nodes
                if n_nodes and len(maze_nodes) >= n_nodes:
                    break

            remaining_nodes = all_nodes - maze_nodes
            # If we've reached the desired number of nodes, break out of the outer loop too
            if n_nodes and len(maze_nodes) >= n_nodes:
                break

        # Choose random start and end points from boundary
        # start_node = random.choice(self.boundary_nodes)
        # end_node = random.choice(self.boundary_nodes)
        # while end_node == start_node:
        #     end_node = random.choice(self.boundary_nodes)

        # Fix start node as the bottom left and end node as the top right
        # start_node = self.grid[0][0]
        # end_node = self.grid[self.m-1][self.n-1]

        start_node = self.rng.choice(list(maze_nodes))
        end_node = self.rng.choice(list(maze_nodes - {start_node}))

        return self.edges, start_node, end_node
    
    def find_path(self, start: int, end: int) -> List[int]:
        """Find path from start to end using BFS."""
        queue = [(start, [start])]
        visited = {start}
        
        while queue:
            current, path = queue.pop(0)
            if current == end:
                return path

            for neighbor in self.get_neighbors(current):
                if neighbor not in visited and tuple(sorted([current, neighbor])) in self.edges:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))
        
        return []  # No path found

    def find_path_dfs(self, start: int, end: int) -> List[int]:
        """
        Find path from start to end using DFS, and return the search trace as an array of branches.
        Each branch represents a sequence of nodes explored until reaching a dead end or having to backtrack.
        """
        stack = [(start, [start], start)]
        visited = {start}
        search_branches = [[]]
        current_branch_idx = 0
        current_branch_root = start
        last_explored = None
        
        while stack:
            current, path, branch_root = stack.pop()

            if self.node_positions.get(current) and self.node_positions.get(last_explored) and \
               tuple(sorted([current, last_explored])) not in self.edges:
                search_branches.append([branch_root])
                current_branch_idx += 1

            # Continue current branch
            search_branches[current_branch_idx].append(current)
            last_explored = current

            if current == end:
                return path, search_branches
                
            neighbors = []
            for neighbor in self.get_neighbors(current):
                if neighbor not in visited and tuple(sorted([current, neighbor])) in self.edges:
                    visited.add(neighbor)
                    neighbors.append(neighbor)
            
            if len(neighbors) > 1:
                current_branch_root = current
            else:
                current_branch_root = branch_root

            # Add neighbors to stack in reverse order so we explore them in the original order
            for neighbor in reversed(neighbors):
                stack.append((neighbor, path + [neighbor], current_branch_root))

        return [], search_branches

    def find_path_beam_search(self, start: int, end: int, num_beams: int = 3) -> tuple[list[int], list[list[int]]]:
        """
        Find path from start to end using a beam search-like DFS algorithm.
        
        This maintains multiple parallel beams (stacks) with a shared visited set.
        When a beam's stack is exhausted, it steals work from other active beams.
        
        Args:
            start: Starting node
            end: Target node
            num_beams: Number of parallel search paths to maintain
            
        Returns:
            Tuple containing:
            - Path from start to end if found, empty list otherwise
            - Search trace as a list of lists, with each sublist representing one beam's exploration
        """
        # Initialize beam stacks and shared visited set
        beams = []
        for i in range(num_beams):
            beams.append([(start, [start], start)])  # (current_node, path, branch_root)
        
        visited = {start}
        search_traces = [[] for _ in range(num_beams)]
        last_explored = [None] * num_beams
        
        # Keep track of active beams
        active_beams = set(range(num_beams))
        
        # Process all beams in parallel
        while active_beams:
            # For each active beam, process one node
            for beam_idx in list(active_beams):
                # If this beam's stack is empty, try to steal work
                if not beams[beam_idx]:
                    # Try to steal from other active beams
                    stolen = False
                    for other_beam_idx in active_beams:
                        if beam_idx == other_beam_idx:
                            continue
                        if beams[other_beam_idx] and len(beams[other_beam_idx]) > 1:
                            # Steal half of the work from the other beam
                            steal_count = len(beams[other_beam_idx]) // 2
                            stolen_work = beams[other_beam_idx][-steal_count:]
                            beams[other_beam_idx] = beams[other_beam_idx][:-steal_count]
                            beams[beam_idx].extend(stolen_work)
                            stolen = True
                            break
                    
                    # If couldn't steal work, mark this beam as inactive
                    if not stolen:
                        active_beams.remove(beam_idx)
                        continue
                
                # Process one node from this beam
                if beams[beam_idx]:
                    current, path, branch_root = beams[beam_idx].pop()
                    
                    # Check if we need to start a new branch in the search trace
                    if (search_traces[beam_idx] and 
                        last_explored[beam_idx] is not None and 
                        self.node_positions.get(current) and 
                        self.node_positions.get(last_explored[beam_idx]) and
                        tuple(sorted([current, last_explored[beam_idx]])) not in self.edges):
                        search_traces[beam_idx].append([branch_root])
                    
                    # Add current node to the current branch in search trace
                    if not search_traces[beam_idx]:
                        search_traces[beam_idx].append([])
                    search_traces[beam_idx][-1].append(current)
                    last_explored[beam_idx] = current
                    
                    # Check if we've reached the end
                    if current == end:
                        return path, search_traces
                    
                    # Get valid neighbors
                    neighbors = []
                    for neighbor in self.get_neighbors(current):
                        if neighbor not in visited and tuple(sorted([current, neighbor])) in self.edges:
                            visited.add(neighbor)
                            neighbors.append(neighbor)
                    
                    # Determine branch root for children
                    if len(neighbors) > 1:
                        current_branch_root = current
                    else:
                        current_branch_root = branch_root
                    
                    # Add neighbors to this beam's stack
                    for neighbor in reversed(neighbors):
                        beams[beam_idx].append((neighbor, path + [neighbor], current_branch_root))
        
        # If we've explored all possible paths and haven't found the end
        return [], search_traces

def visualize_maze(grid: List[List[int]], 
                  edges: Set[Tuple[int, int]], 
                  start: int, 
                  end: int, 
                  solution: List[int] = None,
                  title: str = "Maze Visualization"):
    """Visualize the maze using networkx."""
    # Create graph
    G = nx.Graph()
    
    # Add edges
    G.add_edges_from(edges)
    
    # Get node positions for visualization
    m, n = len(grid), len(grid[0])
    pos = {grid[i][j]: (j, -i) for i in range(m) for j in range(n)}  # Flip y-axis for better visualization
    
    # Set up the plot
    plt.figure(figsize=(10, 10))
    plt.title(title)
    
    # Draw the maze structure
    nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
    
    # Draw all nodes
    node_colors = ['lightgray'] * (m * n)
    node_sizes = [500] * (m * n)
    
    # Highlight start and end nodes
    all_nodes = [grid[i][j] for i in range(m) for j in range(n)]
    start_idx = all_nodes.index(start)
    end_idx = all_nodes.index(end)
    node_colors[start_idx] = 'green'
    node_colors[end_idx] = 'red'
    
    # Draw solution path if provided
    if isinstance(solution[0], list):
        solution = sum(solution, [])
    if solution:
        # Highlight solution path edges
        solution_edges = list(zip(solution[:-1], solution[1:]))
        nx.draw_networkx_edges(G, pos, edgelist=solution_edges, edge_color='blue', width=2)
        
        # Highlight solution nodes
        for node in solution:
            node_idx = all_nodes.index(node)
            if node not in [start, end]:
                node_colors[node_idx] = 'lightblue'
    
    # Draw nodes
    nx.draw_networkx_nodes(G, pos, nodelist=all_nodes, node_color=node_colors, node_size=node_sizes)
    
    # Add node labels
    labels = {node: str(node) for node in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels)
    
    plt.axis('equal')
    plt.show()

def main_with_subgrid(rng: Generator, k: int, m: int, n: int, n_nodes: int = None,
                      top_left: Tuple[int, int] = None, 
                      visualize: bool = True, padding: bool = False, randomize=False, 
                      dfs=True, beam_search=False, num_beams=3) -> Tuple[WilsonMaze,
                                                               Set[Tuple[int, int]], 
                                                               int, 
                                                               int, 
                                                               List[int]]:
    """
    Create a k x k grid, select an m x n subgrid, and generate a maze in the subgrid.
    
    Args:
        k: Size of the full grid (k x k)
        m: Height of the subgrid
        n: Width of the subgrid
        top_left: Optional (row, col) for subgrid selection. Random if not provided.
        visualize: Whether to visualize the grid selection and maze.
        padding: Whether to pad the subgrid with paths to the corners of the full grid.
        randomize: Whether to randomize the node numbering in the grid.
        dfs: Whether to use DFS for pathfinding. If False, uses BFS.
        beam_search: Whether to use beam search for pathfinding. Takes precedence over dfs.
        num_beams: Number of parallel beams to use in beam search.
    """
    # Create the full grid
    full_grid = create_grid(rng, k, randomize=randomize)

    # Select random top-left corner if not provided
    if top_left is None:
        top_left = (int(rng.integers(0, k-m+1)), int(rng.integers(0, k-n+1)))

    # Select the subgrid
    subgrid = select_subgrid(full_grid, top_left, m, n)

    # Visualize the selection if requested
    if visualize:
        visualize_grid_selection(full_grid, subgrid, top_left)

    # Generate and solve maze in the subgrid
    maze = WilsonMaze(subgrid, rng)
    edges, start, end = maze.generate_maze(n_nodes=n_nodes)
    
    # Choose pathfinding algorithm based on parameters
    if beam_search:
        solution = maze.find_path_beam_search(start, end, num_beams=num_beams)[1]
    elif dfs:
        solution = maze.find_path_dfs(start, end)[1]
    else:
        solution = maze.find_path(start, end)

    # if padding:
    #     maze_nodes = {node for edge in edges for node in edge if node not in {start, end}}

    #     temp_maze = WilsonMaze(full_grid, rng, maze_nodes)
    #     left_pad = temp_maze.looped_erased_random_walk(full_grid[0][0], {start})
    #     right_pad = temp_maze.looped_erased_random_walk(end, {full_grid[k-1][k-1]})

    #     for edge in zip(left_pad[:-1], left_pad[1:]):
    #         edges.add(edge)
    #     for edge in zip(right_pad[:-1], right_pad[1:]):
    #         edges.add(edge)

    #     # Update solution based on the pathfinding algorithm used
    #     if beam_search or dfs:
    #         if len(left_pad) > 1:
    #             solution[0] = left_pad[:-1] + solution[0]
    #         if len(right_pad) > 1:
    #             solution[-1] = solution[-1] + right_pad[1:]
    #         start = solution[0][0]
    #         end = solution[-1][-1]
    #     else:
    #         solution = left_pad[:-1] + solution + right_pad[1:]
    #         start = solution[0]
    #         end = solution[-1]
    #     maze = temp_maze
    
    if visualize:
        # Visualize maze with solution
        visualize_maze(full_grid, edges, start, end, solution, title="Maze with Solution Path")

    return maze, edges, start, end, solution

import json

def sample_skewed(rng: Generator, range):
    arr = np.arange(*range, dtype=float)
    weights = (arr + 5) / (arr + 5).sum()
    n = int(rng.choice(arr, p=weights))
    return n

@register()
def maze(rng: Generator, full_grid_size, sub_grid_size=None, n_nodes=None, repr='edge', padding=False, randomize=False, dfs=True, beam_search=False, num_beams=3):
    k = full_grid_size

    if n_nodes:
        n = full_grid_size
        n_nodes = sample_skewed(rng, n_nodes)
    else:
        n = sample_skewed(rng, sub_grid_size)

    maze, edges, start, end, solution = main_with_subgrid(
        rng, k, m=n, n=n, n_nodes=n_nodes, visualize=False, padding=padding, randomize=randomize, 
        dfs=dfs, beam_search=beam_search, num_beams=num_beams
    )
    
    if repr == 'edge':
        maze_str = ','.join(map(lambda x: f'[{x[0]}][{x[1]}]', edges))
    elif repr == 'adj_list':
        if padding:
            maze.exclude_nodes = {}
            adj = dict()
            for node in range(k**2):
                adj[node] = ['X'] * 4
            for edge in edges:
                adj[edge[0]][maze.which_neighbor(*edge)] = edge[1]
                adj[edge[1]][maze.which_neighbor(*edge[::-1])] = edge[0]
        else:
            adj = defaultdict(set)
            for edge in edges:
                adj[edge[0]].add(edge[1])
                adj[edge[1]].add(edge[0])
        if randomize:
            node_list = rng.permutation(k**2)
        else:
            node_list = list(range(k**2))
        maze_str = ','.join(map(lambda node: f"[{node}]:{''.join(map(lambda n: f'[{n}]', adj[node]))}", node_list))
    else:
        raise ValueError("Invalid representation type. Choose 'edge' or 'adj_list'.")

    if dfs:
        solution_str = []
        for branch in solution:
            solution_str.append(''.join(map(lambda x: f'[{x}]', branch)))
        solution_str = ';'.join(solution_str)
    else:
        solution_str = ''.join(map(lambda x: f'[{x}]', solution))
    # info = {
    #     'repr': repr,
    #     'maze': json.dumps(edges if repr == 'edge' else adj)
    # }
    return maze_str + f'?[{start}]>[{end}]?', solution_str, None

# Example usage
if __name__ == "__main__":
    rng = np.random.default_rng(7)
    print(maze(rng, 8, [7,8], repr='adj_list', padding=False, dfs=False, randomize=True))
