"""
Create an ant robot with torso, four upper legs, and four lower legs (8 DOF).

The robot consists of:
- 1 spherical torso body (body_sid=0)
- 4 upper leg bodies (body_sid=1,2,3,4) extending along +X, +Y, -X, -Y axes
- 4 lower leg bodies (body_sid=5,6,7,8) all extending along -Z axis (downward)

Key difference from create_ant_8dof.py:
- Uses only TRANSLATION to position bodies (like biped approach)
- No rotation applied to body orientations (all bodies use identity quaternion)
- Each leg extends along a different axis in its local frame
"""

import os
import pickle
from typing import Tuple

import numpy as np
from scipy.spatial.transform import Rotation

from rise import (
    RQuat3rf,
    RVec3rf,
    RS_NULL_INDEX,
    RSE_StructureConstraintType,
    RS_StructureBodyConfig,
    RS_StructureConstraintConfig,
    RS_StructureConfig,
)


def create_body_config(
    body_sid: int,
    x_voxels: int,
    y_voxels: int,
    z_voxels: int,
    relative_origin_position: Tuple[float, float, float],
    relative_orientation: Tuple[float, float, float, float],
    is_rigid: bool = True,
    soft_shell_thickness: int = 1,
    voxel_size: float = 0.01,
) -> Tuple[RS_StructureBodyConfig, dict]:
    """
    Create a body config with rigid core and optional soft voxel shell.

    Args:
        body_sid: Body structure ID
        x_voxels, y_voxels, z_voxels: Voxel grid dimensions for the rigid core
        relative_origin_position: Position where the TOTAL body corner (including soft shell)
                                  should be placed in structure coordinates
        relative_orientation: Orientation quaternion (x, y, z, w)
        is_rigid: Whether the core voxels are rigid (True) or soft (False)
        soft_shell_thickness: Number of soft voxel layers to add around the rigid core
        voxel_size: Size of each voxel in meters

    Returns:
        Tuple of (body_config, body_info) where body_info contains:
            - rigid_size: (x, y, z) dimensions of rigid core in voxels
            - total_size: (x, y, z) dimensions including soft shell in voxels
            - soft_shell_thickness: thickness of soft shell in voxels
    """
    body_config = RS_StructureBodyConfig()
    body_config.body_sid = body_sid

    # Calculate total dimensions including soft shell
    total_x = x_voxels + 2 * soft_shell_thickness
    total_y = y_voxels + 2 * soft_shell_thickness
    total_z = z_voxels + 2 * soft_shell_thickness

    body_config.relative_origin_position = RVec3rf(*relative_origin_position)
    body_config.relative_orientation = RQuat3rf(*relative_orientation)
    body_config.x_voxels = total_x
    body_config.y_voxels = total_y
    body_config.z_voxels = total_z

    soft_segment_bid = 0
    rigid_segment_bid = 1

    for iz in range(total_z):
        for iy in range(total_y):
            for ix in range(total_x):
                if soft_shell_thickness > 0:
                    is_in_shell = (
                        ix < soft_shell_thickness
                        or ix >= total_x - soft_shell_thickness
                        or iy < soft_shell_thickness
                        or iy >= total_y - soft_shell_thickness
                        or iz < soft_shell_thickness
                        or iz >= total_z - soft_shell_thickness
                    )
                    if is_in_shell:
                        body_config.material_reference_sid.append(0)
                        body_config.segment_bid.append(soft_segment_bid)
                        body_config.segment_type.append(0)
                    else:
                        body_config.material_reference_sid.append(1)
                        body_config.segment_bid.append(rigid_segment_bid)
                        body_config.segment_type.append(1 if is_rigid else 0)
                else:
                    body_config.material_reference_sid.append(1)
                    body_config.segment_bid.append(rigid_segment_bid)
                    body_config.segment_type.append(1 if is_rigid else 0)

    body_info = {
        "rigid_size": (x_voxels, y_voxels, z_voxels),
        "total_size": (total_x, total_y, total_z),
        "soft_shell_thickness": soft_shell_thickness,
    }

    return body_config, body_info


def create_spherical_body_config(
    body_sid: int,
    radius_voxels: int,
    relative_origin_position: Tuple[float, float, float],
    relative_orientation: Tuple[float, float, float, float],
    is_rigid: bool = True,
    soft_shell_thickness: int = 0,
    voxel_size: float = 0.01,
) -> Tuple[RS_StructureBodyConfig, dict]:
    """
    Create a spherical body config with rigid core and optional soft voxel shell.

    Only voxels within the sphere radius are filled, creating an actual spherical shape.

    Args:
        body_sid: Body structure ID
        radius_voxels: Radius of the rigid sphere core in voxels
        relative_origin_position: Position where the TOTAL body corner (including soft shell)
                                  should be placed in structure coordinates
        relative_orientation: Orientation quaternion (x, y, z, w)
        is_rigid: Whether the core voxels are rigid (True) or soft (False)
        soft_shell_thickness: Number of soft voxel layers to add around the rigid core
        voxel_size: Size of each voxel in meters

    Returns:
        Tuple of (body_config, body_info) where body_info contains:
            - shape: "sphere"
            - rigid_radius: radius of rigid core in voxels
            - total_radius: radius including soft shell in voxels
            - soft_shell_thickness: thickness of soft shell in voxels
    """
    body_config = RS_StructureBodyConfig()
    body_config.body_sid = body_sid

    # Calculate total dimensions including soft shell
    # The grid size needs to be the diameter plus soft shell on both sides
    total_radius_voxels = radius_voxels + soft_shell_thickness
    grid_size = total_radius_voxels * 2

    body_config.relative_origin_position = RVec3rf(*relative_origin_position)
    body_config.relative_orientation = RQuat3rf(*relative_orientation)
    body_config.x_voxels = grid_size
    body_config.y_voxels = grid_size
    body_config.z_voxels = grid_size

    center = np.array([total_radius_voxels, total_radius_voxels, total_radius_voxels])
    soft_segment_bid = 0
    rigid_segment_bid = 1

    for iz in range(grid_size):
        for iy in range(grid_size):
            for ix in range(grid_size):
                voxel_pos = np.array([ix + 0.5, iy + 0.5, iz + 0.5])
                distance = np.linalg.norm(voxel_pos - center)

                if distance <= radius_voxels:
                    body_config.material_reference_sid.append(1)
                    body_config.segment_bid.append(rigid_segment_bid)
                    body_config.segment_type.append(1 if is_rigid else 0)
                elif distance <= total_radius_voxels:
                    body_config.material_reference_sid.append(0)
                    body_config.segment_bid.append(soft_segment_bid)
                    body_config.segment_type.append(0)
                else:
                    body_config.material_reference_sid.append(RS_NULL_INDEX)
                    body_config.segment_bid.append(RS_NULL_INDEX)
                    body_config.segment_type.append(RS_NULL_INDEX)

    body_info = {
        "shape": "sphere",
        "rigid_radius": radius_voxels,
        "total_radius": total_radius_voxels,
        "soft_shell_thickness": soft_shell_thickness,
    }

    return body_config, body_info


def create_ant_8dof_fix(
    structure_name: str,
    position: Tuple[float, float, float],
    soft_material_name: str,
    rigid_material_name: str,
    orientation: Tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
    voxel_size: float = 0.01,
    is_fixed: bool = False,
    soft_shell_thickness: int = 1,
    torso_radius_voxels: int = 6,
    upper_leg_length_voxels: int = 12,
    leg_thickness_voxels: int = 2,
    gap_voxels: int = 2,
    knee_gap_voxels: int = 0,
) -> Tuple[RS_StructureConfig, dict]:
    """
    Create an ant robot with spherical torso, 4 upper legs, and 4 lower legs (8 DOF).

    Upper legs extend along different axes in their local frames:
    - Upper Leg 0: extends along +X axis (right)
    - Upper Leg 1: extends along +Y axis (forward)
    - Upper Leg 2: extends along -X axis (left)
    - Upper Leg 3: extends along -Y axis (backward)

    All lower legs extend along -Z axis (downward).
    All bodies use identity orientation - positioning done via translation only.
    """

    structure_config = RS_StructureConfig()
    structure_config.name = structure_name
    structure_config.is_fixed = is_fixed
    structure_config.voxel_size = voxel_size
    structure_config.origin_position = RVec3rf(*position)
    structure_config.orientation = RQuat3rf(*orientation)
    structure_config.material_references.append(soft_material_name)
    structure_config.material_references.append(rigid_material_name)

    total_torso_radius_voxels = torso_radius_voxels + soft_shell_thickness
    total_torso_size_voxels = total_torso_radius_voxels * 2
    total_torso_size = total_torso_size_voxels * voxel_size
    total_torso_radius = total_torso_radius_voxels * voxel_size

    torso_origin = (0.0, 0.0, 0.2)
    torso_center_struct = np.array(
        [total_torso_radius, total_torso_radius, total_torso_radius]
    ) + np.array(torso_origin)

    all_bodies = []

    torso_body, torso_info = create_spherical_body_config(
        body_sid=0,
        radius_voxels=torso_radius_voxels,
        relative_origin_position=torso_origin,
        relative_orientation=(0.0, 0.0, 0.0, 1.0),
        soft_shell_thickness=soft_shell_thickness,
        voxel_size=voxel_size,
    )
    all_bodies.append((0, torso_body, torso_info))

    gap_distance = gap_voxels * voxel_size

    leg_configs = [
        (
            0,
            "+X",
            (upper_leg_length_voxels, leg_thickness_voxels, leg_thickness_voxels),
        ),
        (
            1,
            "+Y",
            (leg_thickness_voxels, upper_leg_length_voxels, leg_thickness_voxels),
        ),
        (
            2,
            "-X",
            (upper_leg_length_voxels, leg_thickness_voxels, leg_thickness_voxels),
        ),
        (
            3,
            "-Y",
            (leg_thickness_voxels, upper_leg_length_voxels, leg_thickness_voxels),
        ),
    ]

    for leg_idx, extend_axis, (leg_x_voxels, leg_y_voxels, leg_z_voxels) in leg_configs:
        leg_total_size = (
            np.array(
                [
                    leg_x_voxels + 2 * soft_shell_thickness,
                    leg_y_voxels + 2 * soft_shell_thickness,
                    leg_z_voxels + 2 * soft_shell_thickness,
                ]
            )
            * voxel_size
        )

        if extend_axis == "+X":
            hip_anchor_struct = torso_center_struct + np.array(
                [total_torso_radius + gap_distance, 0.0, 0.0]
            )
            leg_origin_struct = hip_anchor_struct + np.array(
                [0.0, -leg_total_size[1] / 2.0, -leg_total_size[2] / 2.0]
            )
            upper_leg_origin_to_hip_anchor_local = (
                0.0,
                leg_total_size[1] / 2.0,
                leg_total_size[2] / 2.0,
            )
            hinge_a_axis = np.array([0.0, 0.0, 1.0])
            hinge_b_axis = np.array([0.0, 0.0, -1.0])

        elif extend_axis == "+Y":
            hip_anchor_struct = torso_center_struct + np.array(
                [0.0, total_torso_radius + gap_distance, 0.0]
            )
            leg_origin_struct = hip_anchor_struct + np.array(
                [-leg_total_size[0] / 2.0, 0.0, -leg_total_size[2] / 2.0]
            )
            upper_leg_origin_to_hip_anchor_local = (
                leg_total_size[0] / 2.0,
                0.0,
                leg_total_size[2] / 2.0,
            )
            hinge_a_axis = np.array([0.0, 0.0, 1.0])
            hinge_b_axis = np.array([0.0, 0.0, -1.0])

        elif extend_axis == "-X":
            hip_anchor_struct = torso_center_struct + np.array(
                [-(total_torso_radius + gap_distance), 0.0, 0.0]
            )
            leg_origin_struct = hip_anchor_struct + np.array(
                [-leg_total_size[0], -leg_total_size[1] / 2.0, -leg_total_size[2] / 2.0]
            )
            upper_leg_origin_to_hip_anchor_local = (
                leg_total_size[0],
                leg_total_size[1] / 2.0,
                leg_total_size[2] / 2.0,
            )
            hinge_a_axis = np.array([0.0, 0.0, 1.0])
            hinge_b_axis = np.array([0.0, 0.0, -1.0])

        else:
            hip_anchor_struct = torso_center_struct + np.array(
                [0.0, -(total_torso_radius + gap_distance), 0.0]
            )
            leg_origin_struct = hip_anchor_struct + np.array(
                [-leg_total_size[0] / 2.0, -leg_total_size[1], -leg_total_size[2] / 2.0]
            )
            upper_leg_origin_to_hip_anchor_local = (
                leg_total_size[0] / 2.0,
                leg_total_size[1],
                leg_total_size[2] / 2.0,
            )
            hinge_a_axis = np.array([0.0, 0.0, 1.0])
            hinge_b_axis = np.array([0.0, 0.0, -1.0])

        upper_leg_body, upper_leg_info = create_body_config(
            body_sid=leg_idx + 1,
            x_voxels=leg_x_voxels,
            y_voxels=leg_y_voxels,
            z_voxels=leg_z_voxels,
            relative_origin_position=tuple(leg_origin_struct),
            relative_orientation=(0.0, 0.0, 0.0, 1.0),
            soft_shell_thickness=soft_shell_thickness,
            voxel_size=voxel_size,
        )
        all_bodies.append((leg_idx + 1, upper_leg_body, upper_leg_info))

        torso_origin_to_hip_anchor_local = hip_anchor_struct - np.array(torso_origin)

        constraint = RS_StructureConstraintConfig()
        constraint.type = RSE_StructureConstraintType.RSE_HINGE_JOINT

        constraint.a_body_sid = 0
        constraint.a_segment_bid = 1
        constraint.a_local_anchor = RVec3rf(*torso_origin_to_hip_anchor_local)

        constraint.b_body_sid = leg_idx + 1
        constraint.b_segment_bid = 1
        constraint.b_local_anchor = RVec3rf(*upper_leg_origin_to_hip_anchor_local)

        constraint.hinge_a_local_axis = RVec3rf(*hinge_a_axis)
        constraint.hinge_b_local_axis = RVec3rf(*hinge_b_axis)

        constraint.hinge_rotation_angle_signal_sid = leg_idx
        constraint.hinge_min = -0.7
        constraint.hinge_max = 0.7
        constraint.hinge_max_torque = 6.0

        structure_config.constraints.append(constraint)

    knee_gap_distance = knee_gap_voxels * voxel_size

    lower_leg_configs = [
        (
            0,
            "+X",
            (upper_leg_length_voxels, leg_thickness_voxels, leg_thickness_voxels),
        ),
        (
            1,
            "+Y",
            (leg_thickness_voxels, upper_leg_length_voxels, leg_thickness_voxels),
        ),
        (
            2,
            "-X",
            (upper_leg_length_voxels, leg_thickness_voxels, leg_thickness_voxels),
        ),
        (
            3,
            "-Y",
            (leg_thickness_voxels, upper_leg_length_voxels, leg_thickness_voxels),
        ),
    ]

    for (
        leg_idx,
        extend_axis,
        (upper_x_voxels, upper_y_voxels, upper_z_voxels),
    ) in lower_leg_configs:
        upper_leg_total_size = (
            np.array(
                [
                    upper_x_voxels + 2 * soft_shell_thickness,
                    upper_y_voxels + 2 * soft_shell_thickness,
                    upper_z_voxels + 2 * soft_shell_thickness,
                ]
            )
            * voxel_size
        )

        if extend_axis == "+X":
            upper_leg_origin_struct = all_bodies[leg_idx + 1][
                1
            ].relative_origin_position
            upper_leg_origin_struct = np.array(
                [
                    upper_leg_origin_struct.x,
                    upper_leg_origin_struct.y,
                    upper_leg_origin_struct.z,
                ]
            )
            knee_anchor_struct = upper_leg_origin_struct + np.array(
                [
                    upper_leg_total_size[0],
                    upper_leg_total_size[1] / 2.0,
                    upper_leg_total_size[2] / 2.0,
                ]
            )
            upper_leg_origin_to_knee_anchor_local = (
                upper_leg_total_size[0],
                upper_leg_total_size[1] / 2.0,
                upper_leg_total_size[2] / 2.0,
            )
        elif extend_axis == "+Y":
            upper_leg_origin_struct = all_bodies[leg_idx + 1][
                1
            ].relative_origin_position
            upper_leg_origin_struct = np.array(
                [
                    upper_leg_origin_struct.x,
                    upper_leg_origin_struct.y,
                    upper_leg_origin_struct.z,
                ]
            )
            knee_anchor_struct = upper_leg_origin_struct + np.array(
                [
                    upper_leg_total_size[0] / 2.0,
                    upper_leg_total_size[1],
                    upper_leg_total_size[2] / 2.0,
                ]
            )
            upper_leg_origin_to_knee_anchor_local = (
                upper_leg_total_size[0] / 2.0,
                upper_leg_total_size[1],
                upper_leg_total_size[2] / 2.0,
            )
        elif extend_axis == "-X":
            upper_leg_origin_struct = all_bodies[leg_idx + 1][
                1
            ].relative_origin_position
            upper_leg_origin_struct = np.array(
                [
                    upper_leg_origin_struct.x,
                    upper_leg_origin_struct.y,
                    upper_leg_origin_struct.z,
                ]
            )
            knee_anchor_struct = upper_leg_origin_struct + np.array(
                [
                    0.0,
                    upper_leg_total_size[1] / 2.0,
                    upper_leg_total_size[2] / 2.0,
                ]
            )
            upper_leg_origin_to_knee_anchor_local = (
                0.0,
                upper_leg_total_size[1] / 2.0,
                upper_leg_total_size[2] / 2.0,
            )
        else:
            upper_leg_origin_struct = all_bodies[leg_idx + 1][
                1
            ].relative_origin_position
            upper_leg_origin_struct = np.array(
                [
                    upper_leg_origin_struct.x,
                    upper_leg_origin_struct.y,
                    upper_leg_origin_struct.z,
                ]
            )
            knee_anchor_struct = upper_leg_origin_struct + np.array(
                [
                    upper_leg_total_size[0] / 2.0,
                    0.0,
                    upper_leg_total_size[2] / 2.0,
                ]
            )
            upper_leg_origin_to_knee_anchor_local = (
                upper_leg_total_size[0] / 2.0,
                0.0,
                upper_leg_total_size[2] / 2.0,
            )

        lower_leg_total_size = (
            np.array(
                [
                    leg_thickness_voxels + 2 * soft_shell_thickness,
                    leg_thickness_voxels + 2 * soft_shell_thickness,
                    upper_leg_length_voxels + 2 * soft_shell_thickness,
                ]
            )
            * voxel_size
        )

        lower_leg_origin_struct = knee_anchor_struct + np.array(
            [
                -lower_leg_total_size[0] / 2.0,
                -lower_leg_total_size[1] / 2.0,
                -lower_leg_total_size[2] - knee_gap_distance,
            ]
        )

        lower_leg_body, lower_leg_info = create_body_config(
            body_sid=leg_idx + 5,
            x_voxels=leg_thickness_voxels,
            y_voxels=leg_thickness_voxels,
            z_voxels=upper_leg_length_voxels,
            relative_origin_position=tuple(lower_leg_origin_struct),
            relative_orientation=(0.0, 0.0, 0.0, 1.0),
            soft_shell_thickness=soft_shell_thickness,
            voxel_size=voxel_size,
        )
        all_bodies.append((leg_idx + 5, lower_leg_body, lower_leg_info))

        constraint = RS_StructureConstraintConfig()
        constraint.type = RSE_StructureConstraintType.RSE_HINGE_JOINT

        constraint.a_body_sid = leg_idx + 1
        constraint.a_segment_bid = 1
        constraint.a_local_anchor = RVec3rf(*upper_leg_origin_to_knee_anchor_local)

        constraint.b_body_sid = leg_idx + 5
        constraint.b_segment_bid = 1
        lower_leg_origin_to_knee_anchor_local = (
            lower_leg_total_size[0] / 2.0,
            lower_leg_total_size[1] / 2.0,
            lower_leg_total_size[2] + knee_gap_distance,
        )
        constraint.b_local_anchor = RVec3rf(*lower_leg_origin_to_knee_anchor_local)

        if extend_axis in ["+X", "-X"]:
            constraint.hinge_a_local_axis = RVec3rf(0.0, 1.0, 0.0)
            constraint.hinge_b_local_axis = RVec3rf(0.0, -1.0, 0.0)
        else:
            constraint.hinge_a_local_axis = RVec3rf(1.0, 0.0, 0.0)
            constraint.hinge_b_local_axis = RVec3rf(-1.0, 0.0, 0.0)

        constraint.hinge_rotation_angle_signal_sid = 4 + leg_idx
        constraint.hinge_min = -0.7
        constraint.hinge_max = 0.7
        constraint.hinge_max_torque = 6.0

        structure_config.constraints.append(constraint)

    all_bodies.sort(key=lambda x: x[0])
    body_infos = []
    for body_sid, body_config, body_info in all_bodies:
        structure_config.bodies.append(body_config)
        body_infos.append(body_info)

    structure_config.rotation_angle_signal_num = 8

    combined_size = 64
    is_not_empty = np.zeros((combined_size, combined_size, combined_size), dtype=bool)
    is_rigid = np.zeros((combined_size, combined_size, combined_size), dtype=bool)
    segment_id = np.zeros((combined_size, combined_size, combined_size), dtype=int)

    def structure_to_voxel_idx(
        pos: np.ndarray, offset: np.ndarray
    ) -> Tuple[int, int, int]:
        voxel_pos = (pos / voxel_size + offset).astype(int)
        return tuple(np.clip(voxel_pos, 0, combined_size - 1))

    grid_offset = np.array([combined_size // 2, combined_size // 2, combined_size // 4])
    torso_size_voxels = torso_radius_voxels * 2
    torso_center_voxel = np.array(
        [torso_size_voxels // 2, torso_size_voxels // 2, torso_size_voxels // 2]
    )
    for ix in range(torso_size_voxels):
        for iy in range(torso_size_voxels):
            for iz in range(torso_size_voxels):
                voxel_pos = np.array([ix, iy, iz])
                distance = np.linalg.norm(voxel_pos - torso_center_voxel)
                if distance <= torso_radius_voxels:
                    local_pos = np.array([ix, iy, iz]) * voxel_size
                    world_pos = np.array(torso_origin) + local_pos
                    vx, vy, vz = structure_to_voxel_idx(world_pos, grid_offset)
                    if (
                        0 <= vx < combined_size
                        and 0 <= vy < combined_size
                        and 0 <= vz < combined_size
                    ):
                        is_rigid[vx, vy, vz] = True
                        is_not_empty[vx, vy, vz] = True
                        segment_id[vx, vy, vz] = 1

    connections = []

    structure = {
        "is_not_empty": is_not_empty,
        "is_rigid": is_rigid,
        "segment_id": segment_id,
        "connections": connections,
        "body_infos": body_infos,
        "soft_shell_thickness": soft_shell_thickness,
        "voxel_size": voxel_size,
    }

    return structure_config, structure


if __name__ == "__main__":
    structure_config, structure = create_ant_8dof_fix(
        "ant_8dof_fix",
        (0.0, 0.0, 0.0),
        "material_0",
        "material_1",
        soft_shell_thickness=1,
        torso_radius_voxels=6,
        upper_leg_length_voxels=12,
        leg_thickness_voxels=2,
        gap_voxels=1,
        knee_gap_voxels=3,
    )

    output_path = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
        "data",
        "robot_config",
        "ant_8dof.data",
    )
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "wb") as f:
        pickle.dump(structure_config, f)
