from typing import Tuple
import os
import pickle
import numpy as np
import plotly.graph_objects as go
from plotly.colors import qualitative as plotly_qualitative
from scipy.ndimage import binary_dilation

from rise import (
    RQuat3rf,
    RVec3rf,
    RS_NULL_INDEX,
    RSE_StructureConstraintType,
    RS_StructureBodyConfig,
    RS_StructureConstraintConfig,
    RS_StructureConfig,
)


def create_quadruped_simple(
    structure_name: str,
    position: Tuple[float, float, float],
    material_name: str,
    orientation: Tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
    voxel_size: float = 0.01,
    is_fixed: bool = False,
) -> Tuple[RS_StructureConfig, dict]:
    """
    Build one structure with one body (64x64x64 voxels) and 5 segments:
    - Segment 0: torso box 25x15x10 voxels centered in the body space
    - Segments 1..4: four rectangular legs attached under torso corners
    - Four hip hinge joints (torso to each leg)
    """

    # Body voxel grid size
    x_voxels = 64
    y_voxels = 64
    z_voxels = 64

    structure_config = RS_StructureConfig()
    structure_config.name = structure_name
    structure_config.is_fixed = is_fixed
    structure_config.voxel_size = voxel_size
    structure_config.origin_position = RVec3rf(*position)
    structure_config.orientation = RQuat3rf(*orientation)
    structure_config.material_references.append(material_name)

    body_config = RS_StructureBodyConfig()
    body_config.body_sid = 0
    body_config.relative_origin_position = RVec3rf(0.0, 0.0, 0.0)
    body_config.relative_orientation = RQuat3rf(0.0, 0.0, 0.0, 1.0)
    body_config.x_voxels = x_voxels
    body_config.y_voxels = y_voxels
    body_config.z_voxels = z_voxels

    # Initialize structure arrays
    is_not_empty = np.zeros((x_voxels, y_voxels, z_voxels), dtype=bool)
    is_rigid = np.zeros((x_voxels, y_voxels, z_voxels), dtype=bool)
    segment_id = np.zeros((x_voxels, y_voxels, z_voxels), dtype=int)

    total_voxels = x_voxels * y_voxels * z_voxels
    for _ in range(total_voxels):
        body_config.material_reference_sid.append(RS_NULL_INDEX)
        body_config.segment_bid.append(RS_NULL_INDEX)
        body_config.segment_type.append(RS_NULL_INDEX)

    def linear_index(ix: int, iy: int, iz: int) -> int:
        return ix + iy * x_voxels + iz * x_voxels * y_voxels

    def center_index(n: int) -> float:
        return (n - 1) / 2.0

    def centered_range(total_n: int, box_n: int) -> Tuple[int, int]:
        c = center_index(total_n)
        # print(f"center index: {c}")
        start = int(np.floor(c - (box_n - 1) / 2.0))
        end = start + box_n - 1
        start = max(0, min(start, total_n - box_n))
        end = min(total_n - 1, start + box_n - 1)
        return start, end

    # # Torso dimensions:
    # torso_x, torso_y, torso_z = 23, 19, 8
    # # Legs dimensions:
    # leg_x, leg_y, leg_z = 4, 4, 18

    # Torso dimensions:
    torso_x, torso_y, torso_z = 35, 25, 2
    # Legs dimensions:
    leg_x, leg_y, leg_z = 4, 4, 20

    tx0, tx1 = centered_range(x_voxels, torso_x)
    ty0, ty1 = centered_range(y_voxels, torso_y)
    tz0, tz1 = centered_range(z_voxels, torso_z)

    for ix in range(tx0, tx1 + 1):
        for iy in range(ty0, ty1 + 1):
            for iz in range(tz0, tz1 + 1):
                idx = linear_index(ix, iy, iz)
                body_config.material_reference_sid[idx] = 0
                body_config.segment_bid[idx] = 1  # Torso segment 1
                body_config.segment_type[idx] = 1  # Rigid
                # structure arrays
                is_rigid[ix, iy, iz] = True
                segment_id[ix, iy, iz] = 1  # torso labeled as segment 1

    def range_from_center(center_ix: int, length: int, total_n: int) -> Tuple[int, int]:
        start = int(np.floor(center_ix - (length - 1) / 2.0))
        start = max(0, min(start, total_n - length))
        end = start + length - 1
        return start, end

    # Corner centers on the torso bottom face
    c_x = center_index(x_voxels)
    c_y = center_index(y_voxels)
    c_z = center_index(z_voxels)

    # Use exact torso edges for leg placement
    corner_centers = [
        (tx0, ty0),  # front-left
        (tx1, ty0),  # front-right
        (tx0, ty1),  # back-left
        (tx1, ty1),  # back-right
    ]

    # Place leg attachment at the torso's mid-plane along Z
    leg_top_z = (tz0 + tz1) // 2 + 1
    leg_bottom_z = max(0, leg_top_z - (leg_z - 1))

    for leg_idx, (cx_leg, cy_leg) in enumerate(corner_centers, start=1):
        # Keep X placement centered near the torso edge; shift only along Y outward
        lx0, lx1 = range_from_center(cx_leg, leg_x, x_voxels)

        # Compute Y-range adjacent to the torso outer face (shift outward)
        if cy_leg == ty0:
            # Outward in -Y
            ly1 = max(0, ty0 - 1)
            ly0 = max(0, ly1 - (leg_y - 1))
        else:
            # Outward in +Y
            ly0 = min(y_voxels - 1, ty1 + 1)
            ly1 = min(y_voxels - 1, ly0 + (leg_y - 1))
        for ix in range(lx0, lx1 + 1):
            for iy in range(ly0, ly1 + 1):
                for iz in range(leg_bottom_z, leg_top_z + 1):
                    idx = linear_index(ix, iy, iz)
                    body_config.material_reference_sid[idx] = 0
                    body_config.segment_bid[idx] = leg_idx + 1  # Legs 2..5
                    body_config.segment_type[idx] = 1  # Rigid
                    # structure arrays
                    is_rigid[ix, iy, iz] = True
                    segment_id[ix, iy, iz] = leg_idx + 1  # legs labeled 2..5

        # Create a 1-voxel dilation layer around rigid voxels as occupied (is_not_empty)
    is_not_empty = binary_dilation(
        is_rigid, structure=np.ones((3, 3, 3), dtype=bool), iterations=2
    )

    # Fill soft voxels (dilated but not rigid)
    for ix in range(x_voxels):
        for iy in range(y_voxels):
            for iz in range(z_voxels):
                if is_not_empty[ix, iy, iz] and not is_rigid[ix, iy, iz]:
                    idx = linear_index(ix, iy, iz)
                    body_config.material_reference_sid[idx] = 0
                    body_config.segment_bid[idx] = 0  # Soft matrix
                    body_config.segment_type[idx] = 0  # Soft

    structure_config.bodies.append(body_config)

    # Hip joints: hinge between torso (0) and each leg (1..4)
    def voxel_to_local(ix: int, iy: int, iz: int) -> Tuple[float, float, float]:
        return (
            ix * voxel_size,
            iy * voxel_size,
            iz * voxel_size,
        )

    hip_voxel_anchors = [
        (tx0, ty0, leg_top_z),
        (tx1, ty0, leg_top_z),
        (tx0, ty1, leg_top_z),
        (tx1, ty1, leg_top_z),
    ]

    for signal_sid, (ax, ay, az) in enumerate(hip_voxel_anchors):
        anchor = voxel_to_local(ax, ay, az)
        constraint = RS_StructureConstraintConfig()
        constraint.type = RSE_StructureConstraintType.RSE_HINGE_JOINT
        constraint.a_body_sid = 0
        constraint.b_body_sid = 0
        constraint.a_segment_bid = 1  # Torso
        constraint.b_segment_bid = signal_sid + 2  # Legs 2..5
        constraint.a_local_anchor = RVec3rf(*anchor)
        constraint.b_local_anchor = RVec3rf(*anchor)
        constraint.hinge_rotation_angle_signal_sid = signal_sid
        # Hinge around local Y axis for hip swing
        if signal_sid < 2:
            # First two joints keep current orientation
            constraint.hinge_a_local_axis = RVec3rf(0.0, 1.0, 0.0)
            constraint.hinge_b_local_axis = RVec3rf(0.0, -1.0, 0.0)
        else:
            # Last two joints face opposite direction
            constraint.hinge_a_local_axis = RVec3rf(0.0, -1.0, 0.0)
            constraint.hinge_b_local_axis = RVec3rf(0.0, 1.0, 0.0)
        constraint.hinge_min = -1.5
        constraint.hinge_max = 1.5

        constraint.hinge_max_torque = 6.0

        structure_config.constraints.append(constraint)

    structure_config.rotation_angle_signal_num = 4

    # Build connections in the same format as sim/builder.py
    # components use segment labels from segment_id (torso=1, legs=2..5)
    # position is in voxel indices; axis is hinge axis in local voxel frame
    connections = []
    for i, (ax, ay, az) in enumerate(hip_voxel_anchors):
        connections.append(
            {
                "components": (1, i + 2),
                "position": np.array([float(ax), float(ay), float(az)], dtype=float),
                "axis": np.array([0.0, (1.0 if i < 2 else -1.0), 0.0], dtype=float),
                "size": float(leg_x * leg_y),
            }
        )

    structure = {
        "is_not_empty": is_not_empty,
        "is_rigid": is_rigid,
        "segment_id": segment_id,
        "connections": connections,
    }

    # print(f"structure: {structure}")
    # print(f"is_not_empty: {is_not_empty}")
    # print(f"is_rigid: {is_rigid}")
    # print(f"segment_id: {segment_id}")

    return structure_config, structure


def visualize_body_voxels_plotly(
    body_config: RS_StructureBodyConfig,
    voxel_size: float,
    output_html_path: str,
    title: str | None = None,
) -> None:
    """
    Visualize occupied voxels in a body, colored by segment_bid, and save to HTML.

    Args:
        body_config: Body configuration containing voxel grid and segment assignments.
        voxel_size: Size of a single voxel (meters), used to scale coordinates.
        output_html_path: Path to save the interactive HTML visualization.
        title: Optional plot title.
    """
    x_voxels = int(body_config.x_voxels)
    y_voxels = int(body_config.y_voxels)
    z_voxels = int(body_config.z_voxels)

    # Use voxel-space origin (0,0,0) for coordinates

    # Collect points per segment id
    segment_to_points = {}
    for linear_idx, segment_id in enumerate(body_config.segment_bid):
        if segment_id == RS_NULL_INDEX:
            continue
        # Recover 3D indices from flattened index
        ix = linear_idx % x_voxels
        iy = (linear_idx // x_voxels) % y_voxels
        iz = linear_idx // (x_voxels * y_voxels)

        x = ix * voxel_size
        y = iy * voxel_size
        z = iz * voxel_size

        if segment_id not in segment_to_points:
            segment_to_points[segment_id] = {"x": [], "y": [], "z": []}
        segment_to_points[segment_id]["x"].append(x)
        segment_to_points[segment_id]["y"].append(y)
        segment_to_points[segment_id]["z"].append(z)

    if not segment_to_points:
        fig = go.Figure()
        fig.update_layout(title=title or "No occupied voxels to display")
        fig.write_html(output_html_path, include_plotlyjs="cdn", auto_open=False)
        return

    # Build a color palette large enough for all segments
    palette = (
        list(plotly_qualitative.Plotly)
        + list(plotly_qualitative.D3)
        + getattr(plotly_qualitative, "Dark24", [])
        + getattr(plotly_qualitative, "Set3", [])
    )

    traces = []
    for idx, segment_id in enumerate(sorted(segment_to_points.keys())):
        color = palette[idx % len(palette)] if palette else None
        pts = segment_to_points[segment_id]
        traces.append(
            go.Scatter3d(
                x=pts["x"],
                y=pts["y"],
                z=pts["z"],
                mode="markers",
                name=f"segment {segment_id}",
                marker=dict(size=3, color=color, opacity=0.9),
            )
        )

    # Add origin point and XYZ unit frame
    # Determine a reasonable axis length based on current geometry extents
    all_x = []
    all_y = []
    all_z = []
    for pts in segment_to_points.values():
        all_x.extend(pts["x"])  # type: ignore[arg-type]
        all_y.extend(pts["y"])  # type: ignore[arg-type]
        all_z.extend(pts["z"])  # type: ignore[arg-type]

    if all_x and all_y and all_z:
        extent_x = max(all_x) - min(all_x)
        extent_y = max(all_y) - min(all_y)
        extent_z = max(all_z) - min(all_z)
        max_extent = max(extent_x, extent_y, extent_z)
    else:
        max_extent = max(x_voxels, y_voxels, z_voxels) * voxel_size

    axis_len = max(2.0 * voxel_size, 0.15 * max_extent)

    # Origin marker
    traces.append(
        go.Scatter3d(
            x=[0.0],
            y=[0.0],
            z=[0.0],
            mode="markers",
            name="origin",
            marker=dict(size=5, color="black", opacity=1.0),
        )
    )

    # XYZ axes lines
    traces.append(
        go.Scatter3d(
            x=[0.0, axis_len],
            y=[0.0, 0.0],
            z=[0.0, 0.0],
            mode="lines+markers+text",
            text=["", "X"],
            textposition="top center",
            name="X axis",
            line=dict(color="red", width=6),
            marker=dict(size=2, color="red"),
        )
    )
    traces.append(
        go.Scatter3d(
            x=[0.0, 0.0],
            y=[0.0, axis_len],
            z=[0.0, 0.0],
            mode="lines+markers+text",
            text=["", "Y"],
            textposition="top center",
            name="Y axis",
            line=dict(color="green", width=6),
            marker=dict(size=2, color="green"),
        )
    )
    traces.append(
        go.Scatter3d(
            x=[0.0, 0.0],
            y=[0.0, 0.0],
            z=[0.0, axis_len],
            mode="lines+markers+text",
            text=["", "Z"],
            textposition="top center",
            name="Z axis",
            line=dict(color="blue", width=6),
            marker=dict(size=2, color="blue"),
        )
    )

    fig = go.Figure(data=traces)
    fig.update_layout(
        title=title or f"Body {getattr(body_config, 'body_sid', 0)} voxels by segment",
        scene=dict(
            xaxis=dict(visible=True, showticklabels=True, title="X"),
            yaxis=dict(visible=True, showticklabels=True, title="Y"),
            zaxis=dict(visible=True, showticklabels=True, title="Z"),
            aspectmode="data",
        ),
        legend=dict(itemsizing="constant"),
        margin=dict(l=0, r=0, t=40, b=0),
    )

    fig.write_html(output_html_path, include_plotlyjs="cdn", auto_open=False)


def visualize_boolean_voxels_plotly(
    mask: np.ndarray,
    voxel_size: float,
    output_html_path: str,
    title: str | None = None,
) -> None:
    """
    Visualize a 3D boolean voxel mask using Plotly as a point cloud.

    Args:
        mask: Boolean array of shape (x, y, z).
        voxel_size: Size of a single voxel (meters).
        output_html_path: Path to save the HTML visualization.
        title: Optional title.
    """
    assert mask.ndim == 3, "mask must be 3D"
    x_voxels, y_voxels, z_voxels = mask.shape

    ix, iy, iz = np.nonzero(mask)
    if ix.size == 0:
        fig = go.Figure()
        fig.update_layout(title=title or "No voxels to display")
        fig.write_html(output_html_path, include_plotlyjs="cdn", auto_open=False)
        return

    x = ix.astype(float) * voxel_size
    y = iy.astype(float) * voxel_size
    z = iz.astype(float) * voxel_size

    extent_x = float(np.max(x) - np.min(x)) if x.size else 0.0
    extent_y = float(np.max(y) - np.min(y)) if y.size else 0.0
    extent_z = float(np.max(z) - np.min(z)) if z.size else 0.0
    max_extent = max(extent_x, extent_y, extent_z)
    if max_extent == 0.0:
        max_extent = max(x_voxels, y_voxels, z_voxels) * voxel_size
    axis_len = max(2.0 * voxel_size, 0.15 * max_extent)

    traces = [
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="markers",
            name="voxels",
            marker=dict(size=3, color="#1f77b4", opacity=0.9),
        ),
        go.Scatter3d(
            x=[0.0],
            y=[0.0],
            z=[0.0],
            mode="markers",
            name="origin",
            marker=dict(size=5, color="black", opacity=1.0),
        ),
        go.Scatter3d(
            x=[0.0, axis_len],
            y=[0.0, 0.0],
            z=[0.0, 0.0],
            mode="lines+markers+text",
            text=["", "X"],
            textposition="top center",
            name="X axis",
            line=dict(color="red", width=6),
            marker=dict(size=2, color="red"),
        ),
        go.Scatter3d(
            x=[0.0, 0.0],
            y=[0.0, axis_len],
            z=[0.0, 0.0],
            mode="lines+markers+text",
            text=["", "Y"],
            textposition="top center",
            name="Y axis",
            line=dict(color="green", width=6),
            marker=dict(size=2, color="green"),
        ),
        go.Scatter3d(
            x=[0.0, 0.0],
            y=[0.0, 0.0],
            z=[0.0, axis_len],
            mode="lines+markers+text",
            text=["", "Z"],
            textposition="top center",
            name="Z axis",
            line=dict(color="blue", width=6),
            marker=dict(size=2, color="blue"),
        ),
    ]

    fig = go.Figure(data=traces)
    fig.update_layout(
        title=title or "Boolean voxel mask",
        scene=dict(
            xaxis=dict(visible=True, showticklabels=True, title="X"),
            yaxis=dict(visible=True, showticklabels=True, title="Y"),
            zaxis=dict(visible=True, showticklabels=True, title="Z"),
            aspectmode="data",
        ),
        legend=dict(itemsizing="constant"),
        margin=dict(l=0, r=0, t=40, b=0),
    )

    fig.write_html(output_html_path, include_plotlyjs="cdn", auto_open=False)


if __name__ == "__main__":
    structure_config, structure = create_quadruped_simple(
        "quadruped_simple", (0.0, 0.0, 0.0), "material_0"
    )
    visualize_body_voxels_plotly(
        structure_config.bodies[0], 0.01, "quadruped_simple.html"
    )
    visualize_boolean_voxels_plotly(
        structure["is_rigid"], 0.01, "quadruped_simple_is_rigid.html", title="is_rigid"
    )
    visualize_boolean_voxels_plotly(
        structure["is_not_empty"],
        0.01,
        "quadruped_simple_is_not_empty.html",
        title="is_not_empty",
    )

    output_path = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "data",
        "robot_config",
        "quadruped_simple.data",
    )
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "wb") as f:
        pickle.dump(structure_config, f)
