"""
Point Cloud Renderer for Visual Logging

This module renders colored point clouds to images for visual logging and debugging.
It works with the PointCloudConverter to create visual representations of the
point clouds that are sent to PointLLM for analysis.
"""

import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import logging
from typing import Optional, Tuple, List


class PointCloudRenderer:
    """
    Renders colored point clouds to images with multiple viewing angles.

    Maintains visual consistency with the mesh renderer by using the same
    background color and rendering style.
    """

    def __init__(self, background_color: Tuple[float, float, float] = (0.95, 0.95, 0.95)):
        """
        Initialize the point cloud renderer.

        Args:
            background_color: Background color for the rendered images (RGB in 0-1 range)
        """
        self.background_color = background_color
        self.logger = logging.getLogger(self.__class__.__name__)

    def render_to_image(self,
                       point_cloud: np.ndarray,
                       output_path: str,
                       title: Optional[str] = None,
                       views: Optional[List[Tuple[float, float, str]]] = None,
                       max_display_points: int = 25000,
                       figsize: Tuple[int, int] = (16, 12),
                       dpi: int = 150) -> bool:
        """
        Render a colored point cloud to an image file with multiple views.

        Args:
            point_cloud: Nx6 array (xyz + RGB in 0-1 range)
            output_path: Path to save the rendered image
            title: Optional title for the visualization
            views: Optional list of (azimuth, elevation, view_name) tuples
            max_display_points: Maximum number of points to display (subsampled if exceeded)
            figsize: Figure size in inches
            dpi: Resolution of the output image

        Returns:
            True if successful, False otherwise
        """
        try:
            # Default views if not specified
            if views is None:
                views = [
                    (45, 30, "Front-Right"),
                    (135, 30, "Front-Left"),
                    (225, 30, "Back-Left"),
                    (315, 30, "Back-Right")
                ]

            # Create figure with high DPI for better quality
            fig = plt.figure(figsize=figsize, dpi=dpi)
            fig.patch.set_facecolor(self.background_color)

            # Extract coordinates and colors
            xyz = point_cloud[:, :3]
            colors = point_cloud[:, 3:]  # RGB in 0-1 range

            # Subsample for visualization if too many points
            if len(xyz) > max_display_points:
                indices = np.random.choice(len(xyz), max_display_points, replace=False)
                xyz = xyz[indices]
                colors = colors[indices]
                self.logger.debug(f"Subsampled from {len(point_cloud)} to {max_display_points} points for visualization")

            # Create subplots for different views
            num_views = len(views)
            subplot_cols = min(2, num_views)
            subplot_rows = (num_views + subplot_cols - 1) // subplot_cols

            for idx, (azim, elev, view_name) in enumerate(views, 1):
                ax = fig.add_subplot(subplot_rows, subplot_cols, idx, projection='3d')

                # Create scatter plot with larger points for better visibility
                ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2],
                          c=colors, s=1.0, alpha=0.9, marker='.')

                # Set view angle
                ax.view_init(elev=elev, azim=azim)

                # Set background for 3D axes area before hiding axes
                # This ensures the entire plotting area has the correct background
                ax.xaxis.pane.set_facecolor((*self.background_color, 1.0))
                ax.yaxis.pane.set_facecolor((*self.background_color, 1.0))
                ax.zaxis.pane.set_facecolor((*self.background_color, 1.0))

                # Set the axes background color
                ax.xaxis.pane.fill = True
                ax.yaxis.pane.fill = True
                ax.zaxis.pane.fill = True

                # Also set grid lines to be invisible but keep pane visible
                ax.grid(False)
                ax.xaxis.set_pane_color((*self.background_color, 1.0))
                ax.yaxis.set_pane_color((*self.background_color, 1.0))
                ax.zaxis.set_pane_color((*self.background_color, 1.0))

                # Remove axes labels and ticks but keep the background panes
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_zticks([])
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_zlabel('')

                # Add subtle title without axes
                ax.text2D(0.5, 0.02, f'{view_name}',
                         transform=ax.transAxes,
                         ha='center', va='bottom',
                         fontsize=12, color='#666666')

                # Equal aspect ratio
                max_range = np.array([
                    xyz[:, 0].max() - xyz[:, 0].min(),
                    xyz[:, 1].max() - xyz[:, 1].min(),
                    xyz[:, 2].max() - xyz[:, 2].min()
                ]).max() / 2.0

                mid_x = (xyz[:, 0].max() + xyz[:, 0].min()) * 0.5
                mid_y = (xyz[:, 1].max() + xyz[:, 1].min()) * 0.5
                mid_z = (xyz[:, 2].max() + xyz[:, 2].min()) * 0.5

                ax.set_xlim(mid_x - max_range, mid_x + max_range)
                ax.set_ylim(mid_y - max_range, mid_y + max_range)
                ax.set_zlim(mid_z - max_range, mid_z + max_range)

            # Overall title if provided
            if title:
                fig.suptitle(title, fontsize=16, fontweight='bold', y=0.98)

            # Adjust layout and save
            plt.tight_layout(rect=[0, 0.03, 1, 0.96])
            plt.savefig(output_path, bbox_inches='tight',
                       facecolor=self.background_color, edgecolor='none', dpi=dpi)
            plt.close()

            self.logger.info(f"Saved point cloud visualization to {output_path}")
            return True

        except Exception as e:
            self.logger.error(f"Failed to render point cloud: {e}")
            return False

    def render_from_file(self,
                        npy_path: str,
                        output_path: str,
                        title: Optional[str] = None,
                        **kwargs) -> bool:
        """
        Render a point cloud from a saved NPY file.

        Args:
            npy_path: Path to the NPY file containing the point cloud
            output_path: Path to save the rendered image
            title: Optional title for the visualization
            **kwargs: Additional arguments passed to render_to_image

        Returns:
            True if successful, False otherwise
        """
        try:
            if not os.path.exists(npy_path):
                self.logger.error(f"NPY file not found: {npy_path}")
                return False

            # Load point cloud from file
            point_cloud = np.load(npy_path)

            if point_cloud.shape[1] != 6:
                self.logger.error(f"Invalid point cloud shape: {point_cloud.shape}. Expected Nx6 array.")
                return False

            # Generate title if not provided
            if title is None:
                title = f"Point Cloud ({len(point_cloud)} points)"

            # Render the point cloud
            return self.render_to_image(point_cloud, output_path, title, **kwargs)

        except Exception as e:
            self.logger.error(f"Failed to load and render point cloud from {npy_path}: {e}")
            return False

    def create_comparison_image(self,
                              point_clouds: List[np.ndarray],
                              output_path: str,
                              titles: Optional[List[str]] = None,
                              main_title: Optional[str] = None,
                              figsize: Tuple[int, int] = (20, 12),
                              dpi: int = 150) -> bool:
        """
        Create a comparison image showing multiple point clouds side by side.

        Args:
            point_clouds: List of point cloud arrays
            output_path: Path to save the comparison image
            titles: Optional list of titles for each point cloud
            main_title: Optional main title for the entire figure
            figsize: Figure size in inches
            dpi: Resolution of the output image

        Returns:
            True if successful, False otherwise
        """
        try:
            num_clouds = len(point_clouds)
            if num_clouds == 0:
                self.logger.error("No point clouds provided for comparison")
                return False

            # Create figure
            fig = plt.figure(figsize=figsize, dpi=dpi)
            fig.patch.set_facecolor(self.background_color)

            # Default titles if not provided
            if titles is None:
                titles = [f"Point Cloud {i+1}" for i in range(num_clouds)]

            # Create subplots
            for idx, (point_cloud, title) in enumerate(zip(point_clouds, titles), 1):
                ax = fig.add_subplot(1, num_clouds, idx, projection='3d')

                # Extract coordinates and colors
                xyz = point_cloud[:, :3]
                colors = point_cloud[:, 3:]

                # Subsample if needed
                max_points = 15000
                if len(xyz) > max_points:
                    indices = np.random.choice(len(xyz), max_points, replace=False)
                    xyz = xyz[indices]
                    colors = colors[indices]

                # Create scatter plot
                ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2],
                          c=colors, s=1.0, alpha=0.9, marker='.')

                # Set view angle
                ax.view_init(elev=30, azim=45)

                # Set background for 3D axes area before hiding axes
                # This ensures the entire plotting area has the correct background
                ax.xaxis.pane.set_facecolor((*self.background_color, 1.0))
                ax.yaxis.pane.set_facecolor((*self.background_color, 1.0))
                ax.zaxis.pane.set_facecolor((*self.background_color, 1.0))

                # Set the axes background color
                ax.xaxis.pane.fill = True
                ax.yaxis.pane.fill = True
                ax.zaxis.pane.fill = True

                # Also set grid lines to be invisible but keep pane visible
                ax.grid(False)
                ax.xaxis.set_pane_color((*self.background_color, 1.0))
                ax.yaxis.set_pane_color((*self.background_color, 1.0))
                ax.zaxis.set_pane_color((*self.background_color, 1.0))

                # Remove axes labels and ticks but keep the background panes
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_zticks([])
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_zlabel('')

                # Add title
                ax.text2D(0.5, 0.02, title,
                         transform=ax.transAxes,
                         ha='center', va='bottom',
                         fontsize=12, color='#666666')

                # Equal aspect ratio
                self._set_equal_aspect(ax, xyz)

            # Main title if provided
            if main_title:
                fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.98)

            # Save figure
            plt.tight_layout(rect=[0, 0.03, 1, 0.96])
            plt.savefig(output_path, bbox_inches='tight',
                       facecolor=self.background_color, edgecolor='none', dpi=dpi)
            plt.close()

            self.logger.info(f"Saved comparison image to {output_path}")
            return True

        except Exception as e:
            self.logger.error(f"Failed to create comparison image: {e}")
            return False

    def _set_equal_aspect(self, ax, xyz):
        """Set equal aspect ratio for 3D plot."""
        max_range = np.array([
            xyz[:, 0].max() - xyz[:, 0].min(),
            xyz[:, 1].max() - xyz[:, 1].min(),
            xyz[:, 2].max() - xyz[:, 2].min()
        ]).max() / 2.0

        mid_x = (xyz[:, 0].max() + xyz[:, 0].min()) * 0.5
        mid_y = (xyz[:, 1].max() + xyz[:, 1].min()) * 0.5
        mid_z = (xyz[:, 2].max() + xyz[:, 2].min()) * 0.5

        ax.set_xlim(mid_x - max_range, mid_x + max_range)
        ax.set_ylim(mid_y - max_range, mid_y + max_range)
        ax.set_zlim(mid_z - max_range, mid_z + max_range)


def render_pointcloud_for_iteration(mesh_path: str,
                                   output_folder: str,
                                   iteration_num: int,
                                   description: str = "",
                                   links_json_path: Optional[str] = None,
                                   sample_points: int = 8192) -> bool:
    """
    Helper function to render a point cloud for a specific pipeline iteration.

    This is designed to be easily called from the pipeline to save point cloud
    visualizations at each iteration.

    Args:
        mesh_path: Path to the OBJ mesh file
        output_folder: Folder to save the visualization
        iteration_num: Current iteration number
        description: Object description for the title
        links_json_path: Optional path to links_hierarchy.json
        sample_points: Number of points to sample

    Returns:
        True if successful, False otherwise
    """
    try:
        from utils.pointcloud_converter import PointCloudConverter

        # Convert mesh to point cloud
        converter = PointCloudConverter(sample_points=sample_points)
        point_cloud, component_mapping, color_names = converter.convert_obj_to_pointcloud(
            mesh_path,
            links_json_path=links_json_path
        )

        # Create renderer
        renderer = PointCloudRenderer(background_color=(0.95, 0.95, 0.95))

        # Generate output path
        output_path = os.path.join(output_folder, f"pointcloud_render_iter{iteration_num}.png")

        # Create title
        num_components = len(component_mapping)
        num_points = len(point_cloud)
        title = f"Iteration {iteration_num}: {description}\n({num_components} components, {num_points} points)"

        # Render the point cloud
        success = renderer.render_to_image(
            point_cloud,
            output_path,
            title=title,
            max_display_points=25000
        )

        # Also save the point cloud data for later use
        if success:
            npy_path = os.path.join(output_folder, f"pointcloud_data_iter{iteration_num}.npy")
            np.save(npy_path, point_cloud)
            logging.info(f"Saved point cloud data to {npy_path}")

        return success

    except Exception as e:
        logging.error(f"Failed to render point cloud for iteration {iteration_num}: {e}")
        return False