import numpy as np
import io
from PIL import Image
import matplotlib.pyplot as plt
import torch


class TraversableAreaDetector:
    """
    A class for detecting traversable areas in a scene and managing exploration goals.

    Args:
        args: Configuration arguments containing camera parameters and thresholds
    """

    def __init__(self, args) -> None:
        self.args = args
        self.height_map: np.ndarray = None
        self.traversable_map: np.ndarray = None  # Binary mask (0/1)
        self.horizontal_points: list = []
        self.vertical_points: list = []
        self.merged_points: list = []
        self.rgb: np.ndarray = None  # 0-255 range
        self.depth: np.ndarray = None  # 0-255 range
        self.rgb_annotated_np: np.ndarray = None  # 0-255 range
        self.raversableAreaMask: np.ndarray = None  # 0-255 range
        self.highlight_select_goal: np.ndarray = None  # 0-255 range

    def _depth_to_relative_height(self) -> np.ndarray:
        """
        Calculate relative height difference from depth image to camera.

        Returns:
            np.ndarray: Relative height map with shape (480, 640), where positive values 
                       indicate heights above camera, negative values below camera
        """
        img_height, img_width = self.depth.shape
        # Calculate focal length (in pixels)
        focal_length_px = img_width / \
            (2 * np.tan(np.radians(self.args.hfov / 2)))

        # Generate pixel coordinate grid
        i_idx, j_idx = np.indices((img_height, img_width))
        # Convert pixel coordinates to offsets from image center
        y_prime = (i_idx - img_height / 2)

        # Calculate relative height
        # y_local represents height relative to camera
        y_local = y_prime * self.depth / focal_length_px

        # Note: Negative sign needed because image y-axis points downward
        return -y_local

    def _find_segment_boundaries(self, arr: np.ndarray) -> list:
        """
        Extract boundary points of connected regions in a 1D array.

        Args:
            arr (np.ndarray): 1D binary array where 1 indicates connected region

        Returns:
            list: List of boundary indices marking start and end points of connected regions
        """
        boundaries = []
        start = None
        for i in range(len(arr)):
            if arr[i] == 1:
                if start is None:
                    start = i  # Start of connected region
            else:
                if start is not None:
                    end = i - 1  # End of connected region
                    boundaries.extend([start, end])
                    start = None
        # Handle last connected region
        if start is not None:
            end = len(arr) - 1
            boundaries.extend([start, end])
        return boundaries

    def _process_boundaries(self, boundaries: list) -> list:
        """
        Process boundary points list to generate midpoint coordinates.

        Args:
            boundaries (list): List of boundary indices

        Returns:
            list: List of averaged midpoint coordinates
        """
        boundaries = sorted(boundaries)
        n = len(boundaries)
        if n % 2 != 0:
            boundaries = boundaries[:-1]  # Discard last odd point
        averaged = []
        for i in range(0, len(boundaries), 2):
            if i + 1 >= len(boundaries):
                break
            avg = (boundaries[i] + boundaries[i + 1]) // 2
            averaged.append(avg)
        return averaged

    def _sample_candidate_points(self) -> tuple:
        """
        Sample candidate points from the traversable mask array.

        Returns:
            tuple: (horizontal_points, vertical_points) where each point is (x,y) coordinate
                  horizontal_points: list of points sampled horizontally
                  vertical_points: list of points sampled vertically
        """
        # Clear existing candidate points
        self.horizontal_points = []
        self.vertical_points = []

        # Extract boundaries of mask=1 regions
        rows, cols = np.where(self.traversable_map == 1)
        if len(rows) == 0:
            return  # Early return if no traversable areas

        # Horizontal sampling: Process x-axis
        x_min, x_max = cols.min(), cols.max()
        x_min += 15  # Add margin
        x_max -= 15  # Add margin
        x_samples = np.linspace(x_min, x_max, 6, dtype=int)
        for x in x_samples:
            mask_col = self.traversable_map[:, x]
            boundaries = self._find_segment_boundaries(mask_col)
            averaged_y = self._process_boundaries(boundaries)
            for y in averaged_y:
                self.horizontal_points.append((x, y))

        # Vertical sampling: Process y-axis
        y_min, y_max = rows.min(), rows.max()
        y_min += 15  # Add margin
        y_max -= 15  # Add margin
        y_samples = np.linspace(y_min, y_max, 6, dtype=int)
        for y in y_samples:
            mask_row = self.traversable_map[y, :]
            boundaries = self._find_segment_boundaries(mask_row)
            averaged_x = self._process_boundaries(boundaries)
            for x in averaged_x:
                self.vertical_points.append((x, y))

        return self.horizontal_points, self.vertical_points

    def merge_candidate_points(self) -> list:
        """
        Merge horizontal and vertical candidate points to generate new candidate points.

        Returns:
            list: List of merged points where each point is (x,y) coordinate
        """
        # Clear merged points
        self.merged_points = []

        # Determine base points (smaller set) and other points (larger set)
        base_points, other_points = (self.horizontal_points, self.vertical_points) if len(
            self.horizontal_points) <= len(self.vertical_points) else (self.vertical_points, self.horizontal_points)
        other_points = np.array(other_points)  # Convert to numpy for distance calculation

        # For each base point, find closest point in other set and average
        for p in base_points:
            px, py = p
            # Calculate Euclidean distances to other points
            distances = np.linalg.norm(
                other_points - np.array([px, py]), axis=1)
            closest_idx = np.argmin(distances)
            closest_point = other_points[closest_idx]

            # Calculate average coordinates (integer)
            avg_x = (px + closest_point[0]) // 2
            avg_y = (py + closest_point[1]) // 2
            self.merged_points.append((avg_x, avg_y))

            # Remove matched point from other_points
            other_points = np.delete(other_points, closest_idx, axis=0)

        # Add remaining unmatched points
        remaining_points = [tuple(p) for p in other_points.tolist()]
        self.merged_points += remaining_points

        return self.merged_points

    def _generate_expanded_masks_of_selected_goal(self, select_point: int, target_size=3000, max_radius=50) -> np.ndarray:
        """
        Generate circular expanded mask for selected goal point.

        Args:
            select_point (int): Index of selected point (1-based indexing)
            target_size (int): Target number of pixels to cover (default: 2000)
            max_radius (int): Maximum expansion radius in pixels (default: 50)

        Returns:
            np.ndarray: Binary mask array with shape (args.env_frame_height, args.env_frame_width)
        """
        # Get sorted points
        sorted_points = sorted(self.merged_points, key=lambda p: (p[0], -p[1]))
        (x, y) = sorted_points[select_point - 1]  # select_point is 1-based index

        mask = np.zeros((self.args.env_frame_height, self.args.env_frame_width), dtype=np.uint8)
        covered = 0
        radius = 1

        # Gradually increase radius until target coverage or max radius reached
        while covered < target_size and radius <= max_radius:
            # Generate circular region for current radius
            new_pixels = []
            for dy in range(-radius, radius+1):
                for dx in range(-radius, radius+1):
                    # Check if within circle
                    if dx**2 + dy**2 <= radius**2:
                        px = x + dx
                        py = y + dy
                        # Check if within image bounds and not already covered
                        if 0 <= px < self.args.env_frame_width and 0 <= py < self.args.env_frame_height and mask[py, px] == 0:
                            new_pixels.append((py, px))

            # Terminate if no new pixels can be added
            if not new_pixels:
                break

            # Update mask
            for (py, px) in new_pixels:
                mask[py, px] = 1
            covered += len(new_pixels)
            radius += 1

        return mask

    def _annotated_candidate_exp_goal(self) -> np.ndarray:
        """
        Generate annotated image with numbered candidate exploration goals.

        Returns:
            np.ndarray: RGB image array with shape (args.env_frame_height, args.env_frame_width, 3) containing numbered annotations
        """
        fig = plt.figure(figsize=(self.args.env_frame_width/100, self.args.env_frame_height/100),
                         dpi=100)  # Dynamically set size based on args
        ax = fig.add_axes([0, 0, 1, 1])  # Cover entire canvas
        ax.imshow(self.rgb.astype(np.uint8))  # Ensure uint8 type
        ax.axis('off')

        # Sort points from bottom-left to top-right: x ascending, y descending (image origin at top-left)
        sorted_points = sorted(self.merged_points, key=lambda p: (p[0], -p[1]))

       # Add numbered annotations
        for idx, (x, y) in enumerate(sorted_points, start=1):
            ax.text(
                x, y, str(idx),
                fontsize=10, weight='bold',
                color='black',
                ha='center', va='center',
                bbox=dict(
                    facecolor='white', edgecolor='black',
                    boxstyle='circle,pad=0.3', linewidth=1.5
                )
            )
        buffer = io.BytesIO()
        # Save image to buffer
        plt.savefig(buffer, format='png', bbox_inches='tight',
                    pad_inches=0, dpi=100)
        buffer.seek(0)  # Reset buffer position

        # Read image
        img = Image.open(buffer)
        self.rgb_annotated_np = np.array(img)[:, :, :3].astype(np.uint8)
        plt.close()  # Ensure figure is closed
        return self.rgb_annotated_np

    # Visualization methods

    def get_numpy_of_annotated_candidate_exp_goals(self) -> np.ndarray:
        """
        Get the numpy array of annotated candidate exploration goals.

        Returns:
            np.ndarray: RGB image array with shape (args.env_frame_height, args.env_frame_width, 3) containing numbered annotations
        """
        return self.rgb_annotated_np

    def run_current_view(self, rgb: np.ndarray, depth: np.ndarray) -> bool:
        """
        Process current view to detect traversable areas and generate candidate points.

        Args:
            rgb (np.ndarray): RGB image array with shape (args.env_frame_height, args.env_frame_width, 3)
            depth (np.ndarray): Depth image array with shape (args.env_frame_height, args.env_frame_width)

        Returns:
            bool: True if traversable areas found, False otherwise
        """
        self.rgb = rgb
        self.depth = depth
        self.height_map = self._depth_to_relative_height() + self.args.camera_height
        self.traversable_map = np.where(
            self.height_map < self.args.traversable_height_threshold, 1, 0)
        if self.traversable_map.sum() == 0:  # No traversable areas in current view
            return False
        self._sample_candidate_points()
        self.merge_candidate_points()
        self._annotated_candidate_exp_goal()  # Generate annotated image
        return True

    def get_goal_mask_of_select_point(self, selected_goal: int) -> torch.Tensor:
        """
        Get downsampled mask for selected goal point.

        Args:
            selected_goal (int): Index of selected goal point (1-based indexing)

        Returns:
            torch.Tensor: Downsampled binary mask tensor with shape (120, 160) or None if invalid
        """
        if not self.merged_points:
            return None
        if selected_goal < 1 or selected_goal > len(self.merged_points):
            return None

        ds = self.args.env_frame_width // self.args.frame_width
        mask = self._generate_expanded_masks_of_selected_goal(selected_goal)
        if ds != 1:
            # Downsample to match main obs shape
            selected_goal_mask = mask[ds // 2::ds, ds // 2::ds]
        else:
            selected_goal_mask = mask

        # Convert numpy array to CUDA tensor
        selected_goal_mask = torch.from_numpy(
            selected_goal_mask).float().cuda()
        return selected_goal_mask

    def highlight_selected_goal(self, selected_goal: int) -> np.ndarray:
        """
        Generate annotated image with selected goal highlighted in green.

        Args:
            selected_goal (int): Index of goal point to highlight (1-based indexing)

        Returns:
            np.ndarray: RGB image array with shape (args.env_frame_height, args.env_frame_width, 3) with highlighted goal
        """
        fig = plt.figure(figsize=(self.args.env_frame_width/100, self.args.env_frame_height/100),
                         dpi=100)  # Dynamically set size based on args
        ax = fig.add_axes([0, 0, 1, 1])  # Cover entire canvas
        ax.imshow(self.rgb.astype(np.uint8))  # Ensure uint8 type
        ax.axis('off')

        # Sort points from bottom-left to top-right
        sorted_points = sorted(self.merged_points, key=lambda p: (p[0], -p[1]))

        # Add numbered annotations with highlight
        for idx, (x, y) in enumerate(sorted_points, start=1):
            # Use green highlight for selected goal
            if idx == selected_goal:
                facecolor = 'lime'
            else:
                facecolor = 'white'

            ax.text(
                x, y, str(idx),
                fontsize=10, weight='bold',
                color='black',
                ha='center', va='center',
                bbox=dict(
                    facecolor=facecolor, edgecolor='black',
                    boxstyle='circle,pad=0.3', linewidth=1.5
                )
            )
        buffer = io.BytesIO()
        # Save image to buffer
        plt.savefig(buffer, format='png', bbox_inches='tight',
                    pad_inches=0, dpi=100)
        buffer.seek(0)
        # Convert to numpy array maintaining 0-255 range
        img = Image.open(buffer)
        self.highlight_select_goal = np.array(img)[:, :, :3].astype(np.uint8)
        plt.close()
        return self.highlight_select_goal

    def get_TraversableAreaMask(self) -> np.ndarray:
        """
        Get visualization of traversable areas marked in black.

        Returns:
            np.ndarray: RGB image array with shape (args.env_frame_height, args.env_frame_width, 3) with traversable areas in black
        """
        self.raversableAreaMask = self.rgb.copy()
        # Set traversable areas to black
        self.raversableAreaMask[self.traversable_map == 1] = [0, 0, 0]
        return self.raversableAreaMask