"""
Mesh-based Object Renderer for VLM Feedback System

This module provides functionality to render OBJ mesh files from multiple angles
for VLM analysis. It uses trimesh and pyrender for reliable headless rendering.
"""

import os
import json
import colorsys
import numpy as np
import trimesh
import pyrender
import logging
from typing import List, Tuple, Dict, Optional
from PIL import Image


class MeshRenderer:
    """
    Renders OBJ mesh files from multiple angles for VLM analysis.

    Uses trimesh and pyrender for headless rendering to generate images
    from tilted angles that provide comprehensive object views.
    """

    def __init__(self, image_size: Tuple[int, int] = (384, 384),
                 background_color: Tuple[float, float, float, float] = (0.95, 0.95, 0.95, 1.0)):
        """
        Initialize the mesh renderer.

        Args:
            image_size: (width, height) of rendered images
            background_color: Background color for renders (R, G, B, A)
        """
        self.image_size = image_size
        self.background_color = background_color
        self.logger = logging.getLogger(self.__class__.__name__)

        # Set up headless rendering environment
        import os
        if 'DISPLAY' not in os.environ:
            os.environ['PYOPENGL_PLATFORM'] = 'egl'

        # Define four tilted camera angles for comprehensive object viewing
        self.camera_angles = [
            {"name": "front_right_top", "azimuth": 45, "elevation": 30, "description": "front-right-top view"},
            {"name": "front_left_top", "azimuth": -45, "elevation": 30, "description": "front-left-top view"},
            {"name": "back_right_top", "azimuth": 135, "elevation": 30, "description": "back-right-top view"},
            {"name": "back_left_top", "azimuth": -135, "elevation": 30, "description": "back-left-top view"}
        ]

    def render_mesh_from_file(self, obj_path: str, output_dir: str, links_json_path: Optional[str] = None) -> List[str]:
        """
        Render an OBJ mesh file from multiple angles.

        Uses link-based grouping if links_hierarchy.json is available,
        otherwise falls back to connected component coloring.

        Args:
            obj_path: Path to the OBJ file to render
            output_dir: Directory to save rendered images
            links_json_path: Optional path to links_hierarchy.json for link-based coloring

        Returns:
            List of paths to rendered images
        """
        return self._render_mesh_internal(obj_path, output_dir, links_json_path)

    def _render_mesh_internal(self, obj_path: str, output_dir: str, links_json_path: Optional[str] = None) -> List[str]:
        """
        Internal method to render an OBJ mesh file from multiple angles.

        Args:
            obj_path: Path to the OBJ file to render
            output_dir: Directory to save rendered images

        Returns:
            List of paths to rendered images
        """
        if not os.path.exists(obj_path):
            raise FileNotFoundError(f"OBJ file not found: {obj_path}")

        # Create renders subdirectory
        renders_dir = os.path.join(output_dir, "renders")
        os.makedirs(renders_dir, exist_ok=True)

        try:
            # Load the mesh
            self.logger.info(f"Loading mesh from: {obj_path}")
            mesh = trimesh.load(obj_path)

            # Handle case where trimesh returns a Scene object
            if isinstance(mesh, trimesh.Scene):
                # Convert scene to a single mesh
                mesh = mesh.dump(concatenate=True)

            if not isinstance(mesh, trimesh.Trimesh):
                raise ValueError("Failed to load a valid mesh from the OBJ file")

            self.logger.info(f"Loaded mesh with {len(mesh.vertices)} vertices and {len(mesh.faces)} faces")

            # Split mesh into connected components
            mesh_components = mesh.split(only_watertight=False)
            if not mesh_components:
                mesh_components = [mesh]

            self.logger.info(f"Mesh has {len(mesh_components)} connected components")

            # Group components by links if available
            if links_json_path and os.path.exists(links_json_path):
                link_groups, link_names = self._group_by_links(mesh_components, links_json_path)
                self.logger.info(f"Using link-based coloring: {len(link_groups)} links")
                num_colors = len(link_groups)

                # Create components list with proper colors for each link
                colored_components = []
                component_colors = self._generate_component_colors(num_colors)

                for idx, (link_name, components) in enumerate(link_groups.items()):
                    color = component_colors[idx]
                    for component in components:
                        colored_components.append((component, color))

                # Log link color mapping for debugging
                color_names = self._get_color_names(num_colors)
                for idx, link_name in enumerate(link_names):
                    self.logger.info(f"  {link_name}: {color_names[idx]}")

            else:
                # Fallback to per-component coloring
                if len(mesh_components) > 50:
                    self.logger.warning(f"No links file found and mesh has {len(mesh_components)} components. This may be slow.")

                component_colors = self._generate_component_colors(len(mesh_components))
                colored_components = [(comp, color) for comp, color in zip(mesh_components, component_colors)]
                self.logger.info(f"Using per-component coloring: {len(mesh_components)} colors")

            # Get mesh bounds for camera positioning
            bounds = mesh.bounds
            center = mesh.centroid
            size = bounds[1] - bounds[0]  # [max_x - min_x, max_y - min_y, max_z - min_z]
            max_dimension = np.max(size)


            # Render from each angle
            rendered_images = []
            for angle in self.camera_angles:
                try:
                    image_path = self._render_single_angle(
                        colored_components,
                        center,
                        max_dimension,
                        renders_dir,
                        angle
                    )
                    if image_path:
                        rendered_images.append(image_path)
                        self.logger.info(f"Rendered {angle['description']}: {os.path.basename(image_path)}")
                except Exception as e:
                    self.logger.warning(f"Failed to render from {angle['description']}: {e}")

            if not rendered_images:
                raise RuntimeError("Failed to render any images")

            self.logger.info(f"Successfully rendered {len(rendered_images)} images")
            return rendered_images

        except Exception as e:
            self.logger.error(f"Failed to load or render mesh: {e}")
            raise

    def _render_single_angle(self, colored_components: List[Tuple[trimesh.Trimesh, Tuple[float, float, float, float]]],
                           center: np.ndarray, max_dimension: float, output_dir: str, angle: Dict) -> Optional[str]:
        """
        Render mesh from a single camera angle.

        Args:
            colored_components: List of (mesh_component, color) tuples
            center: Center point of the mesh
            max_dimension: Maximum dimension of the mesh
            output_dir: Output directory for images
            angle: Camera angle configuration

        Returns:
            Path to the rendered image or None if failed
        """
        try:
            # Create pyrender scene with softer ambient lighting for better shadows
            scene = pyrender.Scene(ambient_light=[0.2, 0.2, 0.2], bg_color=self.background_color)

            from pyrender import MetallicRoughnessMaterial

            for component, color in colored_components:
                material = MetallicRoughnessMaterial(
                    baseColorFactor=list(color),
                    metallicFactor=0.05,
                    roughnessFactor=0.8
                )
                pyrender_mesh = pyrender.Mesh.from_trimesh(component, material=material, smooth=False)
                scene.add(pyrender_mesh)

            # Calculate camera position
            azimuth_rad = np.radians(angle['azimuth'])
            elevation_rad = np.radians(angle['elevation'])

            # Distance adjusted to make object larger in frame
            # Object should occupy about 70-80% of the image
            distance = max_dimension * 1.15


            # Calculate camera position relative to center
            camera_x = center[0] + distance * np.cos(elevation_rad) * np.cos(azimuth_rad)
            camera_y = center[1] + distance * np.sin(elevation_rad)
            camera_z = center[2] + distance * np.cos(elevation_rad) * np.sin(azimuth_rad)

            camera_pos = np.array([camera_x, camera_y, camera_z])

            # Create camera
            camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=self.image_size[0] / self.image_size[1])

            # Look at the center
            camera_pose = self._look_at(camera_pos, center, np.array([0, 1, 0]))
            scene.add(camera, pose=camera_pose)

            # Add multiple lights for better shadow effects
            # Main directional light from above-right
            main_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
            main_light_pose = self._look_at(
                camera_pos + np.array([2, 5, -2]),  # Light from above-right
                center,
                np.array([0, 1, 0])
            )
            scene.add(main_light, pose=main_light_pose)

            # Secondary fill light from left to reduce harsh shadows
            fill_light = pyrender.DirectionalLight(color=[0.7, 0.7, 0.9], intensity=1.5)
            fill_light_pose = self._look_at(
                camera_pos + np.array([-3, 2, 1]),  # Light from left side
                center,
                np.array([0, 1, 0])
            )
            scene.add(fill_light, pose=fill_light_pose)

            # Subtle rim light from behind for edge definition
            rim_light = pyrender.DirectionalLight(color=[0.9, 0.9, 1.0], intensity=1.0)
            rim_light_pose = self._look_at(
                center - (camera_pos - center) + np.array([0, 3, 0]),  # Light from behind-above
                center,
                np.array([0, 1, 0])
            )
            scene.add(rim_light, pose=rim_light_pose)

            # Render with headless setup
            try:
                renderer = pyrender.OffscreenRenderer(self.image_size[0], self.image_size[1])
                color, depth = renderer.render(scene)
                renderer.delete()
            except Exception as e:
                # Fallback: try with osmesa
                try:
                    os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
                    renderer = pyrender.OffscreenRenderer(self.image_size[0], self.image_size[1])
                    color, depth = renderer.render(scene)
                    renderer.delete()
                except Exception as e2:
                    self.logger.error(f"Both EGL and OSMesa rendering failed: {e}, {e2}")
                    raise e

            # Save image
            image_filename = f"{angle['name']}.png"
            image_path = os.path.join(output_dir, image_filename)

            # Convert to PIL Image and save
            image = Image.fromarray(color)
            image.save(image_path)

            return image_path

        except Exception as e:
            self.logger.error(f"Failed to render angle {angle['name']}: {e}")
            return None


    def _look_at(self, camera_pos: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
        """
        Create a look-at transformation matrix.

        Args:
            camera_pos: Camera position
            target: Target position to look at
            up: Up vector

        Returns:
            4x4 transformation matrix
        """
        # Calculate direction vectors
        forward = target - camera_pos
        forward = forward / np.linalg.norm(forward)

        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)

        up_corrected = np.cross(right, forward)

        # Build camera pose matrix with basis vectors as columns so the camera
        # looks down the -Z axis (OpenGL convention used by pyrender).
        pose = np.eye(4)
        pose[:3, :3] = np.column_stack((right, up_corrected, -forward))
        pose[:3, 3] = camera_pos

        return pose

    def _generate_component_colors(self, count: int) -> List[Tuple[float, float, float, float]]:
        """Generate visually distinct RGBA colors for mesh components."""
        if count <= 0:
            return [(1.0, 0.0, 0.0, 1.0)]

        hues = np.linspace(0.0, 1.0, count, endpoint=False)
        colors = []
        for idx, hue in enumerate(hues):
            # Alternate saturation and value slightly for better contrast while keeping colors bright
            sat = 0.65 + 0.25 * (idx % 2)
            val = 0.85 if idx % 3 else 0.95
            r, g, b = colorsys.hsv_to_rgb(float(hue), min(sat, 1.0), val)
            colors.append((float(r), float(g), float(b), 1.0))

        return colors

    def _get_color_names(self, count: int) -> List[str]:
        """Get human-readable color names for debugging."""
        base_names = [
            "red", "orange", "yellow", "green", "cyan",
            "blue", "purple", "magenta", "pink", "brown"
        ]

        color_names = []
        for i in range(count):
            if i < len(base_names):
                color_names.append(base_names[i])
            else:
                base_idx = i % len(base_names)
                suffix = (i // len(base_names)) + 1
                color_names.append(f"{base_names[base_idx]}{suffix}")

        return color_names

    def _group_by_links(self, mesh_components: List[trimesh.Trimesh], links_json_path: str) -> Tuple[Dict[str, List[trimesh.Trimesh]], List[str]]:
        """
        Group mesh components by links based on links_hierarchy.json.

        Args:
            mesh_components: List of mesh components
            links_json_path: Path to links_hierarchy.json

        Returns:
            Tuple of (link_groups, link_names)
        """
        # Load links hierarchy
        with open(links_json_path, 'r', encoding='utf-8') as f:
            links_data = json.load(f)

        # Extract main link names from hierarchy
        link_names = []
        hierarchy = links_data.get('hierarchy', {}).get('structure', [])

        def extract_links(items):
            """Extract main_link entries."""
            links = []
            for item in items:
                if item.get('type') == 'main_link':
                    links.append(item['name'])
                # Also check the root
                elif 'root' in links_data.get('hierarchy', {}) and item.get('name') == links_data['hierarchy']['root']:
                    links.append(item['name'])
            return links

        link_names = extract_links(hierarchy)

        # Add root if it's marked as main_link
        if 'root' in links_data.get('hierarchy', {}):
            root_name = links_data['hierarchy']['root']
            if root_name not in link_names:
                link_names.insert(0, root_name)

        if not link_names:
            # Fallback: use all top-level entries
            link_names = [item['name'] for item in hierarchy if 'name' in item]

        self.logger.info(f"Found {len(link_names)} links: {link_names}")

        # Group components by links
        # Sort components by size (largest first)
        mesh_components.sort(key=lambda m: len(m.vertices), reverse=True)

        link_groups = {name: [] for name in link_names}

        # Distribute components to links
        # This is a heuristic - larger components tend to be main structural parts
        components_per_link = max(1, len(mesh_components) // len(link_names))

        for i, comp in enumerate(mesh_components):
            # Determine which link this component belongs to
            link_idx = min(i // components_per_link, len(link_names) - 1)
            link_name = link_names[link_idx]
            link_groups[link_name].append(comp)

        # Remove empty groups
        link_groups = {k: v for k, v in link_groups.items() if v}

        return link_groups, list(link_groups.keys())

    def get_mesh_info(self, obj_path: str) -> Dict:
        """
        Get basic information about a mesh file.

        Args:
            obj_path: Path to the OBJ file

        Returns:
            Dictionary with mesh information
        """
        try:
            mesh = trimesh.load(obj_path)

            # Handle scene objects
            if isinstance(mesh, trimesh.Scene):
                mesh = mesh.dump(concatenate=True)

            if not isinstance(mesh, trimesh.Trimesh):
                return {"error": "Invalid mesh file"}

            bounds = mesh.bounds
            center = mesh.centroid
            size = bounds[1] - bounds[0]

            return {
                "vertices": len(mesh.vertices),
                "faces": len(mesh.faces),
                "bounds": bounds.tolist(),
                "center": center.tolist(),
                "size": size.tolist(),
                "max_dimension": float(np.max(size)),
                "volume": float(mesh.volume) if mesh.is_volume else None,
                "surface_area": float(mesh.area)
            }

        except Exception as e:
            return {"error": str(e)}


# Test function for development
def test_mesh_renderer():
    """Test the mesh renderer with a sample OBJ file."""

    # Test with the bicycle mesh
    bicycle_mesh_path = "./output/test/sample/combined_assembly.obj"

    if not os.path.exists(bicycle_mesh_path):
        print(f"Test mesh not found at: {bicycle_mesh_path}")
        return False

    import tempfile

    renderer = MeshRenderer()

    # Get mesh info first
    mesh_info = renderer.get_mesh_info(bicycle_mesh_path)
    print(f"Mesh info: {mesh_info}")

    # Render the mesh
    with tempfile.TemporaryDirectory() as temp_dir:
        try:
            images = renderer.render_mesh_from_file(bicycle_mesh_path, temp_dir)
            print(f"✓ Test render completed. Images: {images}")

            # Check if files were created
            for img_path in images:
                if os.path.exists(img_path):
                    size = os.path.getsize(img_path)
                    print(f"  - {os.path.basename(img_path)}: {size} bytes")
                else:
                    print(f"  - {os.path.basename(img_path)}: FILE NOT FOUND")

            return True
        except Exception as e:
            print(f"✗ Test render failed: {e}")
            return False


if __name__ == "__main__":
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    test_mesh_renderer()
