"""
Visualize any robot configuration using plotly.

This script creates an interactive 3D visualization of robot configurations
by reading the geometry directly from the robot configuration.

Supports:
- Quadruped 8-DOF robot
- Ant 8-DOF robot
- Biped 6-DOF robot
- Any other robot config with box or spherical bodies

Can load configs from:
- data/robot_config/*.data files (pickle format)
- Or create new configs on the fly
"""

import argparse
import os
import pickle
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
from scipy.spatial.transform import Rotation


def rotate_points_by_quaternion(
    points: np.ndarray, quat: Tuple[float, float, float, float]
) -> np.ndarray:
    """
    Rotate multiple points by a quaternion.

    Args:
        points: Array of 3D points (N, 3)
        quat: Quaternion as (x, y, z, w)

    Returns:
        Rotated points (N, 3)
    """
    r = Rotation.from_quat(quat)
    return r.apply(points)


def get_box_vertices(
    size: Tuple[float, float, float],
    origin: Tuple[float, float, float] = (0.0, 0.0, 0.0),
) -> np.ndarray:
    """
    Generate the 8 vertices of a box.

    Args:
        size: Box dimensions (x, y, z)
        origin: Position of the box corner (0,0,0)

    Returns:
        Array of 8 vertices (8, 3)
    """
    x, y, z = size
    ox, oy, oz = origin
    vertices = np.array(
        [
            [ox, oy, oz],  # 0: origin corner
            [ox + x, oy, oz],  # 1
            [ox + x, oy + y, oz],  # 2
            [ox, oy + y, oz],  # 3
            [ox, oy, oz + z],  # 4
            [ox + x, oy, oz + z],  # 5
            [ox + x, oy + y, oz + z],  # 6
            [ox, oy + y, oz + z],  # 7
        ]
    )
    return vertices


def get_box_faces() -> List[List[int]]:
    """
    Get the face indices for a box mesh.

    Returns:
        List of faces, each face is a list of 4 vertex indices
    """
    return [
        [0, 1, 2, 3],  # Bottom face
        [4, 5, 6, 7],  # Top face
        [0, 1, 5, 4],  # Front face
        [2, 3, 7, 6],  # Back face
        [0, 3, 7, 4],  # Left face
        [1, 2, 6, 5],  # Right face
    ]


def create_box_mesh_data(
    size: Tuple[float, float, float],
    origin: Tuple[float, float, float],
    orientation: Tuple[float, float, float, float],
    color: str = "blue",
    opacity: float = 0.7,
    name: str = "body",
) -> Dict[str, Any]:
    """
    Create mesh data for a 3D box using plotly Mesh3d format.

    Args:
        size: Box dimensions (x, y, z) in meters
        origin: Position of the box corner in structure space
        orientation: Quaternion (x, y, z, w) for body rotation
        color: Color of the mesh
        opacity: Opacity of the mesh
        name: Name for the trace

    Returns:
        Dictionary with mesh data for plotly
    """
    # Get local vertices
    vertices = get_box_vertices(size)

    # Rotate vertices by quaternion
    rotated_vertices = rotate_points_by_quaternion(vertices, orientation)

    # Translate to world position
    world_vertices = rotated_vertices + np.array(origin)

    # Define triangular faces (each quad face becomes 2 triangles)
    i_indices = []
    j_indices = []
    k_indices = []

    quad_faces = get_box_faces()
    for face in quad_faces:
        # First triangle
        i_indices.append(face[0])
        j_indices.append(face[1])
        k_indices.append(face[2])
        # Second triangle
        i_indices.append(face[0])
        j_indices.append(face[2])
        k_indices.append(face[3])

    return {
        "type": "mesh3d",
        "x": world_vertices[:, 0].tolist(),
        "y": world_vertices[:, 1].tolist(),
        "z": world_vertices[:, 2].tolist(),
        "i": i_indices,
        "j": j_indices,
        "k": k_indices,
        "color": color,
        "opacity": opacity,
        "name": name,
        "flatshading": True,
    }


def create_hollow_box_mesh_data(
    outer_size: Tuple[float, float, float],
    inner_size: Tuple[float, float, float],
    origin: Tuple[float, float, float],
    orientation: Tuple[float, float, float, float],
    shell_thickness: float,
    color: str = "blue",
    opacity: float = 0.7,
    name: str = "shell",
) -> Dict[str, Any]:
    """
    Create mesh data for a hollow box (shell only) using plotly Mesh3d format.
    This draws only the shell, not the inner volume.

    Args:
        outer_size: Outer box dimensions (x, y, z) in meters
        inner_size: Inner box dimensions (x, y, z) in meters
        origin: Position of the outer box corner in structure space
        orientation: Quaternion (x, y, z, w) for body rotation
        shell_thickness: Thickness of the shell in meters (used to position inner box)
        color: Color of the mesh
        opacity: Opacity of the mesh
        name: Name for the trace

    Returns:
        Dictionary with mesh data for plotly
    """
    # Calculate inner box origin (offset by shell thickness from outer box corner)
    inner_offset = (
        shell_thickness,
        shell_thickness,
        shell_thickness,
    )

    # Get vertices for outer box
    outer_vertices = get_box_vertices(outer_size, (0, 0, 0))

    # Get vertices for inner box (offset from outer origin)
    inner_vertices = get_box_vertices(inner_size, inner_offset)

    # Combine vertices: outer box (8 vertices) + inner box (8 vertices) = 16 vertices
    all_vertices = np.vstack([outer_vertices, inner_vertices])

    # Rotate all vertices by quaternion
    rotated_vertices = rotate_points_by_quaternion(all_vertices, orientation)

    # Translate to world position
    world_vertices = rotated_vertices + np.array(origin)

    # Define faces for hollow box:
    # - 6 outer faces (from outer box)
    # - 6 inner faces (from inner box, reversed normals)
    # - 12 connecting faces (between outer and inner edges)

    i_indices = []
    j_indices = []
    k_indices = []

    # Outer box faces (vertices 0-7)
    outer_faces = get_box_faces()
    for face in outer_faces:
        i_indices.append(face[0])
        j_indices.append(face[1])
        k_indices.append(face[2])
        i_indices.append(face[0])
        j_indices.append(face[2])
        k_indices.append(face[3])

    # Inner box faces (vertices 8-15, reversed winding for correct normals)
    for face in outer_faces:
        # Reverse winding: swap indices to flip normal
        inner_face = [f + 8 for f in face]
        i_indices.append(inner_face[0])
        j_indices.append(inner_face[2])
        k_indices.append(inner_face[1])
        i_indices.append(inner_face[0])
        j_indices.append(inner_face[3])
        k_indices.append(inner_face[2])

    # Connecting faces between outer and inner boxes
    # Each outer edge connects to corresponding inner edge
    # Bottom face connections
    connections = [
        # Bottom face: outer 0-3 to inner 8-11
        (0, 1, 9, 8),
        (1, 2, 10, 9),
        (2, 3, 11, 10),
        (3, 0, 8, 11),
        # Top face: outer 4-7 to inner 12-15
        (4, 5, 13, 12),
        (5, 6, 14, 13),
        (6, 7, 15, 14),
        (7, 4, 12, 15),
        # Side faces
        (0, 4, 12, 8),
        (1, 5, 13, 9),
        (2, 6, 14, 10),
        (3, 7, 15, 11),
    ]

    for face in connections:
        i_indices.append(face[0])
        j_indices.append(face[1])
        k_indices.append(face[2])
        i_indices.append(face[0])
        j_indices.append(face[2])
        k_indices.append(face[3])

    return {
        "type": "mesh3d",
        "x": world_vertices[:, 0].tolist(),
        "y": world_vertices[:, 1].tolist(),
        "z": world_vertices[:, 2].tolist(),
        "i": i_indices,
        "j": j_indices,
        "k": k_indices,
        "color": color,
        "opacity": opacity,
        "name": name,
        "flatshading": True,
    }


def create_voxelized_box_mesh_data(
    body_config,
    body_origin: Tuple[float, float, float],
    body_orientation: Tuple[float, float, float, float],
    voxel_size: float,
    color: str = "blue",
    opacity: float = 0.7,
    name: str = "voxels",
    render_soft: bool = True,
    render_rigid: bool = True,
) -> Optional[Dict[str, Any]]:
    """
    Create mesh data for a voxelized box by reading actual voxel data from body_config.
    Renders individual voxel cubes based on segment_type.

    Args:
        body_config: RS_StructureBodyConfig with voxel data
        body_origin: Position of body origin (corner) in structure space
        body_orientation: Quaternion (x, y, z, w) for body rotation
        voxel_size: Size of each voxel in meters
        color: Base color (will be modified for soft/rigid)
        opacity: Opacity of the mesh
        name: Name for the trace
        render_soft: Whether to render soft voxels (segment_type == 0)
        render_rigid: Whether to render rigid voxels (segment_type == 1)

    Returns:
        Dictionary with mesh data for plotly, or None if no voxels to render
    """
    x_voxels = body_config.x_voxels
    y_voxels = body_config.y_voxels
    z_voxels = body_config.z_voxels

    # Reconstruct 3D grid from flattened segment_type and segment_bid arrays
    # Voxels are stored in order: for iz in range(z), for iy in range(y), for ix in range(x)
    segment_type = list(body_config.segment_type)
    segment_bid = list(body_config.segment_bid)

    if len(segment_type) != x_voxels * y_voxels * z_voxels:
        # Fallback: return None if data doesn't match
        return None

    all_vertices = []
    all_i = []
    all_j = []
    all_k = []
    vertex_offset = 0

    # Colors for different voxel types
    soft_color = "rgba(173, 216, 230, 0.4)"  # Light blue, semi-transparent
    # Use provided color for rigid voxels, or default if it's an rgba string
    if isinstance(color, str) and color.startswith("rgba"):
        rigid_color = color
    else:
        rigid_color = color

    # Import RS_NULL_INDEX value (typically -1 or a large sentinel value)
    # We'll check for negative values or values >= 2^31-1 to detect null indices
    NULL_INDEX_THRESHOLD = 2**31 - 1

    # Iterate through all voxels
    for iz in range(z_voxels):
        for iy in range(y_voxels):
            for ix in range(x_voxels):
                # Calculate index in flattened array
                idx = ix + iy * x_voxels + iz * x_voxels * y_voxels

                if idx >= len(segment_type):
                    continue

                voxel_type = segment_type[idx]
                voxel_bid = segment_bid[idx] if idx < len(segment_bid) else 0

                # Skip empty voxels marked with RS_NULL_INDEX
                # RS_NULL_INDEX is typically a large value (like UINT32_MAX) or -1
                if voxel_bid < 0 or voxel_bid >= NULL_INDEX_THRESHOLD:
                    continue
                if voxel_type < 0 or voxel_type >= NULL_INDEX_THRESHOLD:
                    continue

                # Skip if we don't want to render this type
                if voxel_type == 0 and not render_soft:
                    continue
                if voxel_type == 1 and not render_rigid:
                    continue

                # Voxel position in body local coordinates (corner position)
                voxel_origin_local = np.array(
                    [
                        ix * voxel_size,
                        iy * voxel_size,
                        iz * voxel_size,
                    ]
                )

                # Create voxel cube vertices in local coordinates
                voxel_vertices_local = get_box_vertices(
                    (voxel_size, voxel_size, voxel_size), tuple(voxel_origin_local)
                )

                # Rotate voxel vertices by body orientation
                r = Rotation.from_quat(body_orientation)
                voxel_vertices_rotated = rotate_points_by_quaternion(
                    voxel_vertices_local, body_orientation
                )

                # Translate to world position
                voxel_vertices_world = voxel_vertices_rotated + np.array(body_origin)
                all_vertices.append(voxel_vertices_world)

                # Add face indices (offset by current vertex count)
                quad_faces = get_box_faces()
                for face in quad_faces:
                    all_i.append(face[0] + vertex_offset)
                    all_j.append(face[1] + vertex_offset)
                    all_k.append(face[2] + vertex_offset)
                    all_i.append(face[0] + vertex_offset)
                    all_j.append(face[2] + vertex_offset)
                    all_k.append(face[3] + vertex_offset)

                vertex_offset += 8

    if len(all_vertices) == 0:
        return None

    # Combine all vertices
    combined_vertices = np.vstack(all_vertices)

    # Use appropriate color based on what we're rendering
    if render_soft and render_rigid:
        # Mixed rendering - use a neutral color
        final_color = color
    elif render_soft:
        final_color = soft_color
    else:
        final_color = rigid_color

    return {
        "type": "mesh3d",
        "x": combined_vertices[:, 0].tolist(),
        "y": combined_vertices[:, 1].tolist(),
        "z": combined_vertices[:, 2].tolist(),
        "i": all_i,
        "j": all_j,
        "k": all_k,
        "color": final_color,
        "opacity": opacity,
        "name": name,
        "flatshading": True,
    }


def create_voxelized_sphere_mesh_data(
    radius_voxels: int,
    center: Tuple[float, float, float],
    voxel_size: float,
    color: str = "blue",
    opacity: float = 0.7,
    name: str = "sphere",
) -> Dict[str, Any]:
    """
    Create mesh data for a voxelized 3D sphere using plotly Mesh3d format.

    The sphere is represented as a collection of voxel cubes, matching the
    actual voxel representation in the robot config.

    Args:
        radius_voxels: Sphere radius in voxels
        center: Center position of the sphere in meters
        voxel_size: Size of each voxel in meters
        color: Color of the mesh
        opacity: Opacity of the mesh
        name: Name for the trace

    Returns:
        Dictionary with mesh data for plotly
    """
    all_vertices = []
    all_i = []
    all_j = []
    all_k = []

    vertex_offset = 0

    # Iterate through voxel grid
    diameter = 2 * radius_voxels
    for ix in range(diameter):
        for iy in range(diameter):
            for iz in range(diameter):
                # Check if voxel is inside sphere
                dx = ix + 0.5 - radius_voxels
                dy = iy + 0.5 - radius_voxels
                dz = iz + 0.5 - radius_voxels
                dist_sq = dx * dx + dy * dy + dz * dz

                if dist_sq <= radius_voxels * radius_voxels:
                    # This voxel is inside the sphere
                    # Calculate voxel corner position in world coords
                    voxel_origin = (
                        center[0] - radius_voxels * voxel_size + ix * voxel_size,
                        center[1] - radius_voxels * voxel_size + iy * voxel_size,
                        center[2] - radius_voxels * voxel_size + iz * voxel_size,
                    )

                    # Check if this voxel is on the surface (has at least one empty neighbor)
                    is_surface = False
                    for dx2, dy2, dz2 in [
                        (-1, 0, 0),
                        (1, 0, 0),
                        (0, -1, 0),
                        (0, 1, 0),
                        (0, 0, -1),
                        (0, 0, 1),
                    ]:
                        nx, ny, nz = ix + dx2, iy + dy2, iz + dz2
                        if (
                            0 <= nx < diameter
                            and 0 <= ny < diameter
                            and 0 <= nz < diameter
                        ):
                            ndx = nx + 0.5 - radius_voxels
                            ndy = ny + 0.5 - radius_voxels
                            ndz = nz + 0.5 - radius_voxels
                            if (
                                ndx * ndx + ndy * ndy + ndz * ndz
                                > radius_voxels * radius_voxels
                            ):
                                is_surface = True
                                break
                        else:
                            is_surface = True
                            break

                    if is_surface:
                        # Add voxel cube vertices
                        voxel_vertices = get_box_vertices(
                            (voxel_size, voxel_size, voxel_size), voxel_origin
                        )
                        all_vertices.append(voxel_vertices)

                        # Add face indices (offset by current vertex count)
                        quad_faces = get_box_faces()
                        for face in quad_faces:
                            all_i.append(face[0] + vertex_offset)
                            all_j.append(face[1] + vertex_offset)
                            all_k.append(face[2] + vertex_offset)
                            all_i.append(face[0] + vertex_offset)
                            all_j.append(face[2] + vertex_offset)
                            all_k.append(face[3] + vertex_offset)

                        vertex_offset += 8

    if len(all_vertices) == 0:
        # Fallback: if no surface voxels, create a single voxel at center
        voxel_origin = (
            center[0] - voxel_size / 2,
            center[1] - voxel_size / 2,
            center[2] - voxel_size / 2,
        )
        all_vertices.append(
            get_box_vertices((voxel_size, voxel_size, voxel_size), voxel_origin)
        )
        quad_faces = get_box_faces()
        for face in quad_faces:
            all_i.append(face[0])
            all_j.append(face[1])
            all_k.append(face[2])
            all_i.append(face[0])
            all_j.append(face[2])
            all_k.append(face[3])

    # Combine all vertices
    combined_vertices = np.vstack(all_vertices)

    return {
        "type": "mesh3d",
        "x": combined_vertices[:, 0].tolist(),
        "y": combined_vertices[:, 1].tolist(),
        "z": combined_vertices[:, 2].tolist(),
        "i": all_i,
        "j": all_j,
        "k": all_k,
        "color": color,
        "opacity": opacity,
        "name": name,
        "flatshading": True,
    }


def get_body_colors(num_bodies: int) -> List[str]:
    """
    Generate colors for body parts.

    Args:
        num_bodies: Number of bodies in the robot

    Returns:
        List of color strings
    """
    if num_bodies == 9:
        # Standard coloring for 9-body robots (1 torso + 4 upper + 4 lower legs)
        return [
            "royalblue",  # Body 0: Torso
            "forestgreen",  # Body 1: Upper leg 1
            "forestgreen",  # Body 2: Upper leg 2
            "forestgreen",  # Body 3: Upper leg 3
            "forestgreen",  # Body 4: Upper leg 4
            "darkorange",  # Body 5: Lower leg 1
            "darkorange",  # Body 6: Lower leg 2
            "darkorange",  # Body 7: Lower leg 3
            "darkorange",  # Body 8: Lower leg 4
        ]
    elif num_bodies == 7:
        # Standard coloring for 7-body robots (1 torso + 2 upper + 2 lower legs + 2 feet)
        return [
            "royalblue",  # Body 0: Torso
            "forestgreen",  # Body 1: Upper leg left
            "forestgreen",  # Body 2: Upper leg right
            "darkorange",  # Body 3: Lower leg left
            "darkorange",  # Body 4: Lower leg right
            "crimson",  # Body 5: Foot left
            "crimson",  # Body 6: Foot right
        ]
    else:
        # Generate colors for arbitrary number of bodies
        import colorsys

        colors = []
        for i in range(num_bodies):
            hue = i / num_bodies
            rgb = colorsys.hsv_to_rgb(hue, 0.7, 0.8)
            colors.append(
                f"rgb({int(rgb[0]*255)}, {int(rgb[1]*255)}, {int(rgb[2]*255)})"
            )
        return colors


def get_body_names(num_bodies: int, robot_type: str = "generic") -> List[str]:
    """
    Generate names for body parts.

    Args:
        num_bodies: Number of bodies in the robot
        robot_type: Type of robot for context-specific naming

    Returns:
        List of body name strings
    """
    if num_bodies == 9:
        if robot_type == "quadruped":
            return [
                "Torso",
                "Upper Leg Front-Left",
                "Upper Leg Front-Right",
                "Upper Leg Back-Left",
                "Upper Leg Back-Right",
                "Lower Leg Front-Left",
                "Lower Leg Front-Right",
                "Lower Leg Back-Left",
                "Lower Leg Back-Right",
            ]
        elif robot_type == "ant":
            return [
                "Torso",
                "Upper Leg +X",
                "Upper Leg +Y",
                "Upper Leg -X",
                "Upper Leg -Y",
                "Lower Leg +X",
                "Lower Leg +Y",
                "Lower Leg -X",
                "Lower Leg -Y",
            ]
    elif num_bodies == 7 and robot_type == "biped":
        return [
            "Torso",
            "Upper Leg Left",
            "Upper Leg Right",
            "Lower Leg Left",
            "Lower Leg Right",
            "Foot Left",
            "Foot Right",
        ]

    # Generic naming
    return [f"Body {i}" for i in range(num_bodies)]


def visualize_structure_config(
    structure_config,
    structure_dict: Optional[Dict[str, Any]] = None,
    voxel_size: float = 0.01,
    output_file: Optional[str] = None,
    show: bool = True,
    robot_type: str = "generic",
    title: Optional[str] = None,
):
    """
    Visualize a robot structure config using plotly.

    This function reads the body positions and orientations directly from
    the structure config and renders them appropriately based on shape.

    Args:
        structure_config: The RS_StructureConfig object
        structure_dict: Optional structure dictionary containing body_infos
        voxel_size: Size of each voxel in meters
        output_file: Optional path to save the HTML visualization
        show: Whether to display the figure
        robot_type: Type of robot for context-specific naming ("quadruped", "ant", "generic")
        title: Optional custom title for the visualization
    """
    try:
        import plotly.graph_objects as go
    except ImportError:
        print(
            "Error: plotly is required for visualization. "
            "Install with: pip install plotly"
        )
        return

    num_bodies = len(structure_config.bodies)
    body_colors = get_body_colors(num_bodies)
    body_names = get_body_names(num_bodies, robot_type)

    # Single transparent color for all soft voxel shells
    soft_shell_color = "rgba(173, 216, 230, 0.3)"  # Light blue with transparency

    traces = []

    # Extract joint positions for markers
    hip_positions = []
    knee_positions = []
    ankle_positions = []

    # Store body origins for visualization
    body_origins = []
    body_origin_names = []

    # Get body infos if available
    body_infos = structure_dict.get("body_infos", []) if structure_dict else []
    soft_shell_thickness = (
        structure_dict.get("soft_shell_thickness", 0) if structure_dict else 0
    )

    # Iterate through all bodies in the structure config
    for i, body in enumerate(structure_config.bodies):
        # Get body dimensions in meters (this is the total size including soft shell)
        total_body_size = (
            body.x_voxels * voxel_size,
            body.y_voxels * voxel_size,
            body.z_voxels * voxel_size,
        )

        # Get body position (origin) - this is the corner of the full body (including soft shell)
        pos = body.relative_origin_position
        body_origin = (pos.x, pos.y, pos.z)

        # Get body orientation (quaternion x, y, z, w)
        ori = body.relative_orientation
        body_quat = (ori.x, ori.y, ori.z, ori.w)

        # Get body info for shell thickness
        body_info = body_infos[i] if i < len(body_infos) else {}
        shell_thickness = body_info.get("soft_shell_thickness", 0)

        color = body_colors[i] if i < len(body_colors) else "gray"
        name = (
            f"{body_names[i]} (Body {body.body_sid})"
            if i < len(body_names)
            else f"Body {body.body_sid}"
        )

        # Store body origin for visualization
        body_origins.append(body_origin)
        body_origin_names.append(f"Origin B{body.body_sid}")

        # Always render using actual voxel data from config (shows exact physics body)
        if shell_thickness > 0:
            # Render soft voxels (segment_type == 0)
            soft_mesh_data = create_voxelized_box_mesh_data(
                body_config=body,
                body_origin=body_origin,
                body_orientation=body_quat,
                voxel_size=voxel_size,
                color=soft_shell_color,
                opacity=0.4,
                name=f"Soft Voxels ({shell_thickness} layers) - {body_names[i] if i < len(body_names) else f'Body {body.body_sid}'}",
                render_soft=True,
                render_rigid=False,
            )
            if soft_mesh_data:
                traces.append(go.Mesh3d(**soft_mesh_data))

            # Render rigid voxels (segment_type == 1)
            rigid_mesh_data = create_voxelized_box_mesh_data(
                body_config=body,
                body_origin=body_origin,
                body_orientation=body_quat,
                voxel_size=voxel_size,
                color=color,
                opacity=0.9,
                name=name,
                render_soft=False,
                render_rigid=True,
            )
            if rigid_mesh_data:
                traces.append(go.Mesh3d(**rigid_mesh_data))
        else:
            # No soft shell, render all voxels
            mesh_data = create_voxelized_box_mesh_data(
                body_config=body,
                body_origin=body_origin,
                body_orientation=body_quat,
                voxel_size=voxel_size,
                color=color,
                opacity=0.8,
                name=name,
                render_soft=True,
                render_rigid=True,
            )
            if mesh_data:
                traces.append(go.Mesh3d(**mesh_data))
            else:
                # Fallback to simple box if voxel data not available
                mesh_data = create_box_mesh_data(
                    size=total_body_size,
                    origin=body_origin,
                    orientation=body_quat,
                    color=color,
                    opacity=0.8,
                    name=name,
                )
                traces.append(go.Mesh3d(**mesh_data))

    # Create a mapping from body_sid to body object
    body_by_sid = {body.body_sid: body for body in structure_config.bodies}

    # Extract joint positions and axes from constraints
    hip_axes = []
    knee_axes = []
    ankle_axes = []

    for constraint in structure_config.constraints:
        # Get the anchor position from body A
        anchor = constraint.a_local_anchor
        body_a_sid = constraint.a_body_sid
        body_b_sid = constraint.b_body_sid

        # Get the body by its body_sid
        body_a = body_by_sid[body_a_sid]
        body_a_pos = np.array(
            [
                body_a.relative_origin_position.x,
                body_a.relative_origin_position.y,
                body_a.relative_origin_position.z,
            ]
        )
        body_a_quat = (
            body_a.relative_orientation.x,
            body_a.relative_orientation.y,
            body_a.relative_orientation.z,
            body_a.relative_orientation.w,
        )

        # Transform anchor from body local coords to world coords
        anchor_local = np.array([anchor.x, anchor.y, anchor.z])
        r = Rotation.from_quat(body_a_quat)
        anchor_world = body_a_pos + r.apply(anchor_local)

        # Transform axis from body local coords to world coords
        axis = constraint.hinge_a_local_axis
        axis_local = np.array([axis.x, axis.y, axis.z])
        axis_world = r.apply(axis_local)

        # Classify joints based on which bodies are connected
        if constraint.a_body_sid == 0:  # Torso to upper leg = hip joint
            hip_positions.append(anchor_world)
            hip_axes.append(axis_world)
        elif (
            num_bodies == 7 and body_b_sid >= 5
        ):  # Biped: lower leg to foot = ankle joint
            ankle_positions.append(anchor_world)
            ankle_axes.append(axis_world)
        else:  # Upper leg to lower leg = knee joint
            knee_positions.append(anchor_world)
            knee_axes.append(axis_world)

    # Add hip joint markers
    if hip_positions:
        hip_x = [p[0] for p in hip_positions]
        hip_y = [p[1] for p in hip_positions]
        hip_z = [p[2] for p in hip_positions]
        traces.append(
            go.Scatter3d(
                x=hip_x,
                y=hip_y,
                z=hip_z,
                mode="markers",
                marker=dict(size=8, color="red", symbol="circle"),
                name="Hip Joints",
            )
        )

    # Add knee joint markers
    if knee_positions:
        knee_x = [p[0] for p in knee_positions]
        knee_y = [p[1] for p in knee_positions]
        knee_z = [p[2] for p in knee_positions]
        traces.append(
            go.Scatter3d(
                x=knee_x,
                y=knee_y,
                z=knee_z,
                mode="markers",
                marker=dict(size=8, color="purple", symbol="circle"),
                name="Knee Joints",
            )
        )

    # Add ankle joint markers
    if ankle_positions:
        ankle_x = [p[0] for p in ankle_positions]
        ankle_y = [p[1] for p in ankle_positions]
        ankle_z = [p[2] for p in ankle_positions]
        traces.append(
            go.Scatter3d(
                x=ankle_x,
                y=ankle_y,
                z=ankle_z,
                mode="markers",
                marker=dict(size=8, color="green", symbol="circle"),
                name="Ankle Joints",
            )
        )

    # Add structure origin marker at (0, 0, 0)
    traces.append(
        go.Scatter3d(
            x=[0],
            y=[0],
            z=[0],
            mode="markers+text",
            marker=dict(
                size=10,
                color="gold",
                symbol="diamond-open",
                line=dict(color="black", width=2),
            ),
            text=["Structure Origin (0,0,0)"],
            textposition="top center",
            textfont=dict(size=12, color="gold"),
            name="Structure Origin",
        )
    )

    # Add global coordinate axes at structure origin
    global_axes_length = 0.05
    global_axes = [
        (np.array([global_axes_length, 0, 0]), "red", "Global X"),
        (np.array([0, global_axes_length, 0]), "green", "Global Y"),
        (np.array([0, 0, global_axes_length]), "blue", "Global Z"),
    ]

    for axis_vec, axis_color, axis_name in global_axes:
        traces.append(
            go.Scatter3d(
                x=[0, axis_vec[0]],
                y=[0, axis_vec[1]],
                z=[0, axis_vec[2]],
                mode="lines+text",
                line=dict(width=6, color=axis_color),
                text=["", axis_name],
                textposition="top center",
                textfont=dict(size=10, color=axis_color),
                name=axis_name,
                showlegend=True,
            )
        )

    # Add body origin markers
    if body_origins:
        origin_x = [p[0] for p in body_origins]
        origin_y = [p[1] for p in body_origins]
        origin_z = [p[2] for p in body_origins]
        traces.append(
            go.Scatter3d(
                x=origin_x,
                y=origin_y,
                z=origin_z,
                mode="markers+text",
                marker=dict(size=6, color="black", symbol="diamond"),
                text=body_origin_names,
                textposition="top center",
                textfont=dict(size=10, color="black"),
                name="Body Origins",
            )
        )

        # Add coordinate axes at each body origin
        axes_length = 0.02
        for i, (body, origin) in enumerate(zip(structure_config.bodies, body_origins)):
            # Get body orientation
            ori = body.relative_orientation
            body_quat = (ori.x, ori.y, ori.z, ori.w)
            r = Rotation.from_quat(body_quat)

            # Define local axes (X: red, Y: green, Z: blue)
            local_axes = [
                (np.array([axes_length, 0, 0]), "red", "X"),
                (np.array([0, axes_length, 0]), "green", "Y"),
                (np.array([0, 0, axes_length]), "blue", "Z"),
            ]

            for axis_vec, axis_color, axis_name in local_axes:
                # Rotate axis by body orientation
                rotated_axis = r.apply(axis_vec)
                end_point = np.array(origin) + rotated_axis

                traces.append(
                    go.Scatter3d(
                        x=[origin[0], end_point[0]],
                        y=[origin[1], end_point[1]],
                        z=[origin[2], end_point[2]],
                        mode="lines",
                        line=dict(width=3, color=axis_color),
                        name=(
                            f"Body {body.body_sid} {axis_name}-axis" if i == 0 else None
                        ),
                        showlegend=False,
                        hovertext=f"Body {body.body_sid} {axis_name}-axis",
                    )
                )

    # Add joint axes
    axis_length = 0.03
    for i, (joint_pos, axis_dir) in enumerate(zip(hip_positions, hip_axes)):
        # Normalize axis direction
        if np.linalg.norm(axis_dir) > 0:
            axis_dir = axis_dir / np.linalg.norm(axis_dir)

        # Hip axis
        hip_end = joint_pos + axis_dir * axis_length
        traces.append(
            go.Scatter3d(
                x=[joint_pos[0], hip_end[0]],
                y=[joint_pos[1], hip_end[1]],
                z=[joint_pos[2], hip_end[2]],
                mode="lines",
                line=dict(width=4, color="yellow"),
                name="Hip Axis" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    for i, (joint_pos, axis_dir) in enumerate(zip(knee_positions, knee_axes)):
        # Normalize axis direction
        if np.linalg.norm(axis_dir) > 0:
            axis_dir = axis_dir / np.linalg.norm(axis_dir)

        # Knee axis
        knee_end = joint_pos + axis_dir * axis_length
        traces.append(
            go.Scatter3d(
                x=[joint_pos[0], knee_end[0]],
                y=[joint_pos[1], knee_end[1]],
                z=[joint_pos[2], knee_end[2]],
                mode="lines",
                line=dict(width=4, color="cyan"),
                name="Knee Axis" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    for i, (joint_pos, axis_dir) in enumerate(zip(ankle_positions, ankle_axes)):
        # Normalize axis direction
        if np.linalg.norm(axis_dir) > 0:
            axis_dir = axis_dir / np.linalg.norm(axis_dir)

        # Ankle axis
        ankle_end = joint_pos + axis_dir * axis_length
        traces.append(
            go.Scatter3d(
                x=[joint_pos[0], ankle_end[0]],
                y=[joint_pos[1], ankle_end[1]],
                z=[joint_pos[2], ankle_end[2]],
                mode="lines",
                line=dict(width=4, color="lime"),
                name="Ankle Axis" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    # Add ground plane
    ground_size = 0.5
    ground_vertices = np.array(
        [
            [-ground_size, -ground_size, 0],
            [ground_size, -ground_size, 0],
            [ground_size, ground_size, 0],
            [-ground_size, ground_size, 0],
        ]
    )
    traces.append(
        go.Mesh3d(
            x=ground_vertices[:, 0],
            y=ground_vertices[:, 1],
            z=ground_vertices[:, 2],
            i=[0, 0],
            j=[1, 2],
            k=[2, 3],
            color="lightgray",
            opacity=0.5,
            name="Ground",
        )
    )

    # Create figure
    fig = go.Figure(data=traces)

    # Determine title
    if title is None:
        title = f"Robot Visualization: {structure_config.name}"

    # Update layout
    fig.update_layout(
        title=dict(
            text=title,
            x=0.5,
            font=dict(size=20),
        ),
        scene=dict(
            xaxis_title="X (m)",
            yaxis_title="Y (m)",
            zaxis_title="Z (m)",
            aspectmode="data",
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.0),
                up=dict(x=0, y=0, z=1),
            ),
        ),
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            bgcolor="rgba(255, 255, 255, 0.8)",
        ),
        margin=dict(l=0, r=0, t=50, b=0),
    )

    # Add annotation
    fig.add_annotation(
        text=(
            f"Bodies: {len(structure_config.bodies)}<br>"
            f"Joints: {len(structure_config.constraints)}<br>"
            f"Voxel size: {voxel_size*1000:.1f} mm"
        ),
        xref="paper",
        yref="paper",
        x=0.99,
        y=0.01,
        showarrow=False,
        font=dict(size=12),
        align="right",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="gray",
        borderwidth=1,
    )

    if output_file:
        fig.write_html(output_file)
        print(f"Visualization saved to: {output_file}")

    if show:
        fig.show()

    return fig


def load_robot_config(
    config_path: str, soft_shell_thickness: int = 1
) -> Tuple[Any, Optional[Dict[str, Any]]]:
    """
    Load a robot configuration from a pickle file.

    Args:
        config_path: Path to the .data file containing the pickled config
        soft_shell_thickness: The soft shell thickness used when creating the config
                              (needed to correctly infer rigid core dimensions)

    Returns:
        Tuple of (structure_config, structure_dict)
        structure_dict will be None if not available
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")

    with open(config_path, "rb") as f:
        structure_config = pickle.load(f)

    # Create body_infos with shell thickness for each body
    voxel_size = structure_config.voxel_size
    shell_thickness = soft_shell_thickness

    body_infos = []
    for body in structure_config.bodies:
        body_infos.append(
            {
                "soft_shell_thickness": shell_thickness,
            }
        )

    structure_dict = {
        "body_infos": body_infos,
        "soft_shell_thickness": shell_thickness,
        "voxel_size": voxel_size,
    }

    return structure_config, structure_dict


def get_available_configs(config_dir: str) -> List[str]:
    """
    List available robot config files in the config directory.

    Args:
        config_dir: Path to the config directory

    Returns:
        List of config file names (without path)
    """
    if not os.path.exists(config_dir):
        return []

    configs = []
    for f in os.listdir(config_dir):
        if f.endswith(".data"):
            configs.append(f)
    return sorted(configs)


def create_robot_config(
    robot_type: str, voxel_size: float, soft_shell_thickness: int, **kwargs
):
    """
    Create a robot configuration based on type.

    Args:
        robot_type: Type of robot ("quadruped", "ant", "biped")
        voxel_size: Size of each voxel in meters
        soft_shell_thickness: Soft shell thickness in voxels
        **kwargs: Additional arguments for specific robot types

    Returns:
        Tuple of (structure_config, structure_dict)
    """
    if robot_type == "quadruped":
        from create_quadruped_8dof import create_quadruped_8dof

        leg_angle = kwargs.get("leg_angle", 45.0)
        return create_quadruped_8dof(
            "quadruped_8dof",
            (0.0, 0.0, 0.0),
            "material_0",
            "material_1",
            voxel_size=voxel_size,
            leg_angle_deg=leg_angle,
            soft_shell_thickness=soft_shell_thickness,
        )
    elif robot_type == "ant":
        from create_ant_8dof import create_ant_8dof

        hip_offset = kwargs.get("hip_offset", 3)
        knee_offset = kwargs.get("knee_offset", 3)
        lower_leg_outward_offset = kwargs.get("lower_leg_outward_offset", 0)
        return create_ant_8dof(
            "ant_8dof",
            (0.0, 0.0, 0.0),
            "material_0",
            "material_1",
            voxel_size=voxel_size,
            soft_shell_thickness=soft_shell_thickness,
        )
    elif robot_type == "biped":
        from create_biped_6dof import create_biped_6dof

        upper_leg_y_offset = kwargs.get("upper_leg_y_offset", 1)
        lower_leg_z_offset = kwargs.get("lower_leg_z_offset", 1)
        foot_z_offset = kwargs.get("foot_z_offset", 1)

        return create_biped_6dof(
            "biped_6dof",
            (0.0, 0.0, 0.0),
            "material_0",
            voxel_size=voxel_size,
            soft_shell_thickness=soft_shell_thickness,
            upper_leg_y_offset_voxels=upper_leg_y_offset,
            lower_leg_z_offset_voxels=lower_leg_z_offset,
            foot_z_offset_voxels=foot_z_offset,
        )
    else:
        raise ValueError(f"Unknown robot type: {robot_type}")


if __name__ == "__main__":
    # Get the default config directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(os.path.dirname(script_dir))
    default_config_dir = os.path.join(project_root, "data", "robot_config")

    parser = argparse.ArgumentParser(
        description="Visualize robot configurations",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=f"""
Examples:
  # Load and visualize a saved config
  python visualize_robot.py --load ant_8dof.data

  # List available configs
  python visualize_robot.py --list

  # Create and visualize a new ant robot
  python visualize_robot.py --robot ant --hip-offset 4 --knee-offset 5

  # Load config with custom output file
  python visualize_robot.py --load quadruped_8dof.data -o my_robot.html
""",
    )

    # Config loading options
    parser.add_argument(
        "--load",
        "-l",
        type=str,
        default=None,
        help="Load config from data/robot_config/ (e.g., 'ant_8dof.data')",
    )
    parser.add_argument(
        "--config-dir",
        type=str,
        default=default_config_dir,
        help=f"Directory containing robot configs (default: {default_config_dir})",
    )
    parser.add_argument(
        "--list",
        action="store_true",
        help="List available robot configs and exit",
    )

    # Robot creation options (used when not loading)
    parser.add_argument(
        "--robot",
        "-r",
        type=str,
        choices=["quadruped", "ant", "biped"],
        default="ant",
        help="Type of robot to create (default: ant) - ignored if --load is used",
    )
    parser.add_argument(
        "--leg-angle",
        type=float,
        default=45.0,
        help="Angle of upper legs from vertical (degrees) - for quadruped only",
    )
    parser.add_argument(
        "--voxel-size",
        type=float,
        default=0.01,
        help="Size of each voxel in meters (used when creating new config)",
    )
    parser.add_argument(
        "--output",
        "-o",
        type=str,
        default=None,
        help="Output HTML file for visualization",
    )
    parser.add_argument(
        "--no-show",
        action="store_true",
        help="Don't display the visualization (only save to file)",
    )
    parser.add_argument(
        "--soft-shell",
        type=int,
        default=1,
        help="Soft voxel shell thickness around rigid core (used for both creating and loading configs)",
    )
    parser.add_argument(
        "--hip-offset",
        type=int,
        default=4,
        help="Gap between torso and upper leg in voxels (ant only)",
    )
    parser.add_argument(
        "--knee-offset",
        type=int,
        default=5,
        help="Gap between upper leg and lower leg in voxels (ant only)",
    )
    parser.add_argument(
        "--lower-leg-outward-offset",
        type=int,
        default=0,
        help="Outward offset for lower legs in voxels (ant only)",
    )
    parser.add_argument(
        "--upper-leg-y-offset",
        type=int,
        default=1,
        help="Y offset for upper legs in voxels (outward, biped only)",
    )
    parser.add_argument(
        "--lower-leg-z-offset",
        type=int,
        default=4,
        help="Z offset for lower legs in voxels (downward, biped only)",
    )
    parser.add_argument(
        "--foot-z-offset",
        type=int,
        default=4,
        help="Z offset for feet in voxels (downward, biped only)",
    )

    args = parser.parse_args()

    # Handle --list option
    if args.list:
        print("=" * 60)
        print("Available Robot Configs")
        print("=" * 60)
        print(f"Config directory: {args.config_dir}")
        print()
        configs = get_available_configs(args.config_dir)
        if configs:
            for config in configs:
                print(f"  - {config}")
            print()
            print(f"Use --load <config_name> to visualize a config")
        else:
            print("  No configs found in directory")
        exit(0)

    # Load or create robot configuration
    if args.load:
        # Load from file
        config_path = args.load
        if not os.path.isabs(config_path):
            config_path = os.path.join(args.config_dir, config_path)

        print("=" * 60)
        print("Loading Robot Configuration")
        print("=" * 60)
        print(f"Config file: {config_path}")
        print()

        try:
            structure_config, structure = load_robot_config(
                config_path, soft_shell_thickness=args.soft_shell
            )
            voxel_size = structure_config.voxel_size

            # Infer robot type from name
            name_lower = structure_config.name.lower()
            if "ant" in name_lower:
                robot_type = "ant"
            elif "quadruped" in name_lower:
                robot_type = "quadruped"
            elif "biped" in name_lower:
                robot_type = "biped"
            else:
                robot_type = "generic"

            print(f"  Name: {structure_config.name}")
            print(f"  Bodies: {len(structure_config.bodies)}")
            print(f"  Constraints: {len(structure_config.constraints)}")
            print(f"  Voxel size: {voxel_size * 1000:.1f} mm")
            print(f"  Soft shell thickness: {args.soft_shell} voxels")
            print(f"  Inferred type: {robot_type}")
            print()

        except Exception as e:
            print(f"Error loading config: {e}")
            exit(1)

        # Set default output file based on loaded config
        if args.output is None:
            config_name = os.path.splitext(os.path.basename(args.load))[0]
            args.output = f"{config_name}_visualization.html"

    else:
        # Create new configuration
        robot_type = args.robot
        voxel_size = args.voxel_size

        # Set default output file based on robot type
        if args.output is None:
            if args.robot == "biped":
                args.output = f"{args.robot}_6dof_visualization.html"
            else:
                args.output = f"{args.robot}_8dof_visualization.html"

        print("=" * 60)
        if args.robot == "biped":
            print(f"{args.robot.title()} 6-DOF Robot Visualization")
        else:
            print(f"{args.robot.title()} 8-DOF Robot Visualization")
        print("=" * 60)
        if args.robot == "quadruped":
            print(f"Upper leg angle: {args.leg_angle}°")
        elif args.robot == "ant":
            print(f"Hip offset: {args.hip_offset} voxels")
            print(f"Knee offset: {args.knee_offset} voxels")
            print(f"Lower leg outward offset: {args.lower_leg_outward_offset} voxels")
        elif args.robot == "biped":
            print(f"Upper leg Y offset: {args.upper_leg_y_offset} voxels (outward)")
            print(f"Lower leg Z offset: {args.lower_leg_z_offset} voxels (downward)")
            print(f"Foot Z offset: {args.foot_z_offset} voxels (downward)")
        print(f"Voxel size: {voxel_size * 1000:.1f} mm")
        print(f"Soft shell thickness: {args.soft_shell} voxels")
        print()

        # Create the robot configuration
        print("Creating robot configuration...")
        structure_config, structure = create_robot_config(
            robot_type=args.robot,
            voxel_size=voxel_size,
            soft_shell_thickness=args.soft_shell,
            leg_angle=args.leg_angle,
            hip_offset=args.hip_offset,
            knee_offset=args.knee_offset,
            lower_leg_outward_offset=args.lower_leg_outward_offset,
            upper_leg_y_offset=args.upper_leg_y_offset,
            lower_leg_z_offset=args.lower_leg_z_offset,
            foot_z_offset=args.foot_z_offset,
        )

        print(f"  Bodies: {len(structure_config.bodies)}")
        print(f"  Constraints: {len(structure_config.constraints)}")
        print()

    print(f"Output file: {args.output}")
    print()

    # Visualize the configuration
    print("Generating visualization...")
    visualize_structure_config(
        structure_config,
        structure_dict=structure,
        voxel_size=voxel_size,
        output_file=args.output,
        show=not args.no_show,
        robot_type=robot_type,
        title=f"Robot: {structure_config.name}",
    )
