"""
Visualize the quadruped 8-DOF robot using plotly.

This script creates an interactive 3D visualization of the quadruped robot
by reading the geometry directly from the robot configuration.
"""

import argparse
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
from scipy.spatial.transform import Rotation

# Import the robot creation function
from create_quadruped_8dof import create_quadruped_8dof


def rotate_points_by_quaternion(
    points: np.ndarray, quat: Tuple[float, float, float, float]
) -> np.ndarray:
    """
    Rotate multiple points by a quaternion.

    Args:
        points: Array of 3D points (N, 3)
        quat: Quaternion as (x, y, z, w)

    Returns:
        Rotated points (N, 3)
    """
    r = Rotation.from_quat(quat)
    return r.apply(points)


def get_box_vertices(
    size: Tuple[float, float, float],
    origin: Tuple[float, float, float] = (0.0, 0.0, 0.0),
) -> np.ndarray:
    """
    Generate the 8 vertices of a box.

    Args:
        size: Box dimensions (x, y, z)
        origin: Position of the box corner (0,0,0)

    Returns:
        Array of 8 vertices (8, 3)
    """
    x, y, z = size
    ox, oy, oz = origin
    vertices = np.array(
        [
            [ox, oy, oz],  # 0: origin corner
            [ox + x, oy, oz],  # 1
            [ox + x, oy + y, oz],  # 2
            [ox, oy + y, oz],  # 3
            [ox, oy, oz + z],  # 4
            [ox + x, oy, oz + z],  # 5
            [ox + x, oy + y, oz + z],  # 6
            [ox, oy + y, oz + z],  # 7
        ]
    )
    return vertices


def get_box_faces() -> List[List[int]]:
    """
    Get the face indices for a box mesh.

    Returns:
        List of faces, each face is a list of 4 vertex indices
    """
    return [
        [0, 1, 2, 3],  # Bottom face
        [4, 5, 6, 7],  # Top face
        [0, 1, 5, 4],  # Front face
        [2, 3, 7, 6],  # Back face
        [0, 3, 7, 4],  # Left face
        [1, 2, 6, 5],  # Right face
    ]


def create_box_mesh_data(
    size: Tuple[float, float, float],
    origin: Tuple[float, float, float],
    orientation: Tuple[float, float, float, float],
    color: str = "blue",
    opacity: float = 0.7,
    name: str = "body",
) -> Dict[str, Any]:
    """
    Create mesh data for a 3D box using plotly Mesh3d format.

    Args:
        size: Box dimensions (x, y, z) in meters
        origin: Position of the box corner in structure space
        orientation: Quaternion (x, y, z, w) for body rotation
        color: Color of the mesh
        opacity: Opacity of the mesh
        name: Name for the trace

    Returns:
        Dictionary with mesh data for plotly
    """
    # Get local vertices
    vertices = get_box_vertices(size)

    # Rotate vertices by quaternion
    rotated_vertices = rotate_points_by_quaternion(vertices, orientation)

    # Translate to world position
    world_vertices = rotated_vertices + np.array(origin)

    # Define triangular faces (each quad face becomes 2 triangles)
    i_indices = []
    j_indices = []
    k_indices = []

    quad_faces = get_box_faces()
    for face in quad_faces:
        # First triangle
        i_indices.append(face[0])
        j_indices.append(face[1])
        k_indices.append(face[2])
        # Second triangle
        i_indices.append(face[0])
        j_indices.append(face[2])
        k_indices.append(face[3])

    return {
        "type": "mesh3d",
        "x": world_vertices[:, 0].tolist(),
        "y": world_vertices[:, 1].tolist(),
        "z": world_vertices[:, 2].tolist(),
        "i": i_indices,
        "j": j_indices,
        "k": k_indices,
        "color": color,
        "opacity": opacity,
        "name": name,
        "flatshading": True,
    }


def visualize_structure_config(
    structure_config,
    structure_dict: Optional[Dict[str, Any]] = None,
    voxel_size: float = 0.01,
    output_file: Optional[str] = None,
    show: bool = True,
):
    """
    Visualize a robot structure config using plotly.

    This function reads the body positions and orientations directly from
    the structure config created by create_quadruped_8dof.

    Args:
        structure_config: The RS_StructureConfig object from create_quadruped_8dof
        structure_dict: Optional structure dictionary containing body_infos for soft shell rendering
        voxel_size: Size of each voxel in meters
        output_file: Optional path to save the HTML visualization
        show: Whether to display the figure
    """
    try:
        import plotly.graph_objects as go
    except ImportError:
        print(
            "Error: plotly is required for visualization. "
            "Install with: pip install plotly"
        )
        return

    # Colors for different body parts (rigid core)
    body_colors = [
        "royalblue",  # Body 0: Torso
        "forestgreen",  # Body 1: Upper leg FL
        "forestgreen",  # Body 2: Upper leg FR
        "forestgreen",  # Body 3: Upper leg BL
        "forestgreen",  # Body 4: Upper leg BR
        "darkorange",  # Body 5: Lower leg FL
        "darkorange",  # Body 6: Lower leg FR
        "darkorange",  # Body 7: Lower leg BL
        "darkorange",  # Body 8: Lower leg BR
    ]

    # Single transparent color for all soft voxel shells
    soft_shell_color = "rgba(173, 216, 230, 0.3)"  # Light blue with transparency

    body_names = [
        "Torso",
        "Upper Leg Front-Left",
        "Upper Leg Front-Right",
        "Upper Leg Back-Left",
        "Upper Leg Back-Right",
        "Lower Leg Front-Left",
        "Lower Leg Front-Right",
        "Lower Leg Back-Left",
        "Lower Leg Back-Right",
    ]

    traces = []

    # Extract joint positions for markers
    hip_positions = []
    knee_positions = []

    # Get body infos if available
    body_infos = structure_dict.get("body_infos", []) if structure_dict else []
    soft_shell_thickness = (
        structure_dict.get("soft_shell_thickness", 0) if structure_dict else 0
    )

    # Iterate through all bodies in the structure config
    for i, body in enumerate(structure_config.bodies):
        # Get body dimensions in meters (this is the total size including soft shell)
        total_body_size = (
            body.x_voxels * voxel_size,
            body.y_voxels * voxel_size,
            body.z_voxels * voxel_size,
        )

        # Get body position (origin) - this is the corner of the full body (including soft shell)
        pos = body.relative_origin_position
        body_origin = (pos.x, pos.y, pos.z)

        # Get body orientation (quaternion x, y, z, w)
        ori = body.relative_orientation
        body_quat = (ori.x, ori.y, ori.z, ori.w)

        # Get body info for rigid core dimensions
        if i < len(body_infos) and body_infos[i]["soft_shell_thickness"] > 0:
            body_info = body_infos[i]
            rigid_size = body_info["rigid_size"]
            shell_thickness = body_info["soft_shell_thickness"]
            shell_offset = shell_thickness * voxel_size

            # Rigid core size in meters
            rigid_body_size = (
                rigid_size[0] * voxel_size,
                rigid_size[1] * voxel_size,
                rigid_size[2] * voxel_size,
            )

            # Rigid core origin is offset from body origin by shell_thickness
            # In the body's local coordinate system (before rotation)
            rigid_origin_local = np.array([shell_offset, shell_offset, shell_offset])

            # Transform to world coordinates
            rigid_origin_rotated = rotate_points_by_quaternion(
                rigid_origin_local.reshape(1, 3), body_quat
            ).flatten()
            rigid_origin_world = np.array(body_origin) + rigid_origin_rotated

            # Draw soft shell (full body, transparent)
            soft_mesh_data = create_box_mesh_data(
                size=total_body_size,
                origin=body_origin,
                orientation=body_quat,
                color=soft_shell_color,
                opacity=0.3,
                name=f"Soft Shell - {body_names[i] if i < len(body_names) else f'Body {body.body_sid}'}",
            )
            traces.append(go.Mesh3d(**soft_mesh_data))

            # Draw rigid core (inner, solid color)
            color = body_colors[i] if i < len(body_colors) else "gray"
            name = (
                f"{body_names[i]} (Body {body.body_sid})"
                if i < len(body_names)
                else f"Body {body.body_sid}"
            )

            rigid_mesh_data = create_box_mesh_data(
                size=rigid_body_size,
                origin=tuple(rigid_origin_world),
                orientation=body_quat,
                color=color,
                opacity=0.9,
                name=name,
            )
            traces.append(go.Mesh3d(**rigid_mesh_data))
        else:
            # No soft shell, draw body as before
            color = body_colors[i] if i < len(body_colors) else "gray"
            name = (
                f"{body_names[i]} (Body {body.body_sid})"
                if i < len(body_names)
                else f"Body {body.body_sid}"
            )

            mesh_data = create_box_mesh_data(
                size=total_body_size,
                origin=body_origin,
                orientation=body_quat,
                color=color,
                opacity=0.8,
                name=name,
            )
            traces.append(go.Mesh3d(**mesh_data))

    # Create a mapping from body_sid to body object
    # (bodies list order may differ from body_sid values)
    body_by_sid = {body.body_sid: body for body in structure_config.bodies}

    # Extract joint positions and axes from constraints
    hip_axes = []
    knee_axes = []

    for constraint in structure_config.constraints:
        # Get the anchor position from body A (torso or upper leg)
        anchor = constraint.a_local_anchor
        body_a_sid = constraint.a_body_sid

        # Get the body by its body_sid (NOT by list index!)
        body_a = body_by_sid[body_a_sid]
        body_a_pos = np.array(
            [
                body_a.relative_origin_position.x,
                body_a.relative_origin_position.y,
                body_a.relative_origin_position.z,
            ]
        )
        body_a_quat = (
            body_a.relative_orientation.x,
            body_a.relative_orientation.y,
            body_a.relative_orientation.z,
            body_a.relative_orientation.w,
        )

        # Transform anchor from body local coords to world coords
        anchor_local = np.array([anchor.x, anchor.y, anchor.z])
        r = Rotation.from_quat(body_a_quat)
        anchor_world = body_a_pos + r.apply(anchor_local)

        # Transform axis from body local coords to world coords
        axis = constraint.hinge_a_local_axis
        axis_local = np.array([axis.x, axis.y, axis.z])
        axis_world = r.apply(axis_local)

        # Classify as hip or knee joint based on which bodies are connected
        if constraint.a_body_sid == 0:  # Torso to upper leg = hip joint
            hip_positions.append(anchor_world)
            hip_axes.append(axis_world)
        else:  # Upper leg to lower leg = knee joint
            knee_positions.append(anchor_world)
            knee_axes.append(axis_world)

    # Add hip joint markers
    if hip_positions:
        hip_x = [p[0] for p in hip_positions]
        hip_y = [p[1] for p in hip_positions]
        hip_z = [p[2] for p in hip_positions]
        traces.append(
            go.Scatter3d(
                x=hip_x,
                y=hip_y,
                z=hip_z,
                mode="markers",
                marker=dict(size=8, color="red", symbol="circle"),
                name="Hip Joints",
            )
        )

    # Add knee joint markers
    if knee_positions:
        knee_x = [p[0] for p in knee_positions]
        knee_y = [p[1] for p in knee_positions]
        knee_z = [p[2] for p in knee_positions]
        traces.append(
            go.Scatter3d(
                x=knee_x,
                y=knee_y,
                z=knee_z,
                mode="markers",
                marker=dict(size=8, color="purple", symbol="circle"),
                name="Knee Joints",
            )
        )

    # Add joint axes (now using world-space axes)
    axis_length = 0.03
    for i, (joint_pos, axis_dir) in enumerate(zip(hip_positions, hip_axes)):
        # Normalize axis direction
        if np.linalg.norm(axis_dir) > 0:
            axis_dir = axis_dir / np.linalg.norm(axis_dir)

        # Hip axis
        hip_end = joint_pos + axis_dir * axis_length
        traces.append(
            go.Scatter3d(
                x=[joint_pos[0], hip_end[0]],
                y=[joint_pos[1], hip_end[1]],
                z=[joint_pos[2], hip_end[2]],
                mode="lines",
                line=dict(width=4, color="yellow"),
                name="Hip Axis" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    for i, (joint_pos, axis_dir) in enumerate(zip(knee_positions, knee_axes)):
        # Normalize axis direction
        if np.linalg.norm(axis_dir) > 0:
            axis_dir = axis_dir / np.linalg.norm(axis_dir)

        # Knee axis
        knee_end = joint_pos + axis_dir * axis_length
        traces.append(
            go.Scatter3d(
                x=[joint_pos[0], knee_end[0]],
                y=[joint_pos[1], knee_end[1]],
                z=[joint_pos[2], knee_end[2]],
                mode="lines",
                line=dict(width=4, color="cyan"),
                name="Knee Axis" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    # Add ground plane
    ground_size = 0.5
    ground_vertices = np.array(
        [
            [-ground_size, -ground_size, 0],
            [ground_size, -ground_size, 0],
            [ground_size, ground_size, 0],
            [-ground_size, ground_size, 0],
        ]
    )
    traces.append(
        go.Mesh3d(
            x=ground_vertices[:, 0],
            y=ground_vertices[:, 1],
            z=ground_vertices[:, 2],
            i=[0, 0],
            j=[1, 2],
            k=[2, 3],
            color="lightgray",
            opacity=0.5,
            name="Ground",
        )
    )

    # Create figure
    fig = go.Figure(data=traces)

    # Update layout
    fig.update_layout(
        title=dict(
            text=f"Quadruped 8-DOF Robot: {structure_config.name}",
            x=0.5,
            font=dict(size=20),
        ),
        scene=dict(
            xaxis_title="X (m)",
            yaxis_title="Y (m)",
            zaxis_title="Z (m)",
            aspectmode="data",
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.0),
                up=dict(x=0, y=0, z=1),
            ),
        ),
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            bgcolor="rgba(255, 255, 255, 0.8)",
        ),
        margin=dict(l=0, r=0, t=50, b=0),
    )

    # Add annotation
    fig.add_annotation(
        text=(
            f"Bodies: {len(structure_config.bodies)}<br>"
            f"Joints: {len(structure_config.constraints)}<br>"
            f"Voxel size: {voxel_size*1000:.1f} mm"
        ),
        xref="paper",
        yref="paper",
        x=0.99,
        y=0.01,
        showarrow=False,
        font=dict(size=12),
        align="right",
        bgcolor="rgba(255, 255, 255, 0.8)",
        bordercolor="gray",
        borderwidth=1,
    )

    if output_file:
        fig.write_html(output_file)
        print(f"Visualization saved to: {output_file}")

    if show:
        fig.show()

    return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize the quadruped 8-DOF robot")
    parser.add_argument(
        "--leg-angle",
        type=float,
        default=45.0,
        help="Angle of upper legs from vertical (degrees)",
    )
    parser.add_argument(
        "--voxel-size",
        type=float,
        default=0.01,
        help="Size of each voxel in meters",
    )
    parser.add_argument(
        "--output",
        "-o",
        type=str,
        default="quadruped_8dof_visualization.html",
        help="Output HTML file for visualization",
    )
    parser.add_argument(
        "--no-show",
        action="store_true",
        help="Don't display the visualization (only save to file)",
    )
    parser.add_argument(
        "--soft-shell",
        type=int,
        default=1,
        help="Soft voxel shell thickness around rigid core (0 to disable)",
    )

    args = parser.parse_args()

    print("=" * 60)
    print("Quadruped 8-DOF Robot Visualization")
    print("=" * 60)
    print(f"Upper leg angle: {args.leg_angle}°")
    print(f"Voxel size: {args.voxel_size * 1000:.1f} mm")
    print(f"Soft shell thickness: {args.soft_shell} voxels")
    print(f"Output file: {args.output}")
    print()

    # Create the robot configuration
    print("Creating robot configuration...")
    structure_config, structure = create_quadruped_8dof(
        "quadruped_8dof",
        (0.0, 0.0, 0.0),
        "material_0",
        voxel_size=args.voxel_size,
        leg_angle_deg=args.leg_angle,
        soft_shell_thickness=args.soft_shell,
    )

    print(f"  Bodies: {len(structure_config.bodies)}")
    print(f"  Constraints: {len(structure_config.constraints)}")
    print()

    # Visualize the configuration
    print("Generating visualization...")
    visualize_structure_config(
        structure_config,
        structure_dict=structure,
        voxel_size=args.voxel_size,
        output_file=args.output,
        show=not args.no_show,
    )
