"""
Point Cloud Converter for PointLLM Integration

This module converts OBJ mesh files to colored point clouds suitable for PointLLM input.
It maintains color consistency with the 2D rendering system by using link-based color
grouping to reduce the number of distinct components.
"""

import os
import json
import colorsys
import numpy as np
import trimesh
import logging
from typing import List, Tuple, Dict, Optional


class PointCloudConverter:
    """
    Converts OBJ mesh files to colored point clouds for PointLLM analysis.

    Uses link-based grouping from links_hierarchy.json to reduce the number
    of distinct colored components, avoiding token overflow in PointLLM.
    """

    def __init__(self, sample_points: int = 8192):
        """
        Initialize the point cloud converter.

        Args:
            sample_points: Number of points to sample from the mesh
        """
        self.sample_points = sample_points
        self.logger = logging.getLogger(self.__class__.__name__)

    def convert_obj_to_pointcloud(self,
                                 obj_path: str,
                                 links_json_path: Optional[str] = None) -> Tuple[np.ndarray, Dict[int, str], List[str]]:
        """
        Convert an OBJ file to a colored point cloud with link-based grouping.

        Args:
            obj_path: Path to the OBJ file
            links_json_path: Optional path to links_hierarchy.json for link-based grouping

        Returns:
            Tuple of:
                - point_cloud: Array of shape (N, 6) with xyz + RGB values (normalized)
                - component_mapping: Dict mapping component index to link name
                - color_names: List of color names for each link
        """
        if not os.path.exists(obj_path):
            raise FileNotFoundError(f"OBJ file not found: {obj_path}")

        try:
            # Load the mesh
            self.logger.info(f"Loading mesh from: {obj_path}")
            mesh = trimesh.load(obj_path)

            # Handle scene objects
            if isinstance(mesh, trimesh.Scene):
                mesh = mesh.dump(concatenate=True)

            if not isinstance(mesh, trimesh.Trimesh):
                raise ValueError("Failed to load a valid mesh from the OBJ file")

            self.logger.info(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces")

            # Get link-based grouping if available
            if links_json_path and os.path.exists(links_json_path):
                link_groups, link_names = self._group_by_links(mesh, links_json_path)
                self.logger.info(f"Using link-based grouping: {len(link_groups)} links")
            else:
                # Fallback to connected components
                mesh_components = mesh.split(only_watertight=False)
                if not mesh_components:
                    mesh_components = [mesh]

                link_groups = {f"component_{i}": [comp] for i, comp in enumerate(mesh_components)}
                link_names = list(link_groups.keys())
                self.logger.info(f"Using component-based grouping: {len(mesh_components)} components")

            # Generate colors for links
            num_links = len(link_groups)
            link_colors = self._generate_component_colors(num_links)
            color_names = self._generate_color_names(num_links)

            # Sample points from each link with assigned colors
            all_points = []
            component_mapping = {}

            for idx, (link_name, link_meshes) in enumerate(link_groups.items()):
                color = link_colors[idx]
                color_name = color_names[idx]

                # Calculate total points for this link (proportional to size)
                total_vertices = sum(len(m.vertices) for m in link_meshes)
                link_sample_points = max(100, int(self.sample_points * (total_vertices / mesh.vertices.shape[0])))

                link_points = []

                for mesh_comp in link_meshes:
                    try:
                        # Sample points from this mesh component
                        # Use proportion based on mesh size
                        mesh_proportion = len(mesh_comp.vertices) / total_vertices if total_vertices > 0 else 1.0
                        points_to_sample = max(10, int(link_sample_points * mesh_proportion))

                        points, _ = trimesh.sample.sample_surface(mesh_comp, points_to_sample)
                        link_points.append(points)
                    except Exception as e:
                        self.logger.warning(f"Failed to sample from mesh in link {link_name}: {e}")
                        # Fallback: use vertices
                        points = mesh_comp.vertices
                        if len(points) > 100:
                            indices = np.random.choice(len(points), 100, replace=False)
                            points = points[indices]
                        link_points.append(points)

                if link_points:
                    # Combine all points for this link
                    combined_points = np.vstack(link_points)

                    # Add RGB values (0-1 range for PointLLM)
                    num_points = len(combined_points)
                    colored_points = np.zeros((num_points, 6))
                    colored_points[:, :3] = combined_points  # xyz
                    colored_points[:, 3:6] = color  # RGB in 0-1 range

                    all_points.append(colored_points)
                    component_mapping[idx] = link_name

                    self.logger.info(f"Link '{link_name}' ({color_name}): sampled {num_points} points")

            # Combine all points
            if all_points:
                point_cloud = np.vstack(all_points)
            else:
                # Emergency fallback
                points = mesh.sample(self.sample_points)
                rgb = np.tile([0.5, 0.5, 0.5], (len(points), 1))
                point_cloud = np.hstack([points, rgb])
                component_mapping = {0: "entire_mesh"}
                color_names = ["gray"]

            # Ensure exactly sample_points using farthest point sampling
            point_cloud = self._resample_to_target_size(point_cloud, self.sample_points)

            # Normalize point cloud (xyz only, preserve RGB)
            point_cloud = self._normalize_point_cloud(point_cloud)

            self.logger.info(f"Generated point cloud: {len(point_cloud)} points, {len(component_mapping)} colored links")

            return point_cloud, component_mapping, color_names

        except Exception as e:
            self.logger.error(f"Failed to convert mesh to point cloud: {e}")
            raise

    def _group_by_links(self, mesh: trimesh.Trimesh, links_json_path: str) -> Tuple[Dict[str, List[trimesh.Trimesh]], List[str]]:
        """
        Group mesh components by links based on links_hierarchy.json.

        Args:
            mesh: The combined mesh
            links_json_path: Path to links_hierarchy.json

        Returns:
            Tuple of:
                - link_groups: Dict mapping link names to lists of mesh components
                - link_names: Ordered list of link names
        """
        # Load links hierarchy
        with open(links_json_path, 'r', encoding='utf-8') as f:
            links_data = json.load(f)

        # Extract ALL unique link names from hierarchy (not just main_links)
        link_names = []
        hierarchy = links_data.get('hierarchy', {}).get('structure', [])

        def extract_all_links(items, depth=0):
            """Recursively extract ALL named entries for maximum color diversity."""
            links = []
            for item in items:
                if 'name' in item:
                    name = item['name']
                    # Add all named items for diverse coloring
                    links.append(name)

                    # Recursively extract children
                    if 'children' in item and item['children']:
                        child_links = extract_all_links(item['children'], depth + 1)
                        links.extend(child_links)
            return links

        # Extract all links
        all_links = extract_all_links(hierarchy)

        # Remove duplicates while preserving order
        seen = set()
        for link in all_links:
            if link not in seen:
                seen.add(link)
                link_names.append(link)

        # Check for individual caster groups in part_meshes directory
        import os
        parent_dir = os.path.dirname(os.path.dirname(links_json_path))
        part_meshes_dir = os.path.join(parent_dir, 'part_meshes')
        if os.path.exists(part_meshes_dir):
            actual_dirs = [d for d in os.listdir(part_meshes_dir)
                          if os.path.isdir(os.path.join(part_meshes_dir, d))]

            # Add any caster_N_group directories that aren't in link_names
            caster_groups = [d for d in actual_dirs if d.startswith('caster_') and d.endswith('_group')]
            for caster in sorted(caster_groups):
                if caster not in link_names:
                    # Replace generic 'casters' if it exists
                    if 'casters' in link_names:
                        idx = link_names.index('casters')
                        link_names[idx] = caster
                        link_names.remove('casters') if 'casters' in link_names else None
                    else:
                        link_names.append(caster)

        if not link_names:
            # Fallback: use all top-level entries
            link_names = [item['name'] for item in hierarchy if 'name' in item]

        self.logger.info(f"Found links: {link_names}")

        # Split mesh into connected components
        mesh_components = mesh.split(only_watertight=False)
        if not mesh_components:
            mesh_components = [mesh]

        self.logger.info(f"Mesh has {len(mesh_components)} connected components")

        # Group components by links
        # This is a spatial heuristic: larger components tend to be main bodies
        # Sort components by size (number of vertices)
        mesh_components.sort(key=lambda m: len(m.vertices), reverse=True)

        link_groups = {name: [] for name in link_names}

        # Assign components to links based on size and position
        # Heuristic: Largest components go to main structural links
        components_per_link = max(1, len(mesh_components) // len(link_names))

        for i, comp in enumerate(mesh_components):
            # Determine which link this component belongs to
            link_idx = min(i // components_per_link, len(link_names) - 1)
            link_name = link_names[link_idx]
            link_groups[link_name].append(comp)

        # Remove empty groups
        link_groups = {k: v for k, v in link_groups.items() if v}

        return link_groups, list(link_groups.keys())

    def _normalize_point_cloud(self, point_cloud: np.ndarray) -> np.ndarray:
        """
        Normalize point cloud xyz coordinates to unit sphere while preserving RGB.
        This matches PointLLM's expected input format.

        Args:
            point_cloud: Nx6 array (xyz + RGB)

        Returns:
            Normalized point cloud
        """
        pc = point_cloud.copy()

        # Normalize only xyz coordinates (first 3 columns)
        xyz = pc[:, :3]
        centroid = np.mean(xyz, axis=0)
        xyz = xyz - centroid
        max_dist = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
        if max_dist > 0:
            xyz = xyz / max_dist
        pc[:, :3] = xyz

        # RGB values (columns 3-6) remain unchanged in 0-1 range

        return pc

    def _resample_to_target_size(self, point_cloud: np.ndarray, target_size: int) -> np.ndarray:
        """
        Resample point cloud to exact target size using farthest point sampling or random sampling.

        Args:
            point_cloud: Nx6 array
            target_size: Target number of points

        Returns:
            Resampled point cloud
        """
        current_size = len(point_cloud)

        if current_size == target_size:
            return point_cloud
        elif current_size > target_size:
            # Downsample using farthest point sampling for better coverage
            return self._farthest_point_sample(point_cloud, target_size)
        else:
            # Upsample by repeating points
            indices = np.random.choice(current_size, target_size, replace=True)
            return point_cloud[indices]

    def _farthest_point_sample(self, point_cloud: np.ndarray, n_points: int) -> np.ndarray:
        """
        Use farthest point sampling for better point distribution.

        Args:
            point_cloud: Nx6 array (xyz + RGB)
            n_points: Number of points to sample

        Returns:
            Sampled point cloud
        """
        if len(point_cloud) <= n_points:
            return point_cloud

        # Use only xyz for distance calculations
        xyz = point_cloud[:, :3]
        N = len(xyz)

        # Initialize
        centroids = np.zeros(n_points, dtype=np.int32)
        distance = np.ones(N) * 1e10
        farthest = np.random.randint(0, N)

        for i in range(n_points):
            centroids[i] = farthest
            centroid = xyz[farthest]
            dist = np.sum((xyz - centroid) ** 2, axis=1)
            mask = dist < distance
            distance[mask] = dist[mask]
            farthest = np.argmax(distance)

        return point_cloud[centroids]

    def _generate_component_colors(self, count: int) -> List[Tuple[float, float, float]]:
        """
        Generate visually distinct RGB colors for links.

        Args:
            count: Number of colors to generate

        Returns:
            List of RGB tuples (values in [0, 1])
        """
        if count <= 0:
            return [(1.0, 0.0, 0.0)]

        # Use well-separated hues for better distinction
        hues = np.linspace(0.0, 1.0, count, endpoint=False)
        colors = []

        for idx, hue in enumerate(hues):
            # Alternate saturation and value for better contrast while keeping colors bright
            # This matches the mesh_renderer approach for consistent visualization
            sat = 0.65 + 0.25 * (idx % 2)  # Alternate between 0.65 and 0.90
            val = 0.85 if idx % 3 else 0.95  # Vary brightness (0.85 or 0.95)
            r, g, b = colorsys.hsv_to_rgb(float(hue), min(sat, 1.0), val)
            colors.append((float(r), float(g), float(b)))

        return colors

    def _generate_color_names(self, count: int) -> List[str]:
        """
        Generate human-readable color names for components.

        Args:
            count: Number of color names to generate

        Returns:
            List of color names
        """
        # Basic color names for better readability
        base_names = [
            "red", "orange", "yellow", "green", "cyan",
            "blue", "purple", "magenta", "pink", "brown"
        ]

        color_names = []
        for i in range(count):
            if i < len(base_names):
                color_names.append(base_names[i])
            else:
                # Generate additional names with numbers
                base_idx = i % len(base_names)
                suffix = (i // len(base_names)) + 1
                color_names.append(f"{base_names[base_idx]}{suffix}")

        return color_names

    def save_colored_pointcloud(self, point_cloud: np.ndarray, output_path: str):
        """
        Save the colored point cloud for debugging and visualization.

        Args:
            point_cloud: Nx6 array (xyz + RGB in 0-1 range)
            output_path: Path to save the point cloud
        """
        try:
            # Save as numpy array
            np.save(output_path, point_cloud)
            self.logger.info(f"Saved colored point cloud to {output_path}")

            # Also save as PLY for visualization
            ply_path = output_path.replace('.npy', '.ply')
            self._save_as_ply(point_cloud, ply_path)

        except Exception as e:
            self.logger.error(f"Failed to save point cloud: {e}")

    def _save_as_ply(self, point_cloud: np.ndarray, ply_path: str):
        """
        Save point cloud as PLY file for visualization.

        Args:
            point_cloud: Nx6 array (xyz + RGB in 0-1 range)
            ply_path: Path to save PLY file
        """
        try:
            # Convert RGB from 0-1 to 0-255
            xyz = point_cloud[:, :3]
            rgb = (point_cloud[:, 3:] * 255).astype(np.uint8)

            # Create trimesh point cloud
            colors = np.hstack([rgb, np.ones((len(rgb), 1), dtype=np.uint8) * 255])  # Add alpha
            pc = trimesh.points.PointCloud(vertices=xyz, colors=colors)

            # Export as PLY
            pc.export(ply_path)
            self.logger.info(f"Saved PLY visualization to {ply_path}")
        except Exception as e:
            self.logger.warning(f"Failed to save PLY: {e}")


# Test function for development
def test_pointcloud_converter():
    """Test the point cloud converter with a sample OBJ file."""

    # Test with the cabinet that had issues
    test_mesh_path = "./output/test_cabinet/combined_assembly.obj"
    links_json_path = "./output/test_cabinet/configs/links_hierarchy.json"

    if not os.path.exists(test_mesh_path):
        print(f"Test mesh not found at: {test_mesh_path}")
        return False

    try:
        converter = PointCloudConverter(sample_points=8192)

        # Test with link-based grouping
        point_cloud, component_mapping, color_names = converter.convert_obj_to_pointcloud(
            test_mesh_path,
            links_json_path
        )

        print(f"Point cloud shape: {point_cloud.shape}")
        print(f"Number of links: {len(component_mapping)}")
        print(f"Link mapping: {component_mapping}")
        print(f"Color names: {color_names}")
        print(f"XYZ range: [{point_cloud[:, :3].min():.3f}, {point_cloud[:, :3].max():.3f}]")
        print(f"RGB range: [{point_cloud[:, 3:].min():.3f}, {point_cloud[:, 3:].max():.3f}]")

        # Save for inspection
        output_path = "/tmp/test_cabinet_pointcloud.npy"
        converter.save_colored_pointcloud(point_cloud, output_path)
        print(f"Saved to {output_path} and {output_path.replace('.npy', '.ply')}")

        return True

    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    # Run test
    test_pointcloud_converter()