"""
Articulation Renderer for VLM Feedback System

This module renders articulated objects with colored child links to visualize
joint movements for VLM analysis. Uses trimesh and pyrender for headless rendering.
"""

import os
import json
import numpy as np
import trimesh
import pyrender
import logging
from typing import List, Dict, Tuple, Optional, Any
from PIL import Image
import colorsys
import xml.etree.ElementTree as ET


class ArticulationRenderer:
    """
    Renders articulated objects with colored child links for joint visualization.

    Features:
    - Colors child links of movable joints (revolute, continuous, prismatic)
    - Renders object in 2 joint states (initial and moved)
    - Captures from 4 camera angles for comprehensive view
    - Generates color mapping for VLM interpretation
    """

    def __init__(self,
                 image_size: Tuple[int, int] = (384, 384),
                 background_color: Tuple[float, float, float, float] = (0.95, 0.95, 0.95, 1.0)):
        """
        Initialize the articulation renderer.

        Args:
            image_size: (width, height) of rendered images
            background_color: Background color for renders (R, G, B, A)
        """
        self.image_size = image_size
        self.background_color = background_color
        self.logger = logging.getLogger(self.__class__.__name__)

        # Set up headless rendering
        if 'DISPLAY' not in os.environ:
            os.environ['PYOPENGL_PLATFORM'] = 'egl'

        # Define camera angles for comprehensive viewing
        self.camera_angles = [
            {"name": "front_right_top", "azimuth": 45, "elevation": 30},
            {"name": "front_left_top", "azimuth": -45, "elevation": 30},
            {"name": "back_right_top", "azimuth": 135, "elevation": 30},
            {"name": "back_left_top", "azimuth": -135, "elevation": 30}
        ]

        # Joint states to render (2 states for efficiency)
        self.joint_states = [
            {"name": "initial", "value": 0.0, "description": "Initial position"},
            {"name": "moved", "value": 0.75, "description": "75% of range or max"}
        ]

    def render_articulated_object(self, base_dir: str, output_dir: str) -> Tuple[List[str], Dict[str, Any]]:
        """
        Render an articulated object with colored child links.

        Args:
            base_dir: Base directory containing URDF, links, and configs
            output_dir: Directory to save rendered images

        Returns:
            Tuple of (list of image paths, color mapping dictionary)
        """
        # Load articulation configuration
        articulation = self._load_articulation(base_dir)
        if not articulation:
            raise ValueError(f"No articulation.json found in {base_dir}")

        # Load URDF to get link information
        urdf_path = os.path.join(base_dir, "generated.urdf")
        if not os.path.exists(urdf_path):
            raise FileNotFoundError(f"URDF file not found: {urdf_path}")

        # Create color mapping for movable joints
        color_mapping = self._create_color_mapping(articulation)

        # Create renders directory
        renders_dir = os.path.join(output_dir, "renders")
        os.makedirs(renders_dir, exist_ok=True)

        # Save color mapping for reference
        mapping_path = os.path.join(output_dir, "color_mapping.json")
        with open(mapping_path, 'w', encoding='utf-8') as f:
            json.dump(color_mapping, f, indent=2)

        self.logger.info(f"Created color mapping for {len(color_mapping)} movable joints")

        # Render each state
        rendered_images = []
        for state in self.joint_states:
            state_images = self._render_state(
                base_dir, renders_dir, state, articulation, color_mapping
            )
            rendered_images.extend(state_images)

        self.logger.info(f"Successfully rendered {len(rendered_images)} images")
        return rendered_images, color_mapping

    def _load_articulation(self, base_dir: str) -> Optional[List[Dict[str, Any]]]:
        """Load articulation configuration from JSON file."""
        # Try configs folder first
        configs_path = os.path.join(base_dir, "configs", "articulation.json")
        if os.path.exists(configs_path):
            with open(configs_path, 'r', encoding='utf-8') as f:
                return json.load(f)

        # Try base directory (legacy)
        articulation_path = os.path.join(base_dir, "articulation.json")
        if os.path.exists(articulation_path):
            with open(articulation_path, 'r', encoding='utf-8') as f:
                return json.load(f)

        return None

    def _create_color_mapping(self, articulation: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Create color mapping for child links of movable joints.

        Args:
            articulation: List of joint specifications

        Returns:
            Dictionary mapping joint names to color information
        """
        color_mapping = {}
        movable_joints = []

        # Find all movable joints
        for joint in articulation:
            if joint.get('type') in ['revolute', 'continuous', 'prismatic']:
                movable_joints.append(joint)

        # Generate distinct colors
        num_colors = len(movable_joints)
        colors = self._generate_distinct_colors(num_colors)
        color_names = self._generate_color_names(num_colors)

        # Assign colors to child links
        for idx, joint in enumerate(movable_joints):
            joint_name = joint.get('joint_name', f"joint_{idx}")
            child_link = joint.get('child', '')

            color_mapping[joint_name] = {
                'child_link': child_link,
                'color_rgb': colors[idx],
                'color_name': color_names[idx],
                'joint_type': joint.get('type'),
                'parent_link': joint.get('parent', ''),
                'axis': joint.get('axis', [0, 0, 1]),
                'limits': joint.get('limit', {})
            }

        return color_mapping

    def _generate_color_names(self, count: int) -> List[str]:
        """Generate color names for joints, extending beyond basic colors if needed."""
        basic_colors = [
            'RED', 'BLUE', 'GREEN', 'YELLOW', 'CYAN', 'MAGENTA',
            'ORANGE', 'PURPLE', 'PINK', 'LIME', 'BROWN', 'NAVY',
            'TEAL', 'OLIVE', 'MAROON', 'GRAY', 'SILVER', 'GOLD',
            'CORAL', 'SALMON', 'INDIGO', 'VIOLET', 'TURQUOISE', 'CRIMSON'
        ]

        if count <= len(basic_colors):
            return basic_colors[:count]

        # If we need more colors, generate numbered variants
        color_names = basic_colors.copy()
        for i in range(len(basic_colors), count):
            color_names.append(f"COLOR_{i+1}")

        return color_names[:count]

    def _generate_distinct_colors(self, count: int) -> List[Tuple[float, float, float, float]]:
        """Generate visually distinct RGBA colors for joint visualization."""
        if count <= 0:
            return []

        colors = []
        hues = np.linspace(0.0, 1.0, count, endpoint=False)

        for idx, hue in enumerate(hues):
            # Use high saturation and value for vivid colors
            saturation = 0.85
            value = 0.9
            r, g, b = colorsys.hsv_to_rgb(float(hue), saturation, value)
            colors.append((float(r), float(g), float(b), 1.0))

        return colors

    def _render_state(self, base_dir: str, output_dir: str,
                     state: Dict[str, Any], articulation: List[Dict[str, Any]],
                     color_mapping: Dict[str, Any]) -> List[str]:
        """
        Render the object in a specific joint state from multiple angles.

        Args:
            base_dir: Base directory with mesh files
            output_dir: Output directory for images
            state: Joint state configuration
            articulation: Articulation specification
            color_mapping: Color assignments for joints

        Returns:
            List of rendered image paths
        """
        rendered_images = []
        links_dir = os.path.join(base_dir, "links")

        # Load and prepare meshes with colors
        scene_meshes = self._load_meshes_with_colors(links_dir, color_mapping)

        # Apply joint transformations for this state with URDF hierarchy
        scene_meshes = self._apply_joint_transforms(
            scene_meshes, articulation, state['value'], color_mapping, base_dir
        )

        # Calculate scene bounds
        all_vertices = []
        for mesh_data in scene_meshes.values():
            if 'mesh' in mesh_data:
                all_vertices.append(mesh_data['mesh'].vertices)

        if all_vertices:
            all_vertices = np.vstack(all_vertices)
            bounds_min = np.min(all_vertices, axis=0)
            bounds_max = np.max(all_vertices, axis=0)
            center = (bounds_min + bounds_max) / 2
            size = bounds_max - bounds_min
            max_dimension = np.max(size)
        else:
            center = np.array([0, 0, 0])
            max_dimension = 10.0

        # Render from each angle
        for angle in self.camera_angles:
            image_path = self._render_single_view(
                scene_meshes, center, max_dimension,
                output_dir, state['name'], angle
            )
            if image_path:
                rendered_images.append(image_path)
                self.logger.info(f"Rendered {state['name']}_{angle['name']}")

        return rendered_images

    def _load_meshes_with_colors(self, links_dir: str,
                                 color_mapping: Dict[str, Any]) -> Dict[str, Any]:
        """
        Load mesh files and apply colors based on joint mapping.

        Args:
            links_dir: Directory containing link mesh files
            color_mapping: Color assignments for joints

        Returns:
            Dictionary of loaded meshes with color information
        """
        scene_meshes = {}

        # Get all OBJ files in links directory
        if not os.path.exists(links_dir):
            self.logger.warning(f"Links directory not found: {links_dir}")
            return scene_meshes

        for filename in os.listdir(links_dir):
            if filename.endswith('.obj'):
                link_name = filename[:-4]  # Remove .obj extension
                mesh_path = os.path.join(links_dir, filename)

                try:
                    mesh = trimesh.load(mesh_path, force='mesh')
                    if isinstance(mesh, trimesh.Scene):
                        mesh = mesh.dump(concatenate=True)

                    # Check if this link is a colored child
                    mesh_color = (0.7, 0.7, 0.7, 1.0)  # Default gray
                    is_movable_child = False

                    for joint_info in color_mapping.values():
                        if joint_info['child_link'] == link_name:
                            mesh_color = joint_info['color_rgb']
                            is_movable_child = True
                            break

                    scene_meshes[link_name] = {
                        'mesh': mesh,
                        'color': mesh_color,
                        'is_movable': is_movable_child,
                        'transform': np.eye(4)
                    }

                except Exception as e:
                    self.logger.warning(f"Failed to load mesh {filename}: {e}")

        return scene_meshes

    def _parse_urdf_hierarchy(self, base_dir: str) -> Dict[str, Any]:
        """
        Parse URDF file to build joint hierarchy including fixed joints.

        Args:
            base_dir: Base directory containing URDF file

        Returns:
            Dictionary with joint hierarchy information
        """
        urdf_path = os.path.join(base_dir, "generated.urdf")
        if not os.path.exists(urdf_path):
            return {}

        hierarchy = {
            'joints': {},
            'links': {},
            'fixed_children': {}  # Maps parent link to list of fixed children
        }

        try:
            tree = ET.parse(urdf_path)
            root = tree.getroot()

            # Parse all joints
            for joint_elem in root.findall('.//joint'):
                joint_name = joint_elem.get('name')
                joint_type = joint_elem.get('type')
                parent_elem = joint_elem.find('parent')
                child_elem = joint_elem.find('child')

                if parent_elem is not None and child_elem is not None:
                    parent_link = parent_elem.get('link')
                    child_link = child_elem.get('link')

                    hierarchy['joints'][joint_name] = {
                        'type': joint_type,
                        'parent': parent_link,
                        'child': child_link
                    }

                    # Build fixed joint relationships
                    if joint_type == 'fixed':
                        if parent_link not in hierarchy['fixed_children']:
                            hierarchy['fixed_children'][parent_link] = []
                        hierarchy['fixed_children'][parent_link].append(child_link)

        except Exception as e:
            self.logger.warning(f"Failed to parse URDF hierarchy: {e}")

        return hierarchy

    def _propagate_transform(self, link_name: str, transform: np.ndarray,
                            scene_meshes: Dict[str, Any],
                            hierarchy: Dict[str, Any],
                            processed: set) -> None:
        """
        Recursively propagate transform to fixed children.

        Args:
            link_name: Name of the link
            transform: Transform matrix to apply
            scene_meshes: Dictionary of meshes
            hierarchy: Joint hierarchy from URDF
            processed: Set of already processed links to avoid cycles
        """
        if link_name in processed:
            return
        processed.add(link_name)

        # Apply transform to current link
        if link_name in scene_meshes:
            if 'transform' in scene_meshes[link_name]:
                # Combine with existing transform
                scene_meshes[link_name]['transform'] = transform @ scene_meshes[link_name]['transform']
            else:
                scene_meshes[link_name]['transform'] = transform

        # Propagate to fixed children
        if link_name in hierarchy.get('fixed_children', {}):
            for child_link in hierarchy['fixed_children'][link_name]:
                self._propagate_transform(child_link, transform, scene_meshes, hierarchy, processed)

    def _apply_joint_transforms(self, scene_meshes: Dict[str, Any],
                               articulation: List[Dict[str, Any]],
                               state_value: float,
                               color_mapping: Dict[str, Any],
                               base_dir: str = None) -> Dict[str, Any]:
        """
        Apply joint transformations based on state value with proper fixed joint handling.

        Args:
            scene_meshes: Dictionary of meshes
            articulation: Joint specifications
            state_value: Joint position (0.0 to 1.0)
            color_mapping: Joint color mapping
            base_dir: Base directory for URDF parsing

        Returns:
            Updated meshes with transformations applied
        """
        # Parse URDF hierarchy if base_dir provided
        hierarchy = {}
        if base_dir:
            hierarchy = self._parse_urdf_hierarchy(base_dir)

        # Track which links have been processed
        processed_links = set()

        # First pass: apply transformations to movable joints
        for joint in articulation:
            joint_type = joint.get('type')
            child_link = joint.get('child')

            if joint_type in ['revolute', 'continuous'] and child_link in scene_meshes:
                # Apply rotation based on state
                axis = np.array(joint.get('axis', [0, 0, 1]))

                # Get rotation limits
                if joint_type == 'revolute':
                    lower = joint.get('limit', {}).get('lower', 0)
                    upper = joint.get('limit', {}).get('upper', np.pi/2)
                else:  # continuous
                    lower = 0
                    upper = 2 * np.pi

                # Calculate rotation angle
                angle = lower + state_value * (upper - lower)

                # Apply rotation to child mesh and its fixed children
                if angle != 0:
                    rotation_matrix = trimesh.transformations.rotation_matrix(
                        angle, axis, point=joint.get('origin', {}).get('xyz', [0, 0, 0])
                    )

                    # Apply to the direct child and propagate to fixed children
                    self._propagate_transform(child_link, rotation_matrix,
                                            scene_meshes, hierarchy, processed_links)

            elif joint_type == 'prismatic' and child_link in scene_meshes:
                # Apply translation based on state
                axis = np.array(joint.get('axis', [1, 0, 0]))
                lower = joint.get('limit', {}).get('lower', 0)
                upper = joint.get('limit', {}).get('upper', 0.1)

                # Calculate translation distance
                distance = lower + state_value * (upper - lower)

                # Apply translation to child mesh and its fixed children
                if distance != 0:
                    translation = axis * distance
                    transform = np.eye(4)
                    transform[:3, 3] = translation

                    # Apply to the direct child and propagate to fixed children
                    self._propagate_transform(child_link, transform,
                                            scene_meshes, hierarchy, processed_links)

        return scene_meshes

    def _render_single_view(self, scene_meshes: Dict[str, Any],
                           center: np.ndarray, max_dimension: float,
                           output_dir: str, state_name: str,
                           angle: Dict[str, Any]) -> Optional[str]:
        """
        Render a single view of the articulated object.

        Args:
            scene_meshes: Dictionary of meshes with transforms
            center: Scene center point
            max_dimension: Maximum dimension of scene
            output_dir: Output directory
            state_name: Name of joint state
            angle: Camera angle configuration

        Returns:
            Path to rendered image or None if failed
        """
        try:
            # Create pyrender scene
            scene = pyrender.Scene(ambient_light=[0.3, 0.3, 0.3],
                                 bg_color=self.background_color)

            # Add meshes to scene with colors
            for link_name, mesh_data in scene_meshes.items():
                if 'mesh' not in mesh_data:
                    continue

                mesh = mesh_data['mesh']
                color = mesh_data['color']
                transform = mesh_data['transform']

                # Create material with color
                material = pyrender.MetallicRoughnessMaterial(
                    baseColorFactor=list(color),
                    metallicFactor=0.1,
                    roughnessFactor=0.7
                )

                # Create pyrender mesh
                pyrender_mesh = pyrender.Mesh.from_trimesh(mesh, material=material)

                # Add to scene with transformation
                scene.add(pyrender_mesh, pose=transform)

            # Calculate camera position
            azimuth_rad = np.radians(angle['azimuth'])
            elevation_rad = np.radians(angle['elevation'])
            distance = max_dimension * 1.5

            camera_x = center[0] + distance * np.cos(elevation_rad) * np.cos(azimuth_rad)
            camera_y = center[1] + distance * np.sin(elevation_rad)
            camera_z = center[2] + distance * np.cos(elevation_rad) * np.sin(azimuth_rad)

            camera_pos = np.array([camera_x, camera_y, camera_z])

            # Create camera
            camera = pyrender.PerspectiveCamera(
                yfov=np.pi / 3.0,
                aspectRatio=self.image_size[0] / self.image_size[1]
            )

            # Look at center
            camera_pose = self._look_at(camera_pos, center, np.array([0, 1, 0]))
            scene.add(camera, pose=camera_pose)

            # Add lights
            # Main light
            main_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
            main_light_pose = self._look_at(
                camera_pos + np.array([2, 5, -2]),
                center,
                np.array([0, 1, 0])
            )
            scene.add(main_light, pose=main_light_pose)

            # Fill light
            fill_light = pyrender.DirectionalLight(color=[0.8, 0.8, 0.9], intensity=1.5)
            fill_light_pose = self._look_at(
                camera_pos + np.array([-3, 2, 1]),
                center,
                np.array([0, 1, 0])
            )
            scene.add(fill_light, pose=fill_light_pose)

            # Render
            renderer = pyrender.OffscreenRenderer(self.image_size[0], self.image_size[1])
            color_img, depth_img = renderer.render(scene)
            renderer.delete()

            # Save image
            image_filename = f"{state_name}_{angle['name']}.png"
            image_path = os.path.join(output_dir, image_filename)

            image = Image.fromarray(color_img)
            image.save(image_path)

            return image_path

        except Exception as e:
            self.logger.error(f"Failed to render {state_name}_{angle['name']}: {e}")
            return None

    def _look_at(self, camera_pos: np.ndarray, target: np.ndarray,
                up: np.ndarray) -> np.ndarray:
        """Create a look-at transformation matrix."""
        forward = target - camera_pos
        forward = forward / np.linalg.norm(forward)

        right = np.cross(forward, up)
        right = right / np.linalg.norm(right)

        up_corrected = np.cross(right, forward)

        pose = np.eye(4)
        pose[:3, :3] = np.column_stack((right, up_corrected, -forward))
        pose[:3, 3] = camera_pos

        return pose


def test_articulation_renderer():
    """Test the articulation renderer with a sample object."""
    import tempfile

    # Test with a sample articulated object
    test_dir = "./output/test_articulation/sample_object"

    if not os.path.exists(test_dir):
        print(f"Test directory not found: {test_dir}")
        return False

    renderer = ArticulationRenderer()

    with tempfile.TemporaryDirectory() as temp_dir:
        try:
            images, color_mapping = renderer.render_articulated_object(test_dir, temp_dir)
            print(f"✓ Rendered {len(images)} images")
            print(f"✓ Color mapping: {json.dumps(color_mapping, indent=2)}")
            return True
        except Exception as e:
            print(f"✗ Test failed: {e}")
            return False


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    test_articulation_renderer()