import numpy as np
import cv2
from skimage.morphology import skeletonize
from PIL import Image
import os
import math
from collections import deque


def depth_to_relative_height(depth_image, hfov=79):
    """Calculate relative height map from depth image"""
    if depth_image.ndim > 2:
        depth_image = depth_image[:, :, 0]
    h, w = depth_image.shape
    focal = w / (2 * np.tan(np.radians(hfov / 2)))
    y_map = (np.indices((h, w))[0] - h / 2) * depth_image / focal
    return -y_map


def filter_small_components(binary_mask, min_area=500):
    """Filter out connected components smaller than threshold area"""
    # Find connected components
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        binary_mask, connectivity=8)

    # Create filtered mask
    filtered_mask = np.zeros_like(binary_mask)

    print(f"\nDetected {num_labels-1} connected components:")
    kept_components = 0

    # Iterate through all components (skip background label 0)
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        print(f"Component {i}: Area = {area} pixels")

        if area >= min_area:
            # Keep components larger than threshold
            filtered_mask[labels == i] = 255
            kept_components += 1
            print(f"  -> Kept (Area >= {min_area})")
        else:
            print(f"  -> Filtered (Area < {min_area})")

    print(f"Final kept components: {kept_components}")
    return filtered_mask


def get_clean_skeleton(binary_mask):
    """Extract clean skeleton from binary mask"""
    binary_mask = (binary_mask > 0).astype(np.uint8) * 255
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    dilated = cv2.dilate(binary_mask, kernel, iterations=2)
    skeleton = skeletonize(dilated // 255)
    return skeleton.astype(np.uint8) * 255


def find_skeleton_endpoints(skeleton):
    """Locate endpoints in skeleton (points with single neighbor)"""
    kernel = np.array([[1, 1, 1],
                       [1, 10, 1],
                       [1, 1, 1]], dtype=np.uint8)
    filtered = cv2.filter2D(skeleton // 255, -1, kernel)
    endpoints = np.where(filtered == 11)
    return list(zip(endpoints[1], endpoints[0]))


def find_skeleton_junctions(skeleton):
    """Locate junction points in skeleton (points with ≥3 neighbors)"""
    kernel = np.array([[1, 1, 1],
                       [1, 10, 1],
                       [1, 1, 1]], dtype=np.uint8)
    filtered = cv2.filter2D(skeleton // 255, -1, kernel)
    junctions = np.where(filtered >= 13)
    return list(zip(junctions[1], junctions[0]))


def remove_redundant_edges(skeleton):
    """Remove redundant edges in loops that don't connect to other traversable areas"""
    clean_skeleton = skeleton.copy()
    junctions = find_skeleton_junctions(skeleton)

    for jx, jy in junctions:
        # Get all connected branches
        branches = []
        visited = np.zeros_like(skeleton, dtype=bool)

        # Check 8-neighborhood
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                if dx == 0 and dy == 0:
                    continue
                nx, ny = jx + dx, jy + dy
                if 0 <= nx < skeleton.shape[1] and 0 <= ny < skeleton.shape[0]:
                    if skeleton[ny, nx] > 0 and not visited[ny, nx]:
                        # Track this branch
                        branch = []
                        queue = deque([(nx, ny)])
                        visited[ny, nx] = True

                        while queue:
                            cx, cy = queue.popleft()
                            branch.append((cx, cy))

                            # Check if reaching another junction or endpoint
                            neighbors = 0
                            for ddx in [-1, 0, 1]:
                                for ddy in [-1, 0, 1]:
                                    if ddx == 0 and ddy == 0:
                                        continue
                                    nnx, nny = cx + ddx, cy + ddy
                                    if 0 <= nnx < skeleton.shape[1] and 0 <= nny < skeleton.shape[0]:
                                        if skeleton[nny, nnx] > 0 and not visited[nny, nnx]:
                                            neighbors += 1
                                            queue.append((nnx, nny))
                                            visited[nny, nnx] = True

                            # Stop if endpoint or junction
                            if neighbors == 0:
                                break

                        branches.append(branch)

        # If more than 2 branches, redundant paths exist
        if len(branches) > 2:
            # Find shortest branch (assumed redundant)
            shortest_branch = min(branches, key=len)
            # Remove this branch (keep connection to first node)
            for i in range(1, len(shortest_branch)):
                x, y = shortest_branch[i]
                clean_skeleton[y, x] = 0

    return clean_skeleton


def calculate_endpoint_distances(endpoints, image_shape):
    """Calculate distances from endpoints to bottom center of image"""
    center_x = image_shape[1] // 2
    bottom_y = image_shape[0] - 1  # Image bottom

    distances = []
    for x, y in endpoints:
        # Euclidean distance
        distance = math.sqrt((x - center_x) ** 2 + (y - bottom_y) ** 2)
        distances.append((x, y, distance))

    # Sort by distance (ascending)
    distances.sort(key=lambda item: item[2])

    return distances


def extend_skeleton_from_endpoints(skeleton, traversable_mask, max_extension=50):
    """Extend skeleton lines from endpoints"""
    extended_skeleton = skeleton.copy()
    endpoints = find_skeleton_endpoints(skeleton)

    for x, y in endpoints:
        # Get endpoint direction
        neighbors = []
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                if dx == 0 and dy == 0:
                    continue
                nx, ny = x + dx, y + dy
                if 0 <= nx < skeleton.shape[1] and 0 <= ny < skeleton.shape[0]:
                    if skeleton[ny, nx] > 0:
                        neighbors.append((dx, dy))

        # If single neighbor, it's a line endpoint
        if len(neighbors) == 1:
            dx, dy = neighbors[0]
            # Extension direction is opposite to neighbor
            ext_dx, ext_dy = -dx, -dy

            # Extend until obstacle or boundary
            for i in range(1, max_extension + 1):
                nx, ny = x + ext_dx * i, y + ext_dy * i
                if 0 <= nx < traversable_mask.shape[1] and 0 <= ny < traversable_mask.shape[0]:
                    if traversable_mask[ny, nx] == 0:  # Hit obstacle
                        # Draw line to obstacle boundary
                        cv2.line(extended_skeleton, (x, y), (x +
                                 ext_dx * (i - 1), y + ext_dy * (i - 1)), 255, 1)
                        break
                else:
                    break

    return extended_skeleton


def process_rgbd(rgbd_array, height_offset=0.88, height_low_threshold=0.8746, height_up_threshold=0.8755, min_component_area=500):
    """Process RGB-D data to generate three output images"""
    if rgbd_array.ndim == 4 and rgbd_array.shape[0] == 1:
        rgbd_array = rgbd_array[0]

    # Extract RGB image
    rgb_img = (rgbd_array[:, :, :3] * 255).astype(np.uint8) if rgbd_array[:, :, :3].max() <= 1 else rgbd_array[:, :,
                                                                                                               :3].astype(np.uint8)
    depth_map = rgbd_array[:, :, 3] if rgbd_array[:, :,
                                                  3].ndim == 2 else rgbd_array[:, :, 3:4].squeeze()

    # Calculate traversable area
    height_map = depth_to_relative_height(depth_map) + height_offset
    print(height_map.min(), height_map.max())
    traversable_mask_1 = (
        height_map < height_up_threshold).astype(np.uint8)
    traversable_mask_2 = (
        height_map < height_low_threshold).astype(np.uint8)
    traversable_mask = (traversable_mask_1 & ~
                        traversable_mask_2).astype(np.uint8)*255
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    traversable_mask = cv2.morphologyEx(
        traversable_mask, cv2.MORPH_CLOSE, kernel)

    # Filter small components
    traversable_mask = filter_small_components(
        traversable_mask, min_component_area)

    # Image 1: Original RGB
    original_img = rgb_img.copy()

    # Image 2: Traversable area (black=traversable)
    traversable_img = np.zeros_like(rgb_img)
    traversable_img[traversable_mask == 255] = [255, 255, 255]  # White for traversable
    traversable_img = cv2.bitwise_not(traversable_img)  # Invert (black=traversable)

    # Image 3: Annotated image with traversable area and skeleton
    marked_img = rgb_img.copy()

    # Annotate traversable area (semi-transparent black)
    overlay = marked_img.copy()
    overlay[traversable_mask == 255] = [0, 0, 0]
    marked_img = cv2.addWeighted(marked_img, 0.7, overlay, 0.3, 0)

    # Process skeleton only if traversable area exists
    if np.any(traversable_mask > 0):
        # Extract and extend skeleton
        skeleton = get_clean_skeleton(traversable_mask)
        skeleton = remove_redundant_edges(skeleton)  # Remove redundant edges
        extended_skeleton = extend_skeleton_from_endpoints(
            skeleton, traversable_mask)

        # Draw skeleton (yellow)
        marked_img[extended_skeleton == 255] = [0, 255, 255]

        # Mark endpoints (green) and calculate distances
        endpoints = find_skeleton_endpoints(skeleton)
        endpoint_distances = calculate_endpoint_distances(
            endpoints, marked_img.shape)

        # Print endpoint info
        print("\nEndpoint coordinates and distances (nearest to farthest):")
        for i, (x, y, dist) in enumerate(endpoint_distances, 1):
            print(f"{i}. Coordinates: ({x}, {y}), Distance: {dist:.2f} pixels")
            cv2.circle(marked_img, (x, y), 2, (0, 255, 0), -1)
            cv2.putText(marked_img, f"{i}", (x + 10, y + 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    else:
        print("\nWarning: No traversable areas remain after filtering!")

    return original_img, traversable_img, marked_img


