import numpy as np
from skimage.morphology import skeletonize
from scipy.ndimage import binary_fill_holes, binary_dilation
import cv2
import math
from collections import deque


class EndpointDetector:
    """Endpoint Detector: Detects skeleton endpoints from traversable areas in depth maps and returns binary masks
    This file collaborates with other files to construct the region-junction graph.
    """
    

    def __init__(self, height_offset=0.88, height_low_threshold=0, height_up_threshold=0.8749, min_component_area=100, hfov=79):
        """
        Initialize the endpoint detector

        Args:
            height_offset (float): Height offset value
            height_low_threshold (float): Lower height threshold
            height_up_threshold (float): Upper height threshold  
            min_component_area (int): Minimum connected component area threshold
            hfov (int): Horizontal field of view in degrees
        """
        self.height_offset = height_offset
        self.height_low_threshold = height_low_threshold
        self.height_up_threshold = height_up_threshold
        self.min_component_area = min_component_area
        self.hfov = hfov

    def _depth_to_relative_height(self, depth_image):
        """Calculate relative height map from depth image"""
        if depth_image.ndim > 2:
            depth_image = depth_image[:, :, 0]
        h, w = depth_image.shape
        focal = w / (2 * np.tan(np.radians(self.hfov / 2)))
        y_map = (np.indices((h, w))[0] - h / 2) * depth_image / focal
        return -y_map

    def _filter_small_components(self, binary_mask):
        """Filter connected components using different strategies based on component count"""
        # Find connected components
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
            binary_mask, connectivity=8)

        # Create filtered mask
        filtered_mask = np.zeros_like(binary_mask)

        # Actual component count (skip background label 0)
        component_count = num_labels - 1
        # print(f"Detected {component_count} connected components")

        if component_count == 0:
            # print("No connected components detected")
            return filtered_mask

        # Get all component info and sort by area (descending)
        components_info = []
        for i in range(1, num_labels):
            area = stats[i, cv2.CC_STAT_AREA]
            components_info.append((i, area))
            # print(f"Component {i}: Area = {area} pixels")

        # Sort by area descending
        components_info.sort(key=lambda x: x[1], reverse=True)
        # print(f"Components sorted by area: {[(i, area) for i, area in components_info]}")

        kept_components = 0

        if component_count == 1:
            # Only 1 component: Keep without area filtering
            label_id, area = components_info[0]
            filtered_mask[labels == label_id] = 255
            kept_components = 1
            # print(f"  -> Kept only component {label_id} (Area = {area}, no filtering)")

        elif component_count <= 2:
            # <= 2 components: Apply area threshold to all
            # print("Component count <= 2, applying area threshold to all")
            for label_id, area in components_info:
                if area >= self.min_component_area:
                    filtered_mask[labels == label_id] = 255
                    kept_components += 1
                    # print(
                    #     f"  -> Kept component {label_id} (Area = {area} >= {self.min_component_area})")
                else:
                    # print(
                    #     f"  -> Filtered component {label_id} (Area = {area} < {self.min_component_area})")
                    pass

        else:
            # >= 3 components: Keep only top 2 largest, then apply threshold
            # print("Component count >= 3, keeping top 2 largest with area threshold")
            top_two_components = components_info[:2]

            for label_id, area in top_two_components:
                if area >= self.min_component_area:
                    filtered_mask[labels == label_id] = 255
                    kept_components += 1
                    # print(
                    #     f"  -> Kept component {label_id} (Area = {area} >= {self.min_component_area})")
                else:
                    # print(
                    #     f"  -> Filtered component {label_id} (Area = {area} < {self.min_component_area})")
                    pass

            # Output discarded components info
            discarded_components = components_info[2:]
            # for label_id, area in discarded_components:
            #     # print(f"  -> Discarded component {label_id} (Area = {area}, not in top 2)")

        # print(f"Final kept components: {kept_components}")
        return filtered_mask

    def _get_clean_skeleton(self, binary_mask):
        """Extract clean skeleton from binary mask"""
        binary_mask = (binary_mask > 0).astype(np.uint8) * 255
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        dilated = cv2.dilate(binary_mask, kernel, iterations=2)
        skeleton = skeletonize(dilated // 255)
        return skeleton.astype(np.uint8) * 255

    def _find_skeleton_endpoints(self, skeleton):
        """Detect skeleton endpoints"""
        kernel = np.array([[1, 1, 1],
                           [1, 10, 1],
                           [1, 1, 1]], dtype=np.uint8)
        filtered = cv2.filter2D(skeleton // 255, -1, kernel)
        endpoints = np.where(filtered == 11)
        return list(zip(endpoints[1], endpoints[0]))  # (x, y) format

    def _find_skeleton_junctions(self, skeleton):
        """Detect skeleton junctions (points with >=3 neighbors)"""
        kernel = np.array([[1, 1, 1],
                           [1, 10, 1],
                           [1, 1, 1]], dtype=np.uint8)
        filtered = cv2.filter2D(skeleton // 255, -1, kernel)
        junctions = np.where(filtered >= 13)
        return list(zip(junctions[1], junctions[0]))

    def _remove_redundant_edges(self, skeleton):
        """Remove redundant edges in loops that don't connect to other traversable areas"""
        clean_skeleton = skeleton.copy()
        junctions = self._find_skeleton_junctions(skeleton)

        for jx, jy in junctions:
            # Get all connected branches
            branches = []
            visited = np.zeros_like(skeleton, dtype=bool)

            # Check 8-neighborhood
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    if dx == 0 and dy == 0:
                        continue
                    nx, ny = jx + dx, jy + dy
                    if 0 <= nx < skeleton.shape[1] and 0 <= ny < skeleton.shape[0]:
                        if skeleton[ny, nx] > 0 and not visited[ny, nx]:
                            # Track this branch
                            branch = []
                            queue = deque([(nx, ny)])
                            visited[ny, nx] = True

                            while queue:
                                cx, cy = queue.popleft()
                                branch.append((cx, cy))

                                # Check if reaching another junction or endpoint
                                neighbors = 0
                                for ddx in [-1, 0, 1]:
                                    for ddy in [-1, 0, 1]:
                                        if ddx == 0 and ddy == 0:
                                            continue
                                        nnx, nny = cx + ddx, cy + ddy
                                        if 0 <= nnx < skeleton.shape[1] and 0 <= nny < skeleton.shape[0]:
                                            if skeleton[nny, nnx] > 0 and not visited[nny, nnx]:
                                                neighbors += 1
                                                queue.append((nnx, nny))
                                                visited[nny, nnx] = True

                                # Stop if endpoint or junction
                                if neighbors == 0:
                                    break

                            branches.append(branch)

            # If more than 2 branches, redundant paths exist
            if len(branches) > 2:
                # Find shortest branch (assumed redundant)
                shortest_branch = min(branches, key=len)
                # Remove this branch (keep connection to first node)
                for i in range(1, len(shortest_branch)):
                    x, y = shortest_branch[i]
                    clean_skeleton[y, x] = 0

        return clean_skeleton

    def _calculate_endpoint_distances(self, endpoints, image_shape):
        """Calculate distances from endpoints to bottom center of image"""
        center_x = image_shape[1] // 2
        bottom_y = image_shape[0] - 1  # Image bottom

        distances = []
        for x, y in endpoints:
            # Calculate Euclidean distance
            distance = math.sqrt((x - center_x) ** 2 + (y - bottom_y) ** 2)
            distances.append((x, y, distance))

        # Sort by distance (ascending)
        distances.sort(key=lambda item: item[2])
        return distances

    def _extend_skeleton_from_endpoints(self, skeleton, traversable_mask, max_extension=50):
        """Extend skeleton lines from endpoints"""
        extended_skeleton = skeleton.copy()
        endpoints = self._find_skeleton_endpoints(skeleton)

        for x, y in endpoints:
            # Get endpoint direction
            neighbors = []
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    if dx == 0 and dy == 0:
                        continue
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < skeleton.shape[1] and 0 <= ny < skeleton.shape[0]:
                        if skeleton[ny, nx] > 0:
                            neighbors.append((dx, dy))

            # If single neighbor, it's a line endpoint
            if len(neighbors) == 1:
                dx, dy = neighbors[0]
                # Extension direction is opposite to neighbor
                ext_dx, ext_dy = -dx, -dy

                # Extend until obstacle or boundary
                for i in range(1, max_extension + 1):
                    nx, ny = x + ext_dx * i, y + ext_dy * i
                    if 0 <= nx < traversable_mask.shape[1] and 0 <= ny < traversable_mask.shape[0]:
                        if traversable_mask[ny, nx] == 0:  # Hit obstacle
                            # Draw line to obstacle boundary
                            cv2.line(
                                extended_skeleton, (x, y), (x + ext_dx * (i - 1), y + ext_dy * (i - 1)), 255, 1)
                            break
                    else:
                        break

        return extended_skeleton

    def process_depth(self, depth_array):
        """
        Process depth map data and return binary masks of endpoints
        
        Args:
            depth_array (np.ndarray): Depth map array with shape (H, W)

        Returns:
            np.ndarray: Binary mask array with shape (N, H, W) where N is number of endpoints
                       Each sub-array has value 1 only at corresponding endpoint location
        """
        # Ensure input is 2D depth map
        if depth_array.ndim != 2:
            raise ValueError(f"Input depth map should be 2D array, but got shape {depth_array.shape}")

        depth_map = depth_array
        height, width = depth_map.shape

        # Calculate traversable area
        height_map = self._depth_to_relative_height(
            depth_map) + self.height_offset
        traversable_mask_1 = (
            height_map < self.height_up_threshold).astype(np.uint8)
        traversable_mask_2 = (
            height_map < self.height_low_threshold).astype(np.uint8)
        traversable_mask = (traversable_mask_1 & ~
                            traversable_mask_2).astype(np.uint8)*255
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
        traversable_mask = cv2.morphologyEx(
            traversable_mask, cv2.MORPH_CLOSE, kernel)
        # Filter small components
        traversable_mask = self._filter_small_components(traversable_mask)

        # Check if any traversable area remains
        if not np.any(traversable_mask > 0):
            # print("Warning: No traversable areas remain after filtering!")
            return np.zeros((0, height, width), dtype=np.uint8)

        # Extract and process skeleton
        skeleton = self._get_clean_skeleton(traversable_mask)
        skeleton = self._remove_redundant_edges(skeleton)
        extended_skeleton = self._extend_skeleton_from_endpoints(
            skeleton, traversable_mask)

        # Find endpoints
        endpoints = self._find_skeleton_endpoints(skeleton)
        endpoint_distances = self._calculate_endpoint_distances(
            endpoints, (height, width))

        # print(f"\nDetected {len(endpoints)} endpoints:")
        for i, (x, y, dist) in enumerate(endpoint_distances, 1):
            # print(f"{i}. Endpoint coordinates: ({x}, {y}), Distance to bottom center: {dist:.2f} pixels")
            pass

        # Create endpoint binary mask array
        if len(endpoints) == 0:
            return np.zeros((0, height, width), dtype=np.uint8)

        # Create array with shape (N, H, W)
        endpoint_masks = np.zeros(
            (len(endpoints), height, width), dtype=np.uint8)

        # Create individual mask for each endpoint
        for i, (x, y) in enumerate(endpoints):
            # Ensure coordinates are valid
            if 0 <= x < width and 0 <= y < height:
                endpoint_masks[i, y, x] = 1

        # print(f"Returning endpoint mask array with shape {endpoint_masks.shape}")
        return endpoint_masks