import numpy as np
from scipy.ndimage import binary_dilation, binary_closing
from skimage.measure import label, regionprops
from skimage.morphology import disk, skeletonize
from scipy.spatial import cKDTree
import scipy.ndimage as ndi
from sklearn.cluster import DBSCAN


class NodeDetector:
    def __init__(self, search_radius_list=[25, 35, 45], min_gap_width=10, max_gap_width=40,
                 cluster_eps=15, dilation_radius=4, check_interval=3, check_radius=10,
                 max_region_area=8000, gap_point_radius=8, obstacle_threshold=0.1):
        """
        Fully encapsulated gap detector

        Parameters:
            search_radius_list: List of candidate search radii [25, 35, 45]
            min_gap_width: Minimum gap width 10
            max_gap_width: Maximum gap width 40
            cluster_eps: Clustering radius 15
            dilation_radius: Dilation radius 4
            check_interval: Path check interval 3
            check_radius: Path check radius 10
            max_region_area: Maximum region area 8000
            gap_point_radius: Gap point check radius 8
            obstacle_threshold: Obstacle ratio threshold 0.1
        """
        # Store parameters
        self.search_radius_list = search_radius_list
        self.min_gap_width = min_gap_width
        self.max_gap_width = max_gap_width
        self.cluster_eps = cluster_eps
        self.dilation_radius = dilation_radius
        self.check_interval = check_interval
        self.check_radius = check_radius
        self.max_region_area = max_region_area
        self.gap_point_radius = gap_point_radius
        self.obstacle_threshold = obstacle_threshold

        # Initialize data
        self.obstacle_map = None
        self.exp_mask = None
        self.path = None
        self.agent_mask = None
        self.boundary = None
        self.endpoints = None
        self.gap_info = None
        self.merged_gap_info = None
        self.final_search_radius = None
        self.safe_gap_points = None
        self.dangerous_gap_indices = None

    def load_data(self, obstacle_exp, path_mask, agent_mask):
        """Load and preprocess all input data"""
        # Process obstacle and explorable area maps
        self.obstacle_map = obstacle_exp[0, :, :]
        self.exp_mask = obstacle_exp[1, :, :]

        path_points = np.column_stack(np.where(path_mask == 1))
        self.path = [tuple(p) for p in path_points[np.lexsort((path_points[:, 1], path_points[:, 0]))]]

        self.agent_mask = agent_mask

    def detect_boundary(self):
        """Detect boundary skeleton line"""
        dilated_exp = binary_dilation(self.exp_mask, structure=disk(self.dilation_radius))
        exp_boundary = dilated_exp ^ self.exp_mask

        boundary_obstacles = np.zeros_like(self.obstacle_map)
        for i in range(self.obstacle_map.shape[0]):
            for j in range(self.obstacle_map.shape[1]):
                if exp_boundary[i, j] == 1:
                    for di in range(-self.dilation_radius, self.dilation_radius + 1):
                        for dj in range(-self.dilation_radius, self.dilation_radius + 1):
                            ni, nj = i + di, j + dj
                            if (0 <= ni < self.obstacle_map.shape[0] and
                                    0 <= nj < self.obstacle_map.shape[1] and
                                    self.obstacle_map[ni, nj] == 1):
                                boundary_obstacles[ni, nj] = 1

        boundary_obstacles = binary_closing(boundary_obstacles, structure=disk(2))

        # Filter small and large regions
        labeled = label(boundary_obstacles)
        for region in regionprops(labeled):
            if region.area < 15 or region.area > self.max_region_area:
                for coord in region.coords:
                    boundary_obstacles[coord[0], coord[1]] = 0

        self.boundary = skeletonize(boundary_obstacles > 0).astype(int)
        self.find_endpoints()

    def find_endpoints(self):
        """Detect skeleton endpoints"""
        kernel = np.array([[1, 1, 1], [1, 10, 1], [1, 1, 1]])
        conv = ndi.convolve(self.boundary.astype(int), kernel, mode='constant')
        endpoints = np.where(conv == 11)
        self.endpoints = list(zip(endpoints[0], endpoints[1]))

    def is_valid_gap_point(self, point):
        """Check if gap point is valid"""
        y, x = point
        height, width = self.obstacle_map.shape

        # Create circular mask
        y_grid, x_grid = np.ogrid[-self.gap_point_radius:self.gap_point_radius + 1,
                         -self.gap_point_radius:self.gap_point_radius + 1]
        mask = y_grid ** 2 + x_grid ** 2 <= self.gap_point_radius ** 2

        # Calculate valid area
        y_min, y_max = max(0, y - self.gap_point_radius), min(height, y + self.gap_point_radius + 1)
        x_min, x_max = max(0, x - self.gap_point_radius), min(width, x + self.gap_point_radius + 1)

        # Adjust mask for boundary cases
        mask = mask[(self.gap_point_radius - (y - y_min)):(self.gap_point_radius + (y_max - y)),
               (self.gap_point_radius - (x - x_min)):(self.gap_point_radius + (x_max - x))]

        # Extract local obstacle area
        local_obstacle = self.obstacle_map[y_min:y_max, x_min:x_max]

        # Calculate obstacle ratio
        obstacle_area = np.sum(local_obstacle[mask])
        total_area = np.sum(mask)
        obstacle_ratio = obstacle_area / total_area

        return obstacle_ratio <= self.obstacle_threshold

    def detect_gap_midpoints(self):
        """Detect gap points"""
        gap_info = {
            'gap_points': [],
            'start_points': [],
            'end_points': []
        }
        traversable = np.logical_and(self.exp_mask, ~self.obstacle_map)

        if len(self.endpoints) < 2:
            return gap_info

        tree = cKDTree(self.endpoints)

        for i, start in enumerate(self.endpoints):
            # Find other endpoints within search_radius
            indices = tree.query_ball_point(start, self.final_search_radius)

            for idx in indices:
                if idx <= i:
                    continue

                end = self.endpoints[idx]
                dist = np.linalg.norm(np.array(start) - np.array(end))

                if self.min_gap_width <= dist <= self.max_gap_width:
                    # Calculate midpoint
                    direction = (np.array(end) - np.array(start)) / dist
                    midpoint = (
                        int(start[0] + direction[0] * dist / 2),
                        int(start[1] + direction[1] * dist / 2)
                    )

                    # Verify if midpoint is traversable and meets obstacle ratio constraint
                    if (0 <= midpoint[0] < self.obstacle_map.shape[0] and
                            0 <= midpoint[1] < self.obstacle_map.shape[1] and
                            traversable[midpoint[0], midpoint[1]] and
                            self.is_valid_gap_point(midpoint)):
                        gap_info['gap_points'].append(midpoint)
                        gap_info['start_points'].append(start)
                        gap_info['end_points'].append(end)

        return gap_info

    def cluster_and_merge_gaps(self, gap_info):
        """Cluster and merge nearby gap points"""
        if not gap_info['gap_points']:
            return gap_info

        # Prepare clustering data
        points = np.array(gap_info['gap_points'])

        # DBSCAN clustering
        clustering = DBSCAN(eps=self.cluster_eps, min_samples=1).fit(points)
        labels = clustering.labels_

        # Merge clustering results
        merged_gap_info = {
            'gap_points': [],
            'start_points': [],
            'end_points': []
        }

        unique_labels = set(labels)
        for cluster_id in unique_labels:
            if cluster_id == -1:  # Noise points, do not merge
                continue

            # Get all point indices in current cluster
            cluster_indices = np.where(labels == cluster_id)[0]

            # Calculate cluster center (midpoint)
            cluster_points = points[cluster_indices]
            center = np.mean(cluster_points, axis=0).astype(int)

            # Calculate merged start and end points (average of all start/end points)
            all_starts = np.array([gap_info['start_points'][i] for i in cluster_indices])
            all_ends = np.array([gap_info['end_points'][i] for i in cluster_indices])

            merged_start = np.mean(all_starts, axis=0).astype(int)
            merged_end = np.mean(all_ends, axis=0).astype(int)

            # Add to results
            merged_gap_info['gap_points'].append(tuple(center))
            merged_gap_info['start_points'].append(tuple(merged_start))
            merged_gap_info['end_points'].append(tuple(merged_end))

        # Add unclustered points (noise points)
        noise_indices = np.where(labels == -1)[0]
        for idx in noise_indices:
            merged_gap_info['gap_points'].append(gap_info['gap_points'][idx])
            merged_gap_info['start_points'].append(gap_info['start_points'][idx])
            merged_gap_info['end_points'].append(gap_info['end_points'][idx])

        return merged_gap_info

    def check_and_remove_gaps_in_path(self):
        """Check gap points on path and remove dangerous points"""
        if not self.path or not self.merged_gap_info['gap_points']:
            self.safe_gap_points = self.merged_gap_info['gap_points'] if self.merged_gap_info else []
            self.dangerous_gap_indices = set()
            return

        gap_tree = cKDTree(self.merged_gap_info['gap_points'])
        dangerous_gap_indices = set()

        for i, path_point in enumerate(self.path):
            if i % self.check_interval == 0:  # Check every check_interval path points
                # Find Gap points within check_radius pixels
                indices = gap_tree.query_ball_point(path_point, self.check_radius)
                dangerous_gap_indices.update(indices)

        # Filtered Gap points (removed dangerous points)
        self.safe_gap_points = [p for idx, p in enumerate(self.merged_gap_info['gap_points'])
                                if idx not in dangerous_gap_indices]
        self.dangerous_gap_indices = dangerous_gap_indices

    def create_sorted_gap_masks(self):
        """Create 3D gap mask array sorted by agent distance"""
        if not self.safe_gap_points:
            return np.zeros((0, *self.obstacle_map.shape), dtype=np.uint8)

        # Get agent position
        agent_pos = None
        if self.agent_mask is not None:
            agent_positions = np.argwhere(self.agent_mask == 1)
            if len(agent_positions) > 0:
                agent_pos = tuple(agent_positions[0])  # Take first agent position

        # Create 3D array
        gap_masks = np.zeros((len(self.safe_gap_points), *self.obstacle_map.shape), dtype=np.uint8)

        if agent_pos is not None:
            # Calculate distances and sort
            gap_points = np.array(self.safe_gap_points)
            distances = np.linalg.norm(gap_points - np.array(agent_pos), axis=1)
            sorted_indices = np.argsort(distances)

            # Fill 3D array in distance order
            for i, idx in enumerate(sorted_indices):
                y, x = self.safe_gap_points[idx]
                gap_masks[i, y, x] = 1
        else:
            # Maintain original order when no agent position
            for i, (y, x) in enumerate(self.safe_gap_points):
                gap_masks[i, y, x] = 1

        return gap_masks

    def run(self, obstacle_exp, path_mask=None, agent_mask=None):
        """
        Execute complete gap detection process

        Parameters:
            obstacle_exp: Obstacle and explorable area numpy array (2, H, W)
            path_mask: Path mask numpy array (H, W)
            agent_mask: Agent position mask numpy array (H, W)

        Returns:
            3D gap mask array sorted by agent distance (num_gaps, H, W)
        """
        # Load data
        self.load_data(obstacle_exp, path_mask, agent_mask)

        # Detect boundary
        self.detect_boundary()

        # Try different search radii
        for search_radius in self.search_radius_list:
            self.final_search_radius = search_radius
            self.gap_info = self.detect_gap_midpoints()
            self.merged_gap_info = self.cluster_and_merge_gaps(self.gap_info)

            # If enough gaps detected or all radii tried
            if len(self.merged_gap_info['gap_points']) > 1 or search_radius == self.search_radius_list[-1]:
                break

        # Check and remove dangerous gap points near path
        self.check_and_remove_gaps_in_path()

        # Create and return sorted 3D gap mask array
        return self.create_sorted_gap_masks()