# file: region_graph.py
import numpy as np
import networkx as nx
from collections import deque
from scipy.ndimage import distance_transform_edt

# Assuming these constants are defined elsewhere and imported
FREE_CELL = 0
OBSTACLE_CELL = 1
UNKNOWN_CELL = 2

def heuristic(a, b):
    """Manhattan distance heuristic."""
    return abs(a[0] - b[0]) + abs(a[1] - b[1])

def build_adaptive_region_graph(
    persistent_known_map: np.ndarray,
    min_region_size: int = 100,
    max_region_size: int = 500
) -> (nx.Graph, Dict[int, np.ndarray], Dict[int, Tuple[int, int]]):
    """
    Builds an Adaptive Region Graph (ARG) from the known map.

    Returns:
        - nx.Graph: The high-level graph where nodes are region IDs.
        - Dict[int, np.ndarray]: A map from region_id to a boolean mask of the region's cells.
        - Dict[int, Tuple[int, int]]: A map from region_id to its centroid coordinate.
    """
    h, w = persistent_known_map.shape
    region_map = np.zeros_like(persistent_known_map, dtype=np.int32)
    visited = (persistent_known_map != FREE_CELL)
    region_id_counter = 1
    
    # 1. Flood-fill to find connected free-space clusters
    for r in range(h):
        for c in range(w):
            if not visited[r, c]:
                q = deque([(r, c)])
                visited[r, c] = True
                current_cluster_pixels = []
                while q:
                    curr_r, curr_c = q.popleft()
                    current_cluster_pixels.append((curr_r, curr_c))
                    region_map[curr_r, curr_c] = region_id_counter
                    for dr, dc in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                        nr, nc = curr_r + dr, curr_c + dc
                        if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc]:
                            visited[nr, nc] = True
                            q.append((nr, nc))
                
                # Simple splitting: if a region is too large, we just increment the ID
                # A more advanced version would use watershed or skeletonization as you suggested.
                # For now, we ensure regions aren't pathologically large.
                if len(current_cluster_pixels) > max_region_size:
                    # Basic splitting by re-labeling parts of the oversized cluster
                    sub_q = deque([current_cluster_pixels[0]])
                    sub_visited = set([current_cluster_pixels[0]])
                    count = 0
                    while sub_q:
                        sr, sc = sub_q.popleft()
                        count += 1
                        if count > max_region_size:
                            region_id_counter += 1
                            count = 0
                        
                        region_map[sr, sc] = region_id_counter
                        for dr, dc in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                            nr, nc = sr + dr, sc + dc
                            if (nr, nc) in current_cluster_pixels and (nr, nc) not in sub_visited:
                                sub_visited.add((nr,nc))
                                sub_q.append((nr, nc))

                region_id_counter += 1

    # 2. Extract region properties and build graph
    regions = {}
    centroids = {}
    unique_ids = np.unique(region_map)
    for rid in unique_ids:
        if rid == 0: continue # Skip non-region areas
        mask = (region_map == rid)
        if np.sum(mask) < min_region_size:
             region_map[mask] = 0 # Prune tiny regions
             continue
        regions[rid] = mask
        
        # Calculate centroid using distance transform for better placement
        dist_map = distance_transform_edt(mask)
        centroid_r, centroid_c = np.unravel_index(np.argmax(dist_map), dist_map.shape)
        centroids[rid] = (centroid_r, centroid_c)
        
    G = nx.Graph()
    for rid in regions:
        G.add_node(rid, centroid=centroids[rid])
        
    # 3. Find gateways and add edges to the graph
    for rid1 in regions:
        # Find neighbors by dilating the region mask
        from scipy.ndimage import binary_dilation
        dilated_mask = binary_dilation(regions[rid1])
        boundary_mask = dilated_mask & (~regions[rid1])
        
        neighbor_ids = np.unique(region_map[boundary_mask])
        for rid2 in neighbor_ids:
            if rid2 == 0 or rid2 <= rid1: continue # Avoid self-loops and duplicate edges
            if not G.has_node(rid2): continue

            # An edge exists if they are adjacent
            dist = heuristic(centroids[rid1], centroids[rid2])
            G.add_edge(rid1, rid2, weight=dist)
            
            # Optional: store gateway cells explicitly if needed for pathfinding
            # gateway_mask = boundary_mask & regions[rid2]
            # gateway_cells = np.argwhere(gateway_mask)
            # G[rid1][rid2]['gateways'] = gateway_cells
            
    return G, region_map, centroids

def find_region_path(
    start_pos: Tuple[int, int],
    goal_pos: Tuple[int, int],
    region_map: np.ndarray,
    region_graph: nx.Graph
) -> Optional[List[int]]:
    """Finds a path of region IDs from start to goal."""
    start_region = region_map[start_pos]
    goal_region = region_map[goal_pos]

    if start_region == 0 or goal_region == 0 or not region_graph.has_node(start_region) or not region_graph.has_node(goal_region):
        return None # Start or goal is in an invalid/unreachable area

    if start_region == goal_region:
        return [start_region]

    try:
        path = nx.astar_path(region_graph, start_region, goal_region, heuristic=lambda u, v: heuristic(region_graph.nodes[u]['centroid'], region_graph.nodes[v]['centroid']), weight='weight')
        return path
    except nx.NetworkXNoPath:
        return None