import cc3d
import torch as t
import vedo
import numpy as np
from PIL import Image
from collections import Counter
from typing import Tuple, List, Union
from scipy.ndimage import binary_dilation, binary_closing

from utils.plot import (
    plot_binary_vedo,
    plot_segment_id_vedo,
)

from rise import *


class RobotBuilder:
    def __init__(
        self,
        voxel_size: float = 0.01,
        camera_offset_length: float = 0.2,
        valid_min_largest_component_ratio: float = 0.5,
        valid_min_rigid_ratio: float = 0.1,
        valid_min_joint_num: int = 3,
        valid_max_joint_num: int = 20,
        valid_min_camera_num: int = 1,
        cover_exposed_rigid_dilation_size: int = 1,
        cover_exposed_rigid_iterations: int = 3,
        cover_exposed_grid_boundary_bleeding: int = 0,
        rigid_void_dilation_size: int = 2,
        joint_closing_dilation_size: int = 2,
        soft_closing_dilation_size: int = 1,
        min_joint_volume: int = 10,
        min_camera_volume: int = 1,
        min_rigid_segment_volume: int = 80,
        joint_dilation_factors: List[int] = [0.5,0.8,1,1.2,1.5,1.7,2],
        min_flap_elongation_ratio: float = 2,
        limit_max_camera_num: int = 1,
        remove_boundary_rigid_voxel_size: int = 2,
    ):
        self.voxel_size = voxel_size
        self.camera_offset_length = camera_offset_length

        self.valid_min_largest_component_ratio = valid_min_largest_component_ratio
        self.valid_min_rigid_ratio = valid_min_rigid_ratio
        self.valid_min_joint_num = valid_min_joint_num
        self.valid_max_joint_num = valid_max_joint_num
        self.valid_min_camera_num = valid_min_camera_num

        self.cover_exposed_rigid_dilation_size = cover_exposed_rigid_dilation_size
        self.cover_exposed_rigid_iterations = cover_exposed_rigid_iterations
        self.cover_exposed_grid_boundary_bleeding = cover_exposed_grid_boundary_bleeding
        self.rigid_void_dilation_size = rigid_void_dilation_size
        self.joint_closing_dilation_size = joint_closing_dilation_size
        self.soft_closing_dilation_size = soft_closing_dilation_size
        self.min_joint_volume = min_joint_volume
        self.min_camera_volume = min_camera_volume
        self.min_rigid_segment_volume = min_rigid_segment_volume
        self.joint_dilation_factors = joint_dilation_factors
        self.min_flap_elongation_ratio = min_flap_elongation_ratio
        self.limit_max_camera_num = limit_max_camera_num
        self.remove_boundary_rigid_voxel_size = remove_boundary_rigid_voxel_size
        # For recording intermediate processing data for visualization etc.
        self.current_robot_stats = {}

    def build(
        self,
        voxels: Union[t.Tensor, np.ndarray],
        robot_name: str = "robot",
        soft_material_name: str = "material_0",
        rigid_material_name: str = "material_1",
        camera_max_depth: int = 5,
        camera_width: int = 128,
        camera_height: int = 128,
        camera_samples: int = 4,
        camera_fov: float = 45.0,
        hinge_limit_min: float = -1.5,
        hinge_limit_max: float = 1.5,
        hinge_max_torque: float = 5,
        debug: bool = False,
        print_summary: bool = True,
    ) -> Union[Tuple[None, None, None], Tuple[RS_StructureConfig, List[RS_CameraConfig], dict]]:
        """
        Args:
            voxels: 4D numpy array of shape [4, X, Y, Z], can be boolean or float, if it is float, it will
            be converted to boolean by thresholding at 0. Channel are layout as:
            [is_not_empty, is_rigid, is_joint, is_camera]

        Returns:
            Robot structure config, list of robot camera configs, and the robot structure dict
        """
        structure_config = RS_StructureConfig()
        structure_config.name = robot_name
        structure_config.voxel_size = self.voxel_size
        structure_config.material_references.append(soft_material_name)
        structure_config.material_references.append(rigid_material_name)
        structure_config.material_references.append("material_smooth_soft")

        build_success, process_result = self.parse(voxels, debug=debug)
        if not build_success:
            invalid_reason = process_result
            if print_summary:
                print(
                    f"[{robot_name}] Build failed, invalid reason: {invalid_reason}"
                )
            return None, None, None
        else:
            robot_structure = process_result
            processed_voxels = robot_structure["voxels"]
            rigid_segment_labels = robot_structure["rigid_segment_labels"]
            rigid_segment_num = robot_structure["rigid_segment_num"]
            joints = robot_structure["joints"]
            cameras = robot_structure["cameras"]

        if print_summary:
            print(
                f"[{robot_name}] "
                f"Voxel num: {np.sum(processed_voxels[0])} "
                f"Occupancy rate: {np.sum(processed_voxels[0]) / processed_voxels.shape[1] / processed_voxels.shape[2] / processed_voxels.shape[3]:.2f} "
                f"rigid segment num: {rigid_segment_num} "
                f"joint num: {len(joints)}"
            )

        body_config, z_offset, bounding_box_size = self.build_body_config(
            processed_voxels[0], processed_voxels[1], rigid_segment_labels
        )
        structure_config.bodies.append(body_config)

        constraint_configs = self.build_constraint_configs(
            joints, z_offset, hinge_limit_min, hinge_limit_max, hinge_max_torque
        )
        for constraint in constraint_configs:
            structure_config.constraints.append(constraint)
        structure_config.rotation_angle_signal_num = len(joints)

        camera_configs = self.build_camera_configs(
            cameras,
            robot_name,
            camera_max_depth,
            camera_width,
            camera_height,
            camera_samples,
            camera_fov,
            z_offset,
        )

        return structure_config, camera_configs, robot_structure

    def parse(
        self,
        voxels: Union[t.tensor, np.ndarray],
        debug: bool = False,
    ):
        self.current_robot_stats = {}
        assert voxels.ndim == 4

        voxels = voxels.cpu().numpy() if t.is_tensor(voxels) else voxels

        if voxels.dtype != bool:
            print("Warning: voxels are not boolean, converting to boolean")
            voxels = voxels > 0.5

        # voxels = self.preprocess_closing_joints(
        #     voxels,
        #     joint_closing_dilation_size=self.joint_closing_dilation_size,
        # )

        # Step2: Update rigid and non-empty masks
        # mark joints as non-empty and rigid
        voxels[0] = np.logical_or(voxels[0], voxels[2])
        voxels[1] = np.logical_or(voxels[1], voxels[2])
        #mark cameras as non-empty and rigid
        voxels[0] = np.logical_or(voxels[0], voxels[3])
        voxels[1] = np.logical_or(voxels[1], voxels[3])

        voxels = self.preprocess_closing_rigid(
            voxels,
            rigid_void_dilation_size=self.rigid_void_dilation_size,
        )

        voxels = self.preprocess_remove_boundary_rigid_voxels(
            voxels,
            remove_boundary_rigid_voxel_size=self.remove_boundary_rigid_voxel_size,
        )
        #mark rigid as non-empty
        voxels[0] = np.logical_or(voxels[0], voxels[1])

        # Step3: Cover exposed rigid regions and close soft voxels
        voxels = self.preprocess_cover_exposed_rigid_regions(
            voxels,
            cover_exposed_rigid_dilation_size=self.cover_exposed_rigid_dilation_size,
            cover_exposed_rigid_iterations=self.cover_exposed_rigid_iterations,
            cover_exposed_grid_boundary_bleeding=self.cover_exposed_grid_boundary_bleeding,
        )
        voxels = self.preprocess_closing_soft(
            voxels,
            soft_closing_dilation_size=self.soft_closing_dilation_size,
        )

        # Step4: Eliminate small joints and cameras, and finally rigid segments
        voxels = self.preprocess_eliminate_small_joints(
            voxels, min_joint_volume=self.min_joint_volume
        )
        voxels = self.preprocess_eliminate_small_cameras(
            voxels, min_camera_volume=self.min_camera_volume
        )
        voxels = self.preprocess_eliminate_small_rigid_segments(
            voxels, min_rigid_segment_volume=self.min_rigid_segment_volume
        )

        # Step4: Filter largest connected component
        voxels, largest_component_ratio, rigid_ratio_in_largest_component = (
            self.preprocess_filter_largest_connected_component(voxels)
        )

        if largest_component_ratio < self.valid_min_largest_component_ratio:
            invalid_reason = (
                f"Below min largest component ratio {self.valid_min_largest_component_ratio}, "
                f"Ratio: {largest_component_ratio}"
            )
            return False, invalid_reason
        if rigid_ratio_in_largest_component < self.valid_min_rigid_ratio:
            invalid_reason = (
                f"Below min rigid ratio in largest component {self.valid_min_rigid_ratio}, "
                f"Ratio: {rigid_ratio_in_largest_component}"
            )
            return False, invalid_reason

        # self.current_robot_stats["largest_component_ratio"] = largest_component_ratio
        # self.current_robot_stats["rigid_ratio_in_largest_component"] = (
        #     rigid_ratio_in_largest_component
        # )

        # Step5: Separate rigid segments by joints
        voxels, rigid_segment_labels, rigid_segment_num, joints = (
            self.separate_rigid_segments_by_joints(
                voxels,
                joint_dilation_factors=self.joint_dilation_factors,
                min_rigid_segment_volume=self.min_rigid_segment_volume,
                min_flap_elongation_ratio=self.min_flap_elongation_ratio,
                debug=debug,
            )
        )

        if (
            len(joints) < self.valid_min_joint_num
            or len(joints) > self.valid_max_joint_num
        ):
            invalid_reason = (
                f"Below min joint num {self.valid_min_joint_num} or above max joint num {self.valid_max_joint_num}, "
                f"Joint num: {len(joints)}"
            )
            return False, invalid_reason

        # Additional validation: ensure the largest rigid segment is at least 2x the second largest
        largest_vol, second_vol = RobotBuilder.compute_top_two_rigid_segment_volumes(
            rigid_segment_labels
        )
        if second_vol > 0:
            ratio = largest_vol / second_vol
            if ratio < 3.0:
                invalid_reason = (
                    f"Largest/second rigid segment voxel ratio {ratio:.2f} < 2.0, "
                    f"largest={largest_vol}, second={second_vol}"
                )
                return False, invalid_reason
        # Step6: Extract cameras
        # Extract cameras, passing the rigid segment labels
        cameras = self.extract_cameras(
            voxels,
            rigid_segment_labels,
            camera_offset_voxels=self.camera_offset_length / self.voxel_size,
            limit_max_camera_num=self.limit_max_camera_num,
            debug=debug,
        )

        if len(cameras) < self.valid_min_camera_num:
            invalid_reason = (
                f"Below min camera num {self.valid_min_camera_num}, Camera num: {len(cameras)}"
            )
            return False, invalid_reason

        # Step7: Calculate joint axis based on camera direction and voxel center
        joints = self.calculate_joint_axis(voxels, joints, cameras)
        
        robot_structure = {
            "voxels": voxels,
            "rigid_segment_labels": rigid_segment_labels,
            "rigid_segment_num": rigid_segment_num,
            "joints": joints,
            "cameras": cameras,
        }

        return True, robot_structure

    def build_body_config(
        self, is_not_empty: np.ndarray, is_rigid: np.ndarray, rigid_segments: np.ndarray
    ):
        body_config = RS_StructureBodyConfig()
        body_config.relative_orientation = RQuat3rf(0, 0, 0, 1)
        body_config.relative_origin_position = RVec3rf(0, 0, 0)
        body_config.body_sid = 0

        non_empty_x = np.sum(is_not_empty.astype(int), axis=(1, 2)) > 0
        start_x = np.argmax(non_empty_x)
        end_x = len(non_empty_x) - np.argmax(np.flip(non_empty_x))
        non_empty_y = np.sum(is_not_empty.astype(int), axis=(0, 2)) > 0
        start_y = np.argmax(non_empty_y)
        end_y = len(non_empty_y) - np.argmax(np.flip(non_empty_y))

        non_empty_z = np.sum(is_not_empty.astype(int), axis=(0, 1)) > 0
        start_layer = np.argmax(non_empty_z)
        end_layer = len(non_empty_z) - np.argmax(np.flip(non_empty_z))
        x_size = end_x - start_x
        y_size = end_y - start_y
        z_size = end_layer - start_layer

        x_layer_size = is_not_empty.shape[0]
        y_layer_size = is_not_empty.shape[1]
        layer_size = x_layer_size * y_layer_size

        self.current_robot_stats["size"] = (x_size, y_size, z_size)

        material_id, segment_id, segment_type = [], [], []
        m_id = np.full(is_not_empty.shape, RS_NULL_INDEX, dtype=int)
        s_id = np.full(is_not_empty.shape, RS_NULL_INDEX, dtype=int)
        s_type = np.full(is_not_empty.shape, RS_NULL_INDEX, dtype=int)

        # Create mask for y-axis range [-20, 20] and apply smooth material
        y_indices = np.arange(is_not_empty.shape[1])
        y_center = is_not_empty.shape[1] // 2
        y_mask = np.abs(y_indices - y_center) <= 20
        
        # Apply mask to all x and z coordinates
        smooth_mask = is_not_empty & y_mask[None, :, None]
        m_id[is_not_empty] = 0
        m_id[is_rigid] = 1
        m_id[smooth_mask&is_not_empty] = 2
        s_id[is_not_empty] = 0
        s_id = np.where(is_rigid, rigid_segments, s_id)
        s_type[is_rigid] = 1
        # Transform from X, Y, Z to ZYX
        m_id = m_id[:, :, start_layer:end_layer].transpose(2, 1, 0).flatten()
        s_id = s_id[:, :, start_layer:end_layer].transpose(2, 1, 0).flatten()
        s_type = s_type[:, :, start_layer:end_layer].transpose(2, 1, 0).flatten()

        body_config.x_voxels = x_layer_size
        body_config.y_voxels = y_layer_size
        body_config.z_voxels = z_size

        for i in range(len(m_id)):
            body_config.material_reference_sid.append(m_id[i])
            body_config.segment_bid.append(s_id[i])
            # 0 is soft and 1 is rigid
            body_config.segment_type.append(s_type[i])

        
        return body_config, start_layer, (
            x_size * self.voxel_size,
            y_size * self.voxel_size,
            z_size * self.voxel_size,
        )

    def build_constraint_configs(
        self,
        joints: List[dict],
        z_offset,
        hinge_limit_min,
        hinge_limit_max,
        hinge_max_torque,
    ):
        constraint_configs = []
        for idx, joint in enumerate(joints):
            constraint = RS_StructureConstraintConfig()
            constraint.type = RSE_StructureConstraintType.RSE_HINGE_JOINT
            constraint.a_body_sid = 0
            constraint.b_body_sid = 0
            constraint.a_segment_bid = joint["components"][0]
            constraint.b_segment_bid = joint["components"][1]
            constraint.a_local_anchor = RVec3rf(
                joint["position"][0] * self.voxel_size,
                joint["position"][1] * self.voxel_size,
                (joint["position"][2] - z_offset) * self.voxel_size,
            )
            constraint.b_local_anchor = RVec3rf(
                joint["position"][0] * self.voxel_size,
                joint["position"][1] * self.voxel_size,
                (joint["position"][2] - z_offset) * self.voxel_size,
            )
            constraint.hinge_rotation_angle_signal_sid = idx
            constraint.hinge_a_local_axis = RVec3rf(
                joint["axis"][0],
                joint["axis"][1],
                joint["axis"][2],
            )
            constraint.hinge_b_local_axis = RVec3rf(
                -joint["axis"][0],
                -joint["axis"][1],
                -joint["axis"][2],
            )
            constraint.hinge_min = hinge_limit_min
            constraint.hinge_max = hinge_limit_max
            constraint.hinge_max_torque = hinge_max_torque
            constraint_configs.append(constraint)
        return constraint_configs

    def build_camera_configs(
        self,
        cameras: List[dict],
        robot_name: str,
        camera_max_depth: int,
        camera_width: int,
        camera_height: int,
        camera_samples: int,
        camera_fov: float,
        z_offset: float,
    ):
        camera_configs = []
        for idx, camera in enumerate(cameras):
            ccoonf = RS_CameraConfig()
            ccoonf.name = f"{robot_name}_camera_{idx}"
            ccoonf.type = RSE_CameraType.RSE_CAMERA_RIGID_BODY_LINKED_PERSPECTIVE

            crpconf = RS_RigidBodyLinkedPerspectiveCameraConfig()
            crpconf.max_depth = camera_max_depth
            crpconf.image_width = camera_width
            crpconf.image_height = camera_height
            crpconf.samples_per_pixel = camera_samples
            crpconf.field_of_view = camera_fov
            crpconf.structure_name = robot_name
            crpconf.body_sid = 0
            crpconf.segment_bid = int(camera["component"])
            crpconf.local_anchor = RVec3rf(
                camera["position"][0] * self.voxel_size,
                camera["position"][1] * self.voxel_size,
                (camera["position"][2] - z_offset) * self.voxel_size,
            )
            crpconf.local_orientation = RQuat3rf(*camera["orientation"])
            crpconf.defocus_angle = 0
            crpconf.focus_distance = 0.05
            ccoonf.config = crpconf
            camera_configs.append(ccoonf)
        return camera_configs

    @staticmethod
    def preprocess_closing_rigid(
        voxels: np.ndarray, rigid_void_dilation_size: int
    ):
        """
        Close rigid voids in the second channel of voxels array with boundary padding
        """
        new_voxels = np.copy(voxels)
        
        # Use safe closing to avoid edge effects
        structure = RobotBuilder.create_dilation_structure(rigid_void_dilation_size)
        new_voxels[1] = RobotBuilder.safe_binary_closing(
            new_voxels[1], 
            structure=structure,
            iterations=2
        )

        return new_voxels

    
    @staticmethod
    def preprocess_closing_soft(voxels: np.ndarray, soft_closing_dilation_size: int):
        """
        Close soft voxels in the first channel of voxels array with boundary padding
        """
        new_voxels = np.copy(voxels)
        
        # Use safe closing to avoid edge effects
        structure = RobotBuilder.create_dilation_structure(soft_closing_dilation_size)
        new_voxels[0] = RobotBuilder.safe_binary_closing(
            new_voxels[0], 
            structure=structure,
            iterations=2
        )
        
        return new_voxels

    @staticmethod
    def preprocess_eliminate_small_rigid_segments(
        voxels: np.ndarray, min_rigid_segment_volume: int
    ):
        """
        Eliminate small rigid segments in the second channel of voxels array
        """
        rigid_segment_labels, num = cc3d.connected_components(
            voxels[1], connectivity=26, return_N=True
        )
        rigid_segment_statistics = cc3d.statistics(rigid_segment_labels)

        new_voxels = np.copy(voxels)
        for label in range(1, num + 1):
            if (
                rigid_segment_statistics["voxel_counts"][label]
                < min_rigid_segment_volume
            ):
                new_voxels[1, rigid_segment_labels == label] = 0
        return new_voxels

    @staticmethod
    def preprocess_eliminate_small_joints(voxels: np.ndarray, min_joint_volume: int):
        """
        Eliminate small joints in the third channel of voxels array
        """
        joint_labels, num = cc3d.connected_components(
            voxels[2], connectivity=26, return_N=True
        )
        joint_statistics = cc3d.statistics(joint_labels)

        new_voxels = np.copy(voxels)
        for label in range(1, num + 1):
            if joint_statistics["voxel_counts"][label] < min_joint_volume:
                new_voxels[2, joint_labels == label] = 0
        return new_voxels

    @staticmethod
    def preprocess_eliminate_small_cameras(voxels: np.ndarray, min_camera_volume: int):
        """
        Eliminate small cameras in the fourth channel of voxels array
        """
        camera_labels, num = cc3d.connected_components(
            voxels[3], connectivity=26, return_N=True
        )
        camera_statistics = cc3d.statistics(camera_labels)

        new_voxels = np.copy(voxels)
        for label in range(1, num + 1):
            if camera_statistics["voxel_counts"][label] < min_camera_volume:
                new_voxels[3, camera_labels == label] = 0

        return new_voxels

    @staticmethod
    def preprocess_remove_boundary_rigid_voxels(voxels: np.ndarray, remove_boundary_rigid_voxel_size: int):
        """
        Cover boundary voxels in the first channel of voxels array
        """
        new_voxels = np.copy(voxels)
        # Get the dimensions of the voxel grid
        rx, ry, rz = voxels.shape[1:]
        
        # Create boundary mask for the specified boundary size
        boundary_mask = np.zeros((rx, ry, rz), dtype=bool)
        
        # Mark all voxels within cover_boundary_voxel_size distance from any boundary as boundary voxels
        boundary_mask[:remove_boundary_rigid_voxel_size, :, :] = True  # x-axis lower boundary
        boundary_mask[-remove_boundary_rigid_voxel_size:, :, :] = True  # x-axis upper boundary
        boundary_mask[:, :remove_boundary_rigid_voxel_size, :] = True  # y-axis lower boundary
        boundary_mask[:, -remove_boundary_rigid_voxel_size:, :] = True  # y-axis upper boundary
        boundary_mask[:, :, :remove_boundary_rigid_voxel_size] = True  # z-axis lower boundary
        boundary_mask[:, :, -remove_boundary_rigid_voxel_size:] = True  # z-axis upper boundary
        
        # Find voxels that are both in boundary region and have voxel[0]=True
        boundary_non_empty_mask = boundary_mask & new_voxels[0] #|new_voxels[1])
        # print("boundary_non_empty_mask.sum()", boundary_non_empty_mask.sum())
        
        # For these boundary voxels, set channels 1, 2, 3 to False
        new_voxels[1, boundary_non_empty_mask] = 0  # rigid channel
        new_voxels[2, boundary_non_empty_mask] = 0  # joint channel
        new_voxels[3, boundary_non_empty_mask] = 0  # camera channel
        
        return new_voxels


    @staticmethod
    def preprocess_cover_exposed_rigid_regions(
        voxels: np.ndarray,
        cover_exposed_rigid_dilation_size: int,
        cover_exposed_rigid_iterations: int,
        cover_exposed_grid_boundary_bleeding: int,
    ):
        """
        Cover exposed rigid regions with soft material, do not dilate on grid boundary
        """
        rx, ry, rz = voxels.shape[1:]
        non_boundary_mask = np.zeros((rx, ry, rz), dtype=bool)
        non_boundary_mask[
            cover_exposed_grid_boundary_bleeding : rx
            - cover_exposed_grid_boundary_bleeding,
            cover_exposed_grid_boundary_bleeding : ry
            - cover_exposed_grid_boundary_bleeding,
            cover_exposed_grid_boundary_bleeding : rz
            - cover_exposed_grid_boundary_bleeding,
        ] = True
        structure = RobotBuilder.create_dilation_structure(
            cover_exposed_rigid_dilation_size
        )
        minimum_containment_mask = (
            binary_dilation(
                voxels[1],
                structure=structure,
                iterations=cover_exposed_rigid_iterations,
            )
            & non_boundary_mask
        )
        new_voxels = np.copy(voxels)
        new_voxels[0, minimum_containment_mask] = True
        return new_voxels

    @staticmethod
    def preprocess_filter_largest_connected_component(voxels: np.ndarray):
        """
        Filter the largest connected component in the first channel of voxels array, and clear other
        channels accordingly
        """
        # Must use stricter 6 connectivity
        total_voxels = voxels[0] #| voxels[1] | voxels[2] | voxels[3]
        labels, num = cc3d.connected_components(
            total_voxels, connectivity=6, return_N=True
        )
        component_sizes = np.bincount(labels.flatten())

        # Remove background
        component_sizes[0] = 0
        largest_component_label = component_sizes.argmax()
        largest_component_mask = labels == largest_component_label
        new_voxels = np.where(largest_component_mask[None, ...], voxels, False)

        total_largest_component_volume = np.sum(largest_component_mask)
        total_largest_component_ratio = total_largest_component_volume / np.sum(
            total_voxels
        )
        total_rigid_ratio_in_largest_component = (
            np.sum(new_voxels[1]) / total_largest_component_volume
        )

        return (
            new_voxels,
            total_largest_component_ratio,
            total_rigid_ratio_in_largest_component,
        )

    @staticmethod
    def separate_rigid_segments_by_joints(
        voxels: np.ndarray,
        joint_dilation_factors: List[int],
        min_rigid_segment_volume: int,
        min_flap_elongation_ratio: float,
        debug: bool = False,
    ):
        # Compute surface rigid voxels
        # neighbor_structure = np.ones((3, 3, 3), dtype=bool)
        # neighbor_structure[1, 1, 1] = False
        
        # 18 connectivity
        neighbor_structure = np.array(
            [
                [
                    [0, 1, 0],
                    [1, 1, 1],
                    [0, 1, 0],
                ],
                [
                    [1, 1, 1],
                    [1, 0, 1],
                    [1, 1, 1],
                ],
                [
                    [0, 1, 0],
                    [1, 1, 1],
                    [0, 1, 0],
                ],
            ],
            dtype=bool,
        )

        # To-do: not neccessary correct
        neighbors = RobotBuilder.convolve_gather(voxels[1], neighbor_structure)
        is_rigid_surface = voxels[1] & np.any(~neighbors, axis=3)

        if debug:
            plotter = vedo.Plotter(
                shape=(1, 1), size=(1600, 1600), title="Voxel Visualization", bg="white"
            )
            rigid_vis = plot_binary_vedo(voxels[1], (1, 0, 0, 0.3), False)
            joint_vis = plot_binary_vedo(voxels[2], (0, 0, 1, 1), False)
            plotter.show([rigid_vis, joint_vis], at=0, resetcam=True, interactive=True)

        # First, identify possible joint locations
        joint_labels, label_num = cc3d.connected_components(
            voxels[2],
            connectivity=26,
            return_N=True,
            out_dtype=np.uint32,
        )

        # For each joint, dilate it and check if it causes rigid segments to separate
        # If so, then it is a valid joint
        # Otherwise, it is not a valid joint
        rx, ry, rz = voxels.shape[1:]
        x_indices, y_indices, z_indices = np.meshgrid(
            np.arange(rx), np.arange(ry), np.arange(rz), indexing="ij"
        )

        joint_info = []  # Store information about valid joints

        # Create a temporary copy of rigid body mask for testing separability of joints
        rigid_body_voxels = np.copy(voxels[1])

        for label in range(1, label_num + 1):
            joint_mask = joint_labels == label

            current_rigid_body_labels, current_rigid_body_num = (
                cc3d.connected_components(
                    rigid_body_voxels,
                    connectivity=26,
                    return_N=True,
                )
            )

            for dilation_factor in joint_dilation_factors:
                dilated_joint_mask, joint_center = RobotBuilder.sphericalize_voxels(joint_mask, dilation_factor)

                new_rigid_body_labels, new_rigid_body_num = cc3d.connected_components(
                    np.logical_and(rigid_body_voxels, ~dilated_joint_mask),
                    connectivity=26,
                    return_N=True,
                )

                new_component_sizes = np.bincount(new_rigid_body_labels.flatten())
                new_component_sizes[0] = 0  # Ignore background

                # Remove dust
                new_component_sizes[new_component_sizes < min_rigid_segment_volume] = 0
                new_rigid_body_num = np.count_nonzero(new_component_sizes)
                dust_mask = np.zeros_like(rigid_body_voxels)
                for i in range(len(new_component_sizes)):
                    if new_component_sizes[i] == 0:
                        current_dusk_mask = new_rigid_body_labels == i
                        dust_mask = np.logical_or(dust_mask, current_dusk_mask)
                        new_rigid_body_labels[current_dusk_mask] = 0

                if debug:
                    plotter = vedo.Plotter(
                        shape=(1, 1),
                        size=(1600, 1600),
                        title="Voxel Visualization",
                        bg="white",
                    )
                    print(f"Joint {label} is testing dilation size {dilation_factor}")
                    rigid_vis = plot_binary_vedo(np.logical_and(rigid_body_voxels, ~dilated_joint_mask), (1, 0, 0, 0.3), False)
                    dil_joint_x = plot_binary_vedo(dilated_joint_mask, (0, 1, 0, 1), False)
                    plotter.show(
                        [rigid_vis, dil_joint_x], at=0, resetcam=True, interactive=True
                    )
                    

                if new_rigid_body_num > current_rigid_body_num:
                    if debug:
                        print(
                            f"Joint {label} is separable at dilation size {dilation_factor}"
                        )
                        
                    break
            else:
                if debug:
                    print(f"Joint {label} is not separable")
                continue

            # if debug:
            #     plotter = vedo.Plotter(
            #         shape=(1, 3),
            #         size=(2400, 800),
            #         title=f"Separation Test label={label}",
            #         bg="white",
            #     )
            #     old_id_x = plot_segment_id_vedo(current_rigid_body_labels, smooth=False)
            #     new_id_x = plot_segment_id_vedo(new_rigid_body_labels, smooth=False)
            #     dil_joint_x = plot_binary_vedo(dilated_joint_mask, smooth=False)
            #     plotter.show(old_id_x, at=0)
            #     plotter.show(new_id_x, at=1)
            #     plotter.show(
            #         [dil_joint_x], at=2, resetcam=True, interactive=True
            #     ).close()

            if new_rigid_body_num > current_rigid_body_num:
                # Test if new segments meet minimum size requirement
                # (Make sure to filter global small rigid segments before this step!)

                # Skip if any component is too small
                if np.any(
                    (new_component_sizes > 0)
                    & (new_component_sizes < min_rigid_segment_volume)
                ):
                    continue

                # cutting_normal /= np.linalg.norm(cutting_normal)
                joint_axis = np.array([0.3, 0.0, 1.0])
                joint_axis /= np.linalg.norm(joint_axis)

                joint_info.append(
                    {
                        "components": None,  # Tuple (A, B), we will label later
                        "position": joint_center,
                        "axis": joint_axis,
                        "voxels": dilated_joint_mask,
                    }
                )


                rigid_body_voxels = np.logical_and(
                    ~dilated_joint_mask, rigid_body_voxels
                )
                rigid_body_voxels[dust_mask] = 0
        if debug:
            print(f"Before removing joints, robot has {len(joint_info)} joints")
        # Finally, remove all joint surface voxels from the voxels[1]
        new_voxels = np.copy(voxels)
        new_voxels[1] = rigid_body_voxels

        # Redo a connected component check
        rigid_segment_labels, rigid_segment_num = cc3d.connected_components(
            new_voxels[1], connectivity=26, return_N=True, out_dtype=np.uint32
        )
        rigid_segment_sizes = np.bincount(rigid_segment_labels.flatten())
        rigid_segment_sizes[0] = 0  # Ignore background
        # Remove dust
        rigid_segment_sizes[rigid_segment_sizes < min_rigid_segment_volume] = 0

        # Then add joints, and check if their neighbors have different labels
        valid_joints = []
        for joint in joint_info:
            # Check if any neighbors have different labels
            neighbors = RobotBuilder.convolve_gather(
                rigid_segment_labels, neighbor_structure
            )[joint["voxels"], :]
            neighbors = set(np.unique(neighbors).tolist())
            neighbors.discard(0)  # Ignore background

            neighbor_and_sizes = [
                (neighbor, rigid_segment_sizes[neighbor])
                for neighbor in neighbors
                if rigid_segment_sizes[neighbor] != 0
            ]
            neighbor_and_sizes = sorted(
                neighbor_and_sizes,
                key=lambda x: x[1],
                reverse=True,
            )

            # If have at least two different neighbors, then it is a valid joint
            if len(neighbor_and_sizes) >= 2:
                # Select two biggest neighbors
                joint["components"] = (
                    neighbor_and_sizes[0][0],
                    neighbor_and_sizes[1][0],
                )
                valid_joints.append(joint)
            else:
                if debug:
                    print(f"Joint invalid: Joint {label} has {len(neighbor_and_sizes)} neighbors")

        # Ensure that between two connected rigid segments, there is only one joint
        existing_connections = set()
        deduplicated_joints = []
        for joint in valid_joints:
            if joint["components"] in existing_connections:
                continue
            existing_connections.add(joint["components"])
            existing_connections.add((joint["components"][1], joint["components"][0]))
            deduplicated_joints.append(joint)
        valid_joints = deduplicated_joints

        # Ensure that no cycle is created
        while True:
            connection_graph = [connection["components"] for connection in valid_joints]
            connection_lut = {
                connection["components"]: idx
                for idx, connection in enumerate(valid_joints)
            }
            cycle = RobotBuilder.find_cycle(connection_graph, rigid_segment_num)
            if cycle is None:
                break
            # Find all corresponding edges
            cycle_edges = []
            for i in range(len(cycle)):
                end_1 = cycle[i - 1]
                end_2 = cycle[i]
                if (end_1, end_2) in connection_lut:
                    cycle_edges.append(connection_lut[(end_1, end_2)])
                else:
                    cycle_edges.append(connection_lut[(end_2, end_1)])

            # Just remove the first one
            valid_joints.pop(cycle_edges[0])

        # if debug:
        #     plotter = vedo.Plotter(
        #         shape=(1, 1),
        #         size=(1600, 1600),
        #         title="Final Rigid Segments and Joints",
        #         bg="white",
        #     )
        #     final_id_vis = plot_segment_id_vedo(rigid_segment_labels, smooth=False)
        #     final_joint_vis = []
        #     for joint in valid_joints:
        #         joint_vis = vedo.Arrow(
        #             joint["position"],
        #             joint["position"] + 20 * joint["axis"],
        #             c="red",
        #         )
        #         final_joint_vis.append(joint_vis)
        #     plotter.show(
        #         [final_id_vis] + final_joint_vis, at=0, resetcam=True, interactive=True
        #     ).close()

        return new_voxels, rigid_segment_labels, rigid_segment_num, valid_joints

    @staticmethod
    def extract_cameras(
        voxels: np.ndarray,
        rigid_segment_labels: np.ndarray,
        camera_offset_voxels: float,
        limit_max_camera_num: int,
        debug: bool = False,
    ):
        """
        Extract camera information from voxels

        Args:
            voxels: 4D numpy array of shape [4, X, Y, Z]
            rigid_segment_labels: 3D numpy array of shape [X, Y, Z] with rigid segment IDs
            camera_offset_length: Offset length from camera center along look-at axis
            debug: Whether to show debug visualization

        Returns:
            List of camera information dictionaries, each containing:
                position: np.ndarray of shape [3] - camera position
                orientation: np.ndarray of shape [4] - camera orientation as quaternion
                components: tuple of component IDs
        """
        # Find camera connected components
        camera_labels, num = cc3d.connected_components(
            voxels[3] & voxels[1], connectivity=26, return_N=True, out_dtype=np.uint32
        )
        stats = cc3d.statistics(camera_labels)
        max_camera_size = np.max(stats["voxel_counts"])
        if num == 0 or max_camera_size < 2:
            camera_labels, num = cc3d.connected_components(
                voxels[3], connectivity=26, return_N=True, out_dtype=np.uint32
            )
            stats = cc3d.statistics(camera_labels)
        # Find center of mass of the body
        is_not_empty = voxels[0]
        x_indices, y_indices, z_indices = np.meshgrid(
            np.arange(voxels.shape[1]),
            np.arange(voxels.shape[2]),
            np.arange(voxels.shape[3]),
            indexing="ij",
        )

        # Calculate body center of mass
        body_positions = np.stack(
            [x_indices[is_not_empty], y_indices[is_not_empty], z_indices[is_not_empty]],
            axis=1,
        )
        body_center = (
            np.mean(body_positions, axis=0) if len(body_positions) > 0 else np.zeros(3)
        )

        # Sort camera labels by voxel size (largest first) and remove background
        camera_sizes = [(label, stats["voxel_counts"][label]) for label in range(1, num + 1)]
        camera_sizes.sort(key=lambda x: x[1], reverse=True)

        # Extract only the sorted labels
        sorted_camera_labels = [label for label, _ in camera_sizes]
       

        # For each camera component, compute properties
        cameras = []
        for label in sorted_camera_labels:
            camera_mask = camera_labels == label

            # Get camera center
            camera_positions = np.stack(
                [
                    x_indices[camera_mask],
                    y_indices[camera_mask],
                    z_indices[camera_mask],
                ],
                axis=1,
            )
            camera_center = np.mean(camera_positions, axis=0)

            # Compute look-at axis (from body center to camera center)
            look_at_axis = camera_center - body_center
            look_at_axis = (
                look_at_axis / np.linalg.norm(look_at_axis)
                if np.linalg.norm(look_at_axis) > 0
                else np.array([1, 0, 0])
            )

            # Compute camera position with offset
            camera_position = (
                camera_center + camera_offset_voxels * look_at_axis
            )  # Offset away from body

            # Compute orientation quaternion (rotate from [1,0,0] to look_at_axis)
            default_axis = np.array([1, 0, 0])

            # Compute the rotation axis and angle (to rotate default_axis to look_at_axis)
            rotation_axis = np.cross(default_axis, look_at_axis)

            # If rotation axis is zero (parallel vectors), use a perpendicular axis
            if np.linalg.norm(rotation_axis) < 1e-6:
                # If look_at_axis is parallel to default_axis, rotation is 0 or 180 degrees
                if np.dot(look_at_axis, default_axis) > 0:
                    # Same direction, no rotation needed
                    quaternion = np.array([0, 0, 0, 1])  # Identity quaternion
                else:
                    # Opposite direction, rotate 180 degrees around any perpendicular axis
                    quaternion = np.array([0, 1, 0, 0])  # 180 degree rotation around y
            else:
                # Normalize rotation axis
                rotation_axis = rotation_axis / np.linalg.norm(rotation_axis)

                # Compute rotation angle
                cos_angle = np.dot(default_axis, look_at_axis)
                angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))

                # Convert axis-angle to quaternion
                sin_half_angle = np.sin(angle / 2)
                cos_half_angle = np.cos(angle / 2)
                quaternion = np.array(
                    [
                        rotation_axis[0] * sin_half_angle,
                        rotation_axis[1] * sin_half_angle,
                        rotation_axis[2] * sin_half_angle,
                        cos_half_angle,
                    ]
                )

            # Find the rigid component that the camera is attached to
            # Use a slightly dilated camera mask to check surrounding rigid segments
            dilated_camera = binary_dilation(camera_mask, structure=RobotBuilder.create_dilation_structure(2))
            neighboring_segments = rigid_segment_labels[dilated_camera & voxels[1]]

            # Find the most common rigid segment ID (excluding 0 which is background)
            if len(neighboring_segments) > 0:
                # Get counts of each segment ID
                segment_counts = np.bincount(neighboring_segments.flatten())
                # Set count of background (0) to 0 to exclude it
                if len(segment_counts) > 0:
                    segment_counts[0] = 0

                # Get the segment ID with the highest count
                if np.max(segment_counts) > 2:
                    most_common_segment = np.argmax(segment_counts)
                else:
                    # Reject this camera
                    continue
            else:
                # Reject this camera
                continue

            if len(cameras) < limit_max_camera_num:
                cameras.append(
                    {
                        "position": camera_position,
                        "orientation": quaternion,
                        "component": most_common_segment,
                        "look_at_axis": look_at_axis,
                    }
                )

        # Add visualization at the end
        if debug:
            plotter = vedo.Plotter(
                shape=(1, 1),
                size=(1600, 1600),
                title="Camera Visualization",
                bg="white",
            )
            # Plot voxels[0] as transparent red
            body_vis = plot_binary_vedo(voxels[0], color=(1, 0, 0, 0.3), smooth=False)
            # Plot voxels[1] as transparent green
            rigid_vis = plot_binary_vedo(voxels[1], color=(0, 1, 0, 0.3), smooth=False)
            # Plot camera voxels as blue
            camera_vis = plot_binary_vedo(voxels[3], color=(0, 0, 1, 0.6), smooth=False)
            # Plot body center as a point
            body_center_point = vedo.Point(body_center, c="black", r=15)

            # Initialize lists to store camera visualization objects
            camera_points = []
            camera_arrows = []

            # Create visualization objects for all cameras
            for camera in cameras:
                # Create visualization objects for this camera
                # Point for camera position
                camera_point = vedo.Point(camera["position"], c="red", r=10)
                camera_points.append(camera_point)

                # Arrow for look-at axis (scale for better visibility)
                arrow_scale = 20.0  # Adjust based on scene scale
                look_at_arrow = vedo.Arrow(
                    camera["position"],
                    camera["position"] + arrow_scale * camera["look_at_axis"],
                    c="blue",
                )
                camera_arrows.append(look_at_arrow)

            # Show all visualization objects
            if cameras:
                plotter.show(
                    [body_vis, rigid_vis, camera_vis, body_center_point]
                    + camera_points
                    + camera_arrows,
                    resetcam=True,
                    interactive=True,
                ).close()

        return cameras

    @staticmethod
    def compute_top_two_rigid_segment_volumes(rigid_segment_labels: np.ndarray):
        """
        Compute the voxel counts of the largest and the second largest rigid segments.

        Args:
            rigid_segment_labels: 3D int array, background=0, positive integers are segment ids.

        Returns:
            (largest_volume, second_largest_volume)
        """
        # Count voxels per label
        counts = np.bincount(rigid_segment_labels.flatten())
        top2 = np.sort(counts[1:])[-2:]
        return int(top2[1]), int(top2[0])


    @staticmethod
    def calculate_joint_axis(voxels: np.ndarray, joints: List[dict], cameras: List[dict]):
        """
        Recompute joint axes using PCA on voxels[0] (occupied). Use the smallest
        principal component as +Z. Among the remaining two components, choose as +X
        the one most aligned (by absolute dot) with cameras[0]['look_at_axis'] and orient it
        to face that look direction. Then set v_world = 0.3*X + 1.0*Z, normalize, and
        assign to each joint['axis'].
        """

        occ = voxels[0]

        # Collect occupied voxel coordinates and center them
        coords = np.array(np.where(occ)).T.astype(float)  # [N,3] in (x,y,z) index space
        center = coords.mean(axis=0, keepdims=True)
        coords_centered = coords - center

        # PCA via covariance eigen-decomposition (symmetric, real)
        # cov shape [3,3]; eigh returns ascending eigenvalues
        cov = np.cov(coords_centered, rowvar=False)
        evals, evecs = np.linalg.eigh(cov)  # columns of evecs are eigenvectors

        # Sort ascending by eigenvalue: smallest -> index 0
        order = np.argsort(evals)
        evecs = evecs[:, order]
        # z_hat: smallest variance (thinnest) direction
        z_hat = evecs[:, 0]

        # Determine x_hat using camera look axis if available
        cand = [evecs[:, 1], evecs[:, 2]]
        
        look = np.asarray(cameras[0]['look_at_axis'], dtype=float)
        look = look / (np.linalg.norm(look) + 1e-12)
        dots = [float(abs(np.dot(look, c))) for c in cand]
        x_hat = cand[int(np.argmax(dots))]
        # Orient x_hat to point towards look
        if np.dot(look, x_hat) < 0:
            x_hat = -x_hat


        # Ensure orthonormal right-handed frame
        z_hat = z_hat / (np.linalg.norm(z_hat) + 1e-12)
        x_hat = x_hat / (np.linalg.norm(x_hat) + 1e-12)
        y_hat = np.cross(z_hat, x_hat)
        if np.linalg.norm(y_hat) < 1e-8:
            # Degenerate case: pick the remaining eigenvector
            rem = evecs[:, 1] if np.allclose(x_hat, evecs[:, 2]) else evecs[:, 2]
            y_hat = np.cross(z_hat, rem)
        y_hat = y_hat / (np.linalg.norm(y_hat) + 1e-12)
        # Re-orthogonalize X for numerical robustness
        x_hat = np.cross(y_hat, z_hat)
        x_hat = x_hat / (np.linalg.norm(x_hat) + 1e-12)

        # Map [0.3, 0, 1] from PCA-frame to world
        v_world = 0.3 * x_hat + 1.0 * z_hat
        v_world = v_world / (np.linalg.norm(v_world) + 1e-12)

        for j in joints:
            j['axis'] = v_world.astype(float)

        return joints

    @staticmethod
    def create_dilation_structure(dilation_size: int):
        structure = np.zeros(
            (dilation_size * 2 + 1, dilation_size * 2 + 1, dilation_size * 2 + 1),
            dtype=bool,
        )
        s_indices = np.indices(structure.shape)
        structure[s_indices[0], s_indices[1], s_indices[2]] = (
            np.linalg.norm(s_indices - dilation_size, axis=0) <= dilation_size
        )
        return structure
    

    
    @staticmethod
    def find_cycle(connection_graph: List[Tuple[int, int]], node_num: int):
        """
        Args:
            connection_graph: Edges of a graph, node index start from 1
            node_num: Number of nodes in the graph, from 1 to N

        Returns:
            None if cycle is not found, else a list containing every node
            in the cycle.
        """
        # DFS algorithm
        adjacency = [set() for _ in range(node_num)]
        # 0: not visited, 1: visited but not finished cycle checking
        # 2: visited and finished cycle checking
        state = [0 for _ in range(node_num)]
        parent = [None for _ in range(node_num)]
        cycle_start, cycle_end = None, None
        for connection in connection_graph:
            adjacency[connection[0] - 1].add(connection[1] - 1)
            adjacency[connection[1] - 1].add(connection[0] - 1)

        def dfs_find_cycle(node: int, parent_node: int):
            nonlocal cycle_start, cycle_end
            state[node] = 1
            for adj_node in adjacency[node]:
                if adj_node == parent_node:
                    continue
                if state[adj_node] == 0:
                    parent[adj_node] = node
                    if dfs_find_cycle(adj_node, node):
                        return True
                else:
                    cycle_end = node
                    cycle_start = adj_node
                    return True
            state[node] = 2
            return False

        for node in range(node_num):
            if state[node] == 0 and dfs_find_cycle(node, parent[node]):
                break

        if cycle_end is None:
            return None
        else:
            cycle = []
            while cycle_end != cycle_start:
                cycle.append(cycle_end + 1)
                cycle_end = parent[cycle_end]
            cycle.append(cycle_start + 1)
            cycle = list(reversed(cycle))
            return cycle

    @staticmethod
    def convolve_gather(input: np.ndarray, structure: np.ndarray):
        """
        Args:
            input: 3D numpy int array of shape [X, Y, Z]
            structure: 3D numpy bool mask array

        Returns:
            Neighboring values selected by structure at every voxel
            Eg: suppose there is a 3x3x3 mask with 6 connectivity,
            7 elements will be selected. so output size would be
            [X, Y, Z, 7]
        """
        padded_input = np.zeros(
            [input.shape[i] + structure.shape[i] - 1 for i in range(3)],
            dtype=input.dtype,
        )
        pad_neg = [structure.shape[i] // 2 for i in range(3)]
        padded_input[
            pad_neg[0] : pad_neg[0] + input.shape[0],
            pad_neg[1] : pad_neg[1] + input.shape[1],
            pad_neg[2] : pad_neg[2] + input.shape[2],
        ] = input
        indices = np.indices(input.shape).reshape(3, -1, 1)
        offsets = np.indices(structure.shape).reshape(3, 1, -1)
        full_indices = indices + offsets
        all_elements = padded_input[full_indices[0], full_indices[1], full_indices[2]]
        selected_elements = all_elements[:, structure.flatten()]
        return selected_elements.reshape(list(input.shape) + [-1])

    @staticmethod
    def raycast_bresenham(start_pos, directions, voxel_grid, max_steps=30):
        """
        Cast rays using 3D Bresenham's line algorithm and return traversable distances.

        Args:
            start_pos: [3] array with starting position
            directions: [N, 3] array with N unit direction vectors
            voxel_grid: 3D boolean array where True (1) represents obstacle voxels
            max_steps: Maximum number of steps to raycast in each direction

        Returns:
            Array of length N with the total traversable distance for each direction
        """
        grid_shape = voxel_grid.shape
        results = np.zeros(len(directions))

        # Round start position to integer coordinates
        start = np.round(start_pos).astype(int)

        for i, direction in enumerate(directions):
            total_distance = 0

            # Cast in both positive and negative directions
            for direction_sign in [1, -1]:
                # Set up Bresenham algorithm variables
                x, y, z = start
                dx, dy, dz = direction * direction_sign

                # Normalize to get step directions
                sx = 1 if dx > 0 else -1 if dx < 0 else 0
                sy = 1 if dy > 0 else -1 if dy < 0 else 0
                sz = 1 if dz > 0 else -1 if dz < 0 else 0

                # Get absolute values for Bresenham
                dx, dy, dz = abs(dx), abs(dy), abs(dz)

                # Determine dominant axis for Bresenham
                if dx >= dy and dx >= dz:
                    err1 = 2 * dy - dx
                    err2 = 2 * dz - dx
                    for _ in range(max_steps):
                        # Check if we're still inside grid bounds
                        if not (
                            0 <= x < grid_shape[0]
                            and 0 <= y < grid_shape[1]
                            and 0 <= z < grid_shape[2]
                        ):
                            total_distance += max_steps  # Hit boundary
                            break

                        # Check if we hit a rigid voxel
                        if voxel_grid[x, y, z]:
                            break

                        # Count empty voxel
                        total_distance += 1

                        # Bresenham stepping
                        if err1 > 0:
                            y += sy
                            err1 -= 2 * dx
                        if err2 > 0:
                            z += sz
                            err2 -= 2 * dx

                        err1 += 2 * dy
                        err2 += 2 * dz
                        x += sx

                elif dy >= dx and dy >= dz:
                    err1 = 2 * dx - dy
                    err2 = 2 * dz - dy
                    for _ in range(max_steps):
                        # Check if we're still inside grid bounds
                        if not (
                            0 <= x < grid_shape[0]
                            and 0 <= y < grid_shape[1]
                            and 0 <= z < grid_shape[2]
                        ):
                            total_distance += max_steps  # Hit boundary
                            break

                        # Check if we hit a rigid voxel
                        if voxel_grid[x, y, z]:
                            break

                        # Count empty voxel
                        total_distance += 1

                        # Bresenham stepping
                        if err1 > 0:
                            x += sx
                            err1 -= 2 * dy
                        if err2 > 0:
                            z += sz
                            err2 -= 2 * dy

                        err1 += 2 * dx
                        err2 += 2 * dz
                        y += sy

                else:  # dz is dominant
                    err1 = 2 * dy - dz
                    err2 = 2 * dx - dz
                    for _ in range(max_steps):
                        # Check if we're still inside grid bounds
                        if not (
                            0 <= x < grid_shape[0]
                            and 0 <= y < grid_shape[1]
                            and 0 <= z < grid_shape[2]
                        ):
                            total_distance += max_steps  # Hit boundary
                            break

                        # Check if we hit a rigid voxel
                        if voxel_grid[x, y, z]:
                            break

                        # Count empty voxel
                        total_distance += 1

                        # Bresenham stepping
                        if err1 > 0:
                            y += sy
                            err1 -= 2 * dz
                        if err2 > 0:
                            x += sx
                            err2 -= 2 * dz

                        err1 += 2 * dy
                        err2 += 2 * dx
                        z += sz

            results[i] = total_distance

        return results

    @staticmethod
    def compute_vertical_joint_axis(cutting_normal: np.ndarray, up_vector: np.ndarray = np.array([0, 0, 1])) -> np.ndarray:
        """
        Compute a hinge joint axis that is as aligned with the provided up_vector (default global +Z)
        as possible while remaining perpendicular to the provided cutting_normal.

        This favors more vertical hinges for near-spherical joints.
        """
        cn = cutting_normal.astype(float)
        cn_norm = np.linalg.norm(cn)
        if cn_norm == 0:
            cn = np.array([0.0, 0.0, 1.0])
            cn_norm = 1.0
        cn = cn / cn_norm

        up = up_vector.astype(float)
        up_norm = np.linalg.norm(up)
        if up_norm == 0:
            up = np.array([0.0, 0.0, 1.0])
            up_norm = 1.0
        up = up / up_norm

        # Project up onto the plane perpendicular to cn
        axis = up - np.dot(up, cn) * cn
        axis_norm = np.linalg.norm(axis)

        if axis_norm < 1e-6:
            # cutting_normal parallel to up; choose any stable perpendicular axis (prefer X, then Y)
            candidate = np.cross(cn, np.array([1.0, 0.0, 0.0]))
            cand_norm = np.linalg.norm(candidate)
            if cand_norm < 1e-6:
                candidate = np.cross(cn, np.array([0.0, 1.0, 0.0]))
                cand_norm = np.linalg.norm(candidate)
            if cand_norm < 1e-12:
                # Degenerate fallback
                axis = np.array([0.0, 0.0, 1.0])
            else:
                axis = candidate / cand_norm
        else:
            axis = axis / axis_norm

        # Make it point upward if possible
        if np.dot(axis, up) < 0:
            axis = -axis

        return axis

    @staticmethod
    def add_random_axis_perturbation(axis: np.ndarray, normal: np.ndarray, max_degrees: float = 5.0) -> np.ndarray:
        """
        Apply a small random rotation to the axis around the provided normal to keep
        the axis close to perpendicular with the cutting plane while adding variety.

        - axis: base hinge axis (will be normalized)
        - normal: cutting plane normal to rotate around (will be normalized)
        - max_degrees: maximum absolute rotation angle in degrees
        """
        a = axis.astype(float)
        if np.linalg.norm(a) == 0:
            a = np.array([0.0, 0.0, 1.0])
        a = a / max(np.linalg.norm(a), 1e-12)

        n = normal.astype(float)
        if np.linalg.norm(n) == 0:
            n = np.array([0.0, 0.0, 1.0])
        n = n / max(np.linalg.norm(n), 1e-12)

        # Ensure axis is perpendicular to normal before rotation (project out any tiny component)
        a = a - np.dot(a, n) * n
        a = a / max(np.linalg.norm(a), 1e-12)

        # Sample a small random angle in radians
        max_radians = np.deg2rad(max_degrees)
        # Choose randomly between negative and positive range
        if np.random.random() < 0.5:
            theta = np.random.uniform(-max_radians, -max_radians/2)
        else:
            theta = np.random.uniform(max_radians/2, max_radians)

        # Rodrigues' rotation formula for rotating vector a around n by theta
        a_rot = (
            a * np.cos(theta)
            + np.cross(n, a) * np.sin(theta)
            + n * np.dot(n, a) * (1 - np.cos(theta))
        )

        # Reproject to be safe and normalize
        a_rot = a_rot - np.dot(a_rot, n) * n
        a_rot = a_rot / max(np.linalg.norm(a_rot), 1e-12)

        return a_rot

    @staticmethod
    def sphericalize_voxels(voxels, smoothing_factor=1.5):

        coords = np.array(np.where(voxels)).T
        center = np.mean(coords, axis=0)
        
        x, y, z = np.meshgrid(
            np.arange(voxels.shape[0]),
            np.arange(voxels.shape[1]),
            np.arange(voxels.shape[2]),
            indexing='ij'
        )
        
        center_distances = np.sqrt(
            (x - center[0])**2 + 
            (y - center[1])**2 + 
            (z - center[2])**2
        )
        
        coords_centered = coords - center

        cov_matrix = np.cov(coords_centered.T)
        eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)

        projections = coords_centered @ eigenvectors
        axis_ranges = np.max(projections, axis=0) - np.min(projections, axis=0)

        radius = np.min(axis_ranges)

        sphere_mask = center_distances <= (radius * smoothing_factor)
        
        result = sphere_mask
        
        return result, center

    @staticmethod
    def safe_binary_closing(volume: np.ndarray, structure: np.ndarray, iterations: int = 1):
        """
        Perform binary closing with boundary padding to avoid edge effects
        
        Args:
            volume: 3D boolean array to process
            structure: Structuring element for morphological operations
            iterations: Number of iterations for closing
            
        Returns:
            Processed volume with same shape as input
        """
        # Calculate padding size based on structure size
        pad_size = max(structure.shape) // 2 + iterations
        
        # Pad the volume with False values
        padded_volume = np.pad(
            volume, 
            pad_width=pad_size, 
            mode='constant', 
            constant_values=False
        )
        
        # Perform closing on padded volume
        closed_padded = binary_closing(
            padded_volume,
            structure=structure,
            iterations=iterations
        )
        
        # Extract the original region (remove padding)
        result = closed_padded[
            pad_size:-pad_size,
            pad_size:-pad_size,
            pad_size:-pad_size
        ]
        
        return result


def calculate_robot_orientation(voxels):
    """
    Calculate the main orientation of the robot based on voxel distribution.

    Args:
        voxels: Robot voxel structure with shape [4, x, y, z]

    Returns:
        np.array: Main orientation vector (normalized)
    """
    # Use the rigid structure (channel 1) to determine orientation
    rigid_voxels = voxels[1] > 0  # Binary mask for rigid parts

    # Get coordinates of all rigid voxels
    coords = []
    x_size, y_size, z_size = voxels.shape[1], voxels.shape[2], voxels.shape[3]

    for z in range(z_size):
        for y in range(y_size):
            for x in range(x_size):
                if rigid_voxels[x, y, z]:
                    coords.append([x, y, z])

    if len(coords) < 3:  # Need at least 3 points for PCA
        # Fallback to default orientation if not enough rigid voxels
        return np.array([0, 0, 1])  # Default upward direction

    coords = np.array(coords)

    # Center the coordinates
    centered = coords - np.mean(coords, axis=0)

    # Calculate covariance matrix
    cov_matrix = np.cov(centered.T)

    # Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Get the main axis (largest eigenvalue)
    main_axis = eigenvectors[:, np.argmax(eigenvalues)]

    # Normalize and return negative direction (as in the reference code)
    return -main_axis / np.linalg.norm(main_axis)


def visualize(
    robot_structure,
    original_robot_structure=None,
    interactive=False,
    root_dir=None,
    resolution=128,
):
    """
    Create a simple visualization of the robot structure.

    Args:
        interactive: Whether to show an interactive visualization window

    Returns:
        PIL Image object if output_path is provided, None otherwise
    """
    voxels = robot_structure["voxels"]
    if original_robot_structure is None:
        plotter = vedo.Plotter(
            shape=(1, 3),
            size=(2400, 800),
            title="Robot Visualization",
            bg="white",
            offscreen=not interactive,
        )
    else:
        plotter = vedo.Plotter(
            shape=(1, 4),
            size=(3200, 800),
            title="Robot Visualization",
            bg="white",
            offscreen=not interactive,
        )

    # Create Voxel Visualization
    orientation = calculate_robot_orientation(voxels)

    # Calculate camera position
    pos_dir = np.cross(orientation, np.array([0, -1, 0]))
    if np.linalg.norm(pos_dir) < 1e-6:  # If cross product is nearly zero
        # Use alternative direction
        pos_dir = np.cross(orientation, np.array([1, 0, 0]))

    if np.linalg.norm(pos_dir) > 1e-6:
        pos_dir = pos_dir / np.linalg.norm(pos_dir)
    else:
        # Fallback to default position
        pos_dir = np.array([1, 0, 0])

    cam_pos = {
        "position": list(
            300 * pos_dir
            + np.array([resolution // 2, resolution // 2, resolution // 2])
        ),
        "focal_point": np.array([resolution // 2, resolution // 2, resolution // 2]),
        "viewup": -np.array(orientation),
    }

    is_rigid = plot_binary_vedo(voxels[1], color=(0, 1, 0, 0.3), smooth=False)
    is_joint = plot_binary_vedo(voxels[2], color=(0, 0, 1, 1), smooth=False)
    is_camera = plot_binary_vedo(voxels[3], color=(1, 0, 1, 1), smooth=False)
    is_not_empty = plot_binary_vedo(voxels[0], color=(1, 0, 0, 0.2), smooth=False)
    plotter.show(
        (is_not_empty, is_rigid, is_joint, is_camera),
        at=(0 if original_robot_structure is None else 1),
        camera=cam_pos,
        resetcam=False,
    )

    # Create Surface Visualization
    is_not_empty_surface = plot_binary_vedo(
        voxels[0], color=(0.5, 0.5, 0.5, 1), smooth=True
    )
    plotter.show(
        (is_not_empty_surface),
        at=(2 if original_robot_structure is None else 3),
        camera=cam_pos,
        resetcam=False,
    )

    # Create Rigid Segment And Camera Visualization
    rigid_segment_labels = robot_structure["rigid_segment_labels"]
    joints = robot_structure["joints"]
    cameras = robot_structure["cameras"]

    # Visualize rigid segments with unique colors
    rigid_segments_vis = plot_segment_id_vedo(rigid_segment_labels, smooth=False)

    # Visualize joints with arrows
    joint_vis = []
    for joint in joints:
        arrow = vedo.Arrow(
            joint["position"],
            joint["position"] + 20 * joint["axis"],
            c="red",
        )
        joint_vis.append(arrow)

    # Visualize cameras with points and arrows
    camera_vis = []
    for camera in cameras:
        # Point for camera position
        camera_point = vedo.Point(camera["position"], c="magenta", r=10)
        camera_vis.append(camera_point)

        # Arrow for camera direction
        arrow_scale = 20.0
        look_direction = camera["look_at_axis"]
        look_at_arrow = vedo.Arrow(
            camera["position"],
            camera["position"] + arrow_scale * look_direction,
            c="blue",
        )
        camera_vis.append(look_at_arrow)

    # Show all visualization objects in the second panel
    plotter.show(
        [rigid_segments_vis] + joint_vis + camera_vis,
        at=(1 if original_robot_structure is None else 2),
        camera=cam_pos,
        resetcam=False,
    )

    if original_robot_structure is not None:
        voxels_ori = original_robot_structure
        is_not_empty_ori = plot_binary_vedo(
            voxels_ori[0], color=(1, 0, 0, 0.2), smooth=False
        )
        is_rigid_ori = plot_binary_vedo(
            voxels_ori[1], color=(0, 1, 0, 0.3), smooth=False
        )
        is_joint_ori = plot_binary_vedo(voxels_ori[2], color=(0, 0, 1, 1), smooth=False)
        is_camera_ori = plot_binary_vedo(
            voxels_ori[3], color=(1, 0, 1, 1), smooth=False
        )
        plotter.show(
            (is_not_empty_ori, is_rigid_ori, is_joint_ori, is_camera_ori),
            at=0,
            camera=cam_pos,
            resetcam=False,
        )

    # Optionally return the image
    if not interactive:
        img = plotter.screenshot(asarray=True)
        return Image.fromarray(img)
    else:
        plotter.interactive()
        return None


def visualize_simple(robot_structure, interactive=False, root_dir=None, resolution=128):
    """
    Create a simple visualization of the robot structure.

    Args:
        robot_structure: Robot structure data (dict or voxels array)
        interactive: Whether to show an interactive visualization window
        root_dir: Optional root direction vector. If None, will be calculated automatically
        resolution: Resolution of the voxel grid

    Returns:
        PIL Image object if output_path is provided, None otherwise
    """
    if isinstance(robot_structure, dict):
        voxels = robot_structure["voxels"]
    else:
        voxels = robot_structure

    plotter = vedo.Plotter(
        shape=(1, 1),
        size=(800, 800),
        title="Robot Visualization",
        bg="white",
        offscreen=not interactive,
    )

    # Create Voxel Visualization
    is_not_empty = plot_binary_vedo(voxels[0], color=(1, 0, 0, 0.2), smooth=False)
    is_rigid = plot_binary_vedo(voxels[1], color=(0, 1, 0, 0.3), smooth=False)
    is_joint = plot_binary_vedo(voxels[2], color=(0, 0, 1, 1), smooth=False)
    is_camera = plot_binary_vedo(voxels[3], color=(1, 0, 1, 1), smooth=False)
    # Add boundary visualization with 2 voxel size

    # Calculate camera position based on root_dir or computed orientation
    if root_dir is not None:
        # Use provided root_dir
        orientation = np.array(root_dir)
    else:
        # Calculate orientation automatically
        orientation = calculate_robot_orientation(voxels)
        # print(f"Calculated robot orientation: {orientation}")

    # Calculate camera position
    pos_dir = np.cross(orientation, np.array([0, -1, 0]))
    if np.linalg.norm(pos_dir) < 1e-6:  # If cross product is nearly zero
        # Use alternative direction
        pos_dir = np.cross(orientation, np.array([1, 0, 0]))

    if np.linalg.norm(pos_dir) > 1e-6:
        pos_dir = pos_dir / np.linalg.norm(pos_dir)
    else:
        # Fallback to default position
        pos_dir = np.array([1, 0, 0])

    cam_pos = {
        "position": list(
            300 * pos_dir
            + np.array([resolution // 2, resolution // 2, resolution // 2])
        ),
        "focal_point": np.array([resolution // 2, resolution // 2, resolution // 2]),
        "viewup": -np.array(orientation),
    }

    plotter.show(
        (is_not_empty, is_rigid, is_joint, is_camera),
        camera=cam_pos,
        at=0,
        resetcam=False,
    )

    # Optionally return the image
    if not interactive:
        img = plotter.screenshot(asarray=True)
        return Image.fromarray(img)
    else:
        plotter.interactive()
        return None


def visualize_comprehensive(robot_structure, interactive=False):
    """
    Create a comprehensive visualization of the robot structure for publication.

    Args:
        output_path: Optional path to save the visualization image
        interactive: Whether to show an interactive visualization window

    Returns:
        PIL Image object if output_path is provided, None otherwise
    """
    if robot_structure is None:
        print("Cannot visualize an invalid robot structure. Run build() first.")
        return None

    # Get the data from the log
    voxels = robot_structure["voxels"]
    rigid_segments = robot_structure["rigid_segment_labels"]
    joints = robot_structure["joints"]
    cameras = robot_structure["cameras"]

    # Create a multi-panel visualization: 2x3 grid
    plotter = vedo.Plotter(
        shape=(1, 4),
        size=(3600, 900),
        title="Robot Structure Visualization",
        bg="white",
        offscreen=not interactive,
        sharecam=False,
    )

    # Panel 0: Global 45-degree diagonal view of the complete robot
    _visualize_full_structure(
        plotter,
        voxels,
        rigid_segments,
        joints,
        cameras,
        at=0,
        title="Complete Robot Structure",
    )

    # Panel 1: Front view with functional annotations
    _visualize_functional_view(
        plotter,
        voxels,
        rigid_segments,
        joints,
        cameras,
        at=1,
        title="Front View with Functional Components",
    )

    # Panel 2: Exploded view showing rigid segments separated
    _visualize_exploded_view(
        plotter,
        voxels,
        rigid_segments,
        joints,
        cameras,
        at=2,
        title="Exploded View of Rigid Segments",
    )

    # Panel 3: Cross-section view showing internal structure
    _visualize_cross_section(
        plotter, voxels, rigid_segments, joints, at=3, title="Cross-Section View"
    )
    #
    # # Panel 4: Detailed joint visualization
    # _visualize_joints_detailed(
    #     plotter, voxels, rigid_segments, joints, at=4, title="Joint Mechanics"
    # )
    #
    # # Panel 5: Camera visualization with FOV and orientation
    # _visualize_cameras_detailed(
    #     plotter,
    #     voxels,
    #     rigid_segments,
    #     cameras,
    #     at=5,
    #     title="Camera Orientation and Field of View",
    # )
    #
    # # Add a global caption and color legend
    # _add_global_annotations(plotter)

    # Show the visualization
    if interactive:
        plotter.show(interactive=True)
        return None
    else:
        plotter.show(interactive=False)

        # Capture the screenshot
        screenshot = plotter.screenshot(asarray=True)

        # Convert to PIL Image
        img = Image.fromarray(screenshot)

        # Close plotter to free resources
        plotter.close()

        return img


def _visualize_full_structure(
    plotter,
    voxels,
    rigid_segments,
    joints,
    cameras,
    at=0,
    title="",
):
    """Visualize the full robot structure from a specific angle."""
    # Create isosurfaces for each channel with appropriate colors and transparencies
    # Soft body (translucent light red)
    iso_non_empty = plot_binary_vedo(voxels[0], color=(1, 0, 0, 0.05), smooth=False)

    # Rigid structure (translucent green)
    iso_rigid = plot_binary_vedo(voxels[1], color=(0, 1, 0, 0.3), smooth=False)

    # Joints (solid blue)
    iso_joint = plot_binary_vedo(voxels[2], color=(0, 0, 1, 1), smooth=False)

    # Cameras (solid magenta)
    iso_camera = plot_binary_vedo(voxels[3], color=(1, 0, 1, 1), smooth=False)

    # Add joint axes and rotation indicators
    joint_vis = []
    for joint in joints:
        # Get joint position and axis
        position = joint["position"]
        axis = joint["axis"]

        # Create an arrow to represent the axis
        arrow = vedo.Arrow(position - axis * 5, position + axis * 5, c="cyan")

        # Create a small cylinder at the joint position
        cylinder = vedo.Cylinder(pos=position, r=3, axis=axis, c="red")

        # Create a circle to represent the rotation plane
        # First, find a perpendicular vector to the axis
        perp = np.array([1, 0, 0])
        if np.abs(np.dot(perp, axis)) > 0.9:
            perp = np.array([0, 1, 0])
        perp = np.cross(perp, axis)
        perp = perp / np.linalg.norm(perp)

        # Add a curved arrow to indicate rotation direction
        points = []
        for i in range(19):
            angle = i * np.pi / 9  # 0 to 210 degrees
            # Rotate perp around axis
            point = position + 4 * (
                perp * np.cos(angle) + np.cross(axis, perp) * np.sin(angle)
            )
            points.append(point)

        arc = vedo.Line(points, c="yellow", lw=4)

        # Add an arrowhead at the end
        arrow_axis = points[-2] - points[-1]
        arrowhead = vedo.Cone(
            pos=points[-1],
            axis=arrow_axis / np.linalg.norm(arrow_axis),
            r=0.8,
            height=1.5,
            c="yellow",
        )

        joint_vis.extend([arrow, cylinder, arc, arrowhead])

    # Add camera frustums
    camera_vis = []
    for camera in cameras:
        # Get camera position and orientation
        position = camera["position"] - camera["look_at_axis"] * 10
        orientation = camera["orientation"]

        # Create a small cube to represent the camera
        cube = vedo.Cube(pos=position, side=3, c="magenta")

        # Calculate the direction vector from quaternion
        qx, qy, qz, qw = orientation
        direction = np.array(
            [
                1 - 2 * qy * qy - 2 * qz * qz,
                2 * qx * qy + 2 * qw * qz,
                2 * qx * qz - 2 * qw * qy,
            ]
        )

        # Calculate up vector from quaternion
        up = np.array(
            [
                2 * qx * qy - 2 * qw * qz,
                1 - 2 * qx * qx - 2 * qz * qz,
                2 * qy * qz + 2 * qw * qx,
            ]
        )

        # Calculate right vector
        right = np.cross(direction, up)

        # Create FOV pyramid (assuming 45-degree FOV)
        fov_tan = np.tan(np.radians(45 / 2))
        distance = 10  # Distance to near plane
        height = distance * fov_tan
        width = height * 1.33  # Aspect ratio 4:3

        # Calculate corners of the FOV pyramid
        top_left = position + distance * direction + height * up - width * right
        top_right = position + distance * direction + height * up + width * right
        bottom_left = position + distance * direction - height * up - width * right
        bottom_right = position + distance * direction - height * up + width * right

        # Create the far plane
        far_distance = 20
        far_scale = far_distance / distance

        # Create lines for the FOV pyramid
        lines = [
            vedo.Line([position, top_left], c="black", lw=4),
            vedo.Line([position, top_right], c="black", lw=4),
            vedo.Line([position, bottom_left], c="black", lw=4),
            vedo.Line([position, bottom_right], c="black", lw=4),
            vedo.Line([top_left, top_right], c="black", lw=4),
            vedo.Line([top_left, bottom_left], c="black", lw=4),
            vedo.Line([top_right, bottom_right], c="black", lw=4),
            vedo.Line([bottom_left, bottom_right], c="black", lw=4),
        ]

        # Add a line in the view direction
        view_line = vedo.Arrow(
            position,
            position + direction * 15,
            c="magenta",
        )
        camera_vis.extend([cube, view_line] + lines)

    # Combine all visualizations
    all_objects = (
        [iso_non_empty, iso_rigid, iso_joint, iso_camera] + joint_vis + camera_vis
    )

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show in the specified panel
    plotter.show(all_objects, at=at)
    plotter.add(caption, at=at)


def _visualize_functional_view(
    plotter,
    voxels,
    rigid_segments,
    joints,
    cameras,
    at=1,
    title="",
):
    """Visualize the robot with functional annotations."""
    # Create a semi-transparent view of the overall structure
    iso_non_empty = plot_binary_vedo(
        voxels[0], color=(0.8, 0.8, 0.8, 0.2), smooth=False
    )

    # Create a more detailed visualization of functional components

    # Visualize rigid segments with distinct colors
    unique_segments = np.unique(rigid_segments)
    unique_segments = unique_segments[unique_segments > 0]

    segment_colors = [
        (1, 0, 0, 0.7),  # Red
        (0, 1, 0, 0.7),  # Green
        (0, 0, 1, 0.7),  # Blue
        (1, 1, 0, 0.7),  # Yellow
        (1, 0, 1, 0.7),  # Magenta
        (0, 1, 1, 0.7),  # Cyan
        (1, 0.5, 0, 0.7),  # Orange
        (0.5, 0, 1, 0.7),  # Purple
    ]

    segment_vis = []
    segment_labels = []

    for i, segment_id in enumerate(unique_segments):
        # Create a binary mask for this segment
        segment_mask = rigid_segments == segment_id

        # Use a color from the list, cycling if necessary
        color = segment_colors[i % len(segment_colors)]

        # Create an isosurface for this segment
        segment_surface = plot_binary_vedo(segment_mask, color=color, smooth=False)
        segment_vis.append(segment_surface)

        # Calculate centroid for label placement
        indices = np.where(segment_mask)
        if len(indices[0]) > 0:
            centroid = np.array(
                [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])]
            )

            # Add a label pointing to the segment
            label = vedo.Text3D(
                f"Segment {segment_id}",
                pos=centroid + np.array([5, 5, 5]),
                s=2,
                c="black",
            )
            segment_labels.append(label)

    # Add joint annotations
    joint_vis = []
    joint_labels = []

    for i, joint in enumerate(joints):
        position = joint["position"]
        axis = joint["axis"]
        components = joint["components"]

        # Create a highlighted visualization of the joint
        joint_sphere = vedo.Sphere(
            pos=position,
            r=3,
            c="blue",
        )

        # Add a clear axis indicator
        joint_arrow = vedo.Arrow(
            position - axis * 5,
            position + axis * 5,
            c="cyan",
        )

        joint_vis.extend([joint_sphere, joint_arrow])

        # Add a label describing the joint function
        label = vedo.Text3D(
            f"Joint {i + 1}\nConnects: {components[0]}-{components[1]}",
            pos=position + np.array([8, 0, 0]),
            s=2,
            c="black",
        )
        joint_labels.append(label)

    # Add camera annotations
    camera_vis = []
    camera_labels = []

    for i, camera in enumerate(cameras):
        position = camera["position"]
        orientation = camera["orientation"]
        component = camera["component"]

        # Create a highlighted visualization of the camera
        camera_cube = vedo.Cube(
            pos=position,
            side=4,
            c="magenta",
        )

        qx, qy, qz, qw = orientation
        direction = np.array(
            [
                1 - 2 * qy * qy - 2 * qz * qz,
                2 * qx * qy + 2 * qw * qz,
                2 * qx * qz - 2 * qw * qy,
            ]
        )

        # Add a direction indicator
        camera_arrow = vedo.Arrow(
            position,
            position + direction * 10,
            c="magenta",
        )

        camera_vis.extend([camera_cube, camera_arrow])

        # Add a label describing the camera
        label = vedo.Text3D(
            f"Camera {i + 1}\nOn Segment: {component}",
            pos=position + np.array([0, 8, 0]),
            s=2,
            c="black",
        )
        camera_labels.append(label)

    # Combine all visualizations
    all_objects = (
        [iso_non_empty]
        + segment_vis
        + joint_vis
        + camera_vis
        + segment_labels
        + joint_labels
        + camera_labels
    )

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show in the specified panel
    plotter.show(all_objects, at=at)
    plotter.add(caption, at=at)


def _visualize_exploded_view(
    plotter, voxels, rigid_segments, joints, cameras, at=2, title=""
):
    """Create an exploded view to show individual rigid segments separated."""

    # Visualization components
    vis_objects = []

    # Process rigid segments
    unique_segments = np.unique(rigid_segments)
    unique_segments = unique_segments[unique_segments > 0]

    # Define displacement vectors for explosion
    displacement_directions = [
        np.array([1, 1, 1]),
        np.array([1, -1, 1]),
        np.array([-1, 1, 1]),
        np.array([-1, -1, 1]),
        np.array([1, 1, -1]),
        np.array([1, -1, -1]),
        np.array([-1, 1, -1]),
        np.array([-1, -1, -1]),
    ]

    # Find the center of mass of the entire robot
    is_not_empty = voxels[0]
    x_indices, y_indices, z_indices = np.indices(voxels.shape[1:])
    center = np.array(
        [
            np.mean(x_indices[is_not_empty]),
            np.mean(y_indices[is_not_empty]),
            np.mean(z_indices[is_not_empty]),
        ]
    )

    # Calculate centroids for each segment
    segment_centroids = {}

    for segment_id in unique_segments:
        segment_mask = rigid_segments == segment_id
        indices = np.where(segment_mask)
        if len(indices[0]) > 0:
            centroid = np.array(
                [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])]
            )
            segment_centroids[segment_id] = centroid

    # Draw arrows between connected segments
    connections_drawn = set()

    for joint in joints:
        components = joint["components"]

        # Only process each connection once
        if (components[0], components[1]) in connections_drawn or (
            components[1],
            components[0],
        ) in connections_drawn:
            continue

        connections_drawn.add((components[0], components[1]))

        # Ensure both components exist in centroids
        if components[0] in segment_centroids and components[1] in segment_centroids:
            start = segment_centroids[components[0]]
            end = segment_centroids[components[1]]

            # Create a dashed line connecting the segments
            connection = vedo.Line(
                start,
                end,
                c="gray",
                lw=2,
                alpha=0.7,
            ).pattern("--")

            vis_objects.append(connection)

    # Create an exploded view of the rigid segments
    for i, segment_id in enumerate(unique_segments):
        segment_mask = rigid_segments == segment_id

        # Get a unique color for this segment
        segment_color = (
            0.2 + 0.8 * np.random.random(),
            0.2 + 0.8 * np.random.random(),
            0.2 + 0.8 * np.random.random(),
            0.9,
        )

        # Create an isosurface for this segment
        segment_surface = plot_binary_vedo(
            segment_mask, color=segment_color, smooth=False
        )

        # Get centroid of this segment
        centroid = segment_centroids[segment_id]

        # Calculate displacement vector from center
        direction = centroid - center
        if np.linalg.norm(direction) > 0:
            direction = direction / np.linalg.norm(direction)
        else:
            direction = displacement_directions[i % len(displacement_directions)]

        # Apply explosion displacement
        explosion_distance = 15
        segment_surface.shift(*(direction * explosion_distance))

        # Add a segment label
        label = vedo.Text3D(
            f"Segment {segment_id}",
            pos=centroid + direction * (explosion_distance + 5),
            s=2,
            c="black",
        )

        vis_objects.extend([segment_surface, label])

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show in the specified panel
    plotter.show(vis_objects, at=at)
    plotter.add(caption, at=at)


def _visualize_cross_section(plotter, voxels, rigid_segments, joints, at=3, title=""):
    """Create a series of CT-like cross-sectional slices arranged along the x-axis."""

    # Convert 4D voxels [4, 128, 128, 128] to 3D volume with scalar values
    # Assign different scalar values to each component type:
    # 0 = empty space
    # 1 = soft body (is_occupied)
    # 2 = rigid structure (is_rigid_bone)
    # 3 = joint (is_joint)
    # 4 = camera (is_camera)

    volume_shape = voxels.shape[1:]
    scalar_volume = np.zeros(volume_shape, dtype=np.float32)

    # Fill with scalar values in priority order (camera > joint > rigid > soft)
    scalar_volume = np.where(voxels[0] > 0, 0.1, scalar_volume)  # Soft body
    scalar_volume = np.where(voxels[1] > 0, 0.2, scalar_volume)  # Rigid structure
    scalar_volume = np.where(voxels[2] > 0, 0.6, scalar_volume)  # Joint
    scalar_volume = np.where(voxels[3] > 0, 0.9, scalar_volume)  # Camera

    # Create the volume object
    vol = vedo.Volume(scalar_volume)

    # Create a series of CT-like slices along the X axis
    slices = []
    slice_labels = []
    num_slices = 10  # Number of slices to create

    # Calculate slice positions
    x_occupancy = np.any(voxels[0] > 0, axis=(1, 2))
    non_empty_indices = np.where(x_occupancy)[0]
    if len(non_empty_indices) == 0:
        print("No occupied voxels found!")
        return
    x_min, x_max = non_empty_indices[0], non_empty_indices[-1]

    # Add a small buffer around the occupied region
    buffer = 10
    x_min = max(0, x_min + buffer)
    x_max = min(volume_shape[0] - 1, x_max - buffer)

    slice_positions = np.linspace(x_min, x_max, num_slices, dtype=int)

    # Create slices and arrange them along the x-axis
    slice_spacing = volume_shape[1] * 0.2  # Space between slices
    for i, x_pos in enumerate(slice_positions):
        # Create a slice at position x_pos
        ct_slice = (
            vol.slice_plane(origin=(x_pos, 0, 0), normal=(1, 0, 0)).cmap(
                "Greys_r", alpha=[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
            )
            # .cmap(
            #     [
            #         (0, (0, 0, 0)),
            #         (0.15, (0.2, 0.2, 0.2)),
            #         (0.5, (0.6, 0.6, 0.6)),
            #         (0.8, (0.8, 0.8, 0.8)),
            #         (1, (1, 1, 1)),
            #     ],
            #     vmin=0,
            #     vmax=1,
            # )
            # .alpha([(0, 0), (1, 1)], vmin=0, vmax=1)
        )

        # Position the slice along the x-axis
        offset_x = i * slice_spacing - (num_slices * slice_spacing) / 2
        ct_slice.shift(offset_x, 0, 0)

        # Add a label for the slice
        label = vedo.Text3D(
            f"X = {x_pos}",
            pos=(offset_x, -volume_shape[1] / 2 - 10, 0),
            s=3,
            c="black",
        )

        slices.append(ct_slice)
        slice_labels.append(label)

    # Add a color legend to explain the values
    # legend = vedo.Text2D(
    #     "Color Legend:\n"
    #     "Red: Soft Body\n"
    #     "Green: Rigid Structure\n"
    #     "Blue: Joints\n"
    #     "Magenta: Cameras",
    #     pos="bottom-left",
    #     font="Arial",
    #     s=0.9,
    #     c="black",
    #     bg="white",
    #     alpha=0.8,
    # )

    # Create a bounding box to show the original volume extents
    bbox = vedo.Box(vol.bounds()).alpha(0.1).c("gray").wireframe(True)

    # Combine all objects
    all_objects = [bbox] + slices + slice_labels

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show in the specified panel
    plotter.show(all_objects, at=at)
    plotter.add(caption, at=at)


def _visualize_joints_detailed(plotter, voxels, rigid_segments, joints, at=4, title=""):
    """Create detailed visualization of joint mechanics."""

    # First, show a semi-transparent view of the full robot
    iso_non_empty = plot_binary_vedo(voxels[0], color=(0.9, 0.9, 0.9, 0.1), smooth=True)

    # Create joint visualizations
    joint_vis = []

    # Create a grid layout for joints
    rows = int(np.ceil(np.sqrt(len(joints))))
    cols = int(np.ceil(len(joints) / rows))

    grid_size = 25  # Spacing between joints in the grid

    # Find center of the robot
    x_indices, y_indices, z_indices = np.indices(voxels.shape[1:])
    center = np.array(
        [
            np.mean(x_indices[voxels[0]]),
            np.mean(y_indices[voxels[0]]),
            np.mean(z_indices[voxels[0]]),
        ]
    )

    for i, joint in enumerate(joints):
        # Calculate grid position
        row = i // cols
        col = i % cols

        # Get joint data
        position = joint["position"]
        axis = joint["axis"]
        components = joint["components"]

        # Calculate offset position for the grid
        offset_x = (col - cols / 2) * grid_size
        offset_z = (row - rows / 2) * grid_size

        # Create a more detailed visualization of the joint mechanics

        # Add cylindrical joint
        joint_cylinder = vedo.Cylinder(
            pos=position + np.array([offset_x, 0, offset_z]) - axis * 2,
            height=4,
            r=3,
            axis=axis,
            c="silver",
            alpha=0.8,
        )

        # Add shaft
        shaft = vedo.Cylinder(
            pos=position + np.array([offset_x, 0, offset_z]) - axis * 3,
            height=6,
            r=1,
            axis=axis,
            c="gold",
        )

        # Add bearings on each end
        bearing1 = vedo.Torus(
            pos=position + np.array([offset_x, 0, offset_z]) - axis * 2,
            r=3,
            thickness=1,
            c="gray",
            axis=axis,
        )

        bearing2 = vedo.Torus(
            pos=position + np.array([offset_x, 0, offset_z]) + axis * 2,
            r=3,
            thickness=1,
            c="gray",
            axis=axis,
        )

        # Add connection to rigid segments
        connections = []

        for j, component_id in enumerate(components):
            # Create a binary mask for this segment
            segment_mask = rigid_segments == component_id

            # Calculate direction to the segment centroid
            indices = np.where(segment_mask)
            if len(indices[0]) > 0:
                segment_centroid = np.array(
                    [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])]
                )

                # Direction from joint to segment
                direction = segment_centroid - position
                direction_length = np.linalg.norm(direction)

                if direction_length > 0:
                    direction = direction / direction_length

                    # Create a connection rod
                    conn_length = min(10, direction_length * 0.7)
                    connection = vedo.Cylinder(
                        pos=position
                        + np.array([offset_x, 0, offset_z])
                        + direction * (conn_length / 2),
                        height=conn_length,
                        r=1.5,
                        axis=direction,
                        c="orange",
                        alpha=0.8,
                    )

                    connections.append(connection)

        # Add rotation indicators
        # Find a perpendicular vector to the axis
        perp = np.array([1, 0, 0])
        if np.abs(np.dot(perp, axis)) > 0.9:
            perp = np.array([0, 1, 0])
        perp = np.cross(perp, axis)
        perp = perp / np.linalg.norm(perp)

        # Create a circular arrow to indicate rotation
        points = []
        for j in range(37):
            angle = j * np.pi / 18  # 0 to 360 degrees
            # Rotate perp around axis
            point = (
                position
                + np.array([offset_x, 0, offset_z])
                + 5 * (perp * np.cos(angle) + np.cross(axis, perp) * np.sin(angle))
            )
            points.append(point)

        arc = vedo.Line(points, c="cyan", lw=2)

        # Add arrowhead
        arrowhead = vedo.Cone(
            pos=points[18],
            point_at=points[17],
            r=0.8,
            height=1.5,
            c="cyan",
        )

        # Add label
        label = vedo.Text3D(
            f"Joint {i + 1}\nComponents: {components[0]}-{components[1]}",
            pos=position + np.array([offset_x - 5, -8, offset_z]),
            s=2,
            c="black",
            justify="center",
        )

        # Add axis label
        axis_label = vedo.Text3D(
            "Rotation Axis",
            pos=position + np.array([offset_x, 0, offset_z]) + axis * 8,
            s=1.5,
            c="blue",
            justify="center",
        )

        joint_vis.extend(
            [
                joint_cylinder,
                shaft,
                bearing1,
                bearing2,
                arc,
                arrowhead,
                label,
                axis_label,
            ]
            + connections
        )

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show all objects
    all_objects = [iso_non_empty] + joint_vis

    plotter.show(all_objects, at=at)
    plotter.add(caption, at=at)

    # Set the camera angle
    plotter.camera(at=at, zoom=1.0)


def _visualize_cameras_detailed(
    plotter, voxels, rigid_segments, cameras, at=5, title=""
):
    """Create detailed visualization of cameras with FOV and sensing properties."""
    # First, show a semi-transparent view of the full robot
    iso_non_empty = plot_binary_vedo(voxels[0], color=(0.9, 0.9, 0.9, 0.1), smooth=True)

    # Create camera visualizations
    camera_vis = []

    # Create a grid layout for cameras
    rows = int(np.ceil(np.sqrt(len(cameras))))
    cols = int(np.ceil(len(cameras) / rows))

    grid_size = 30  # Spacing between cameras in the grid

    for i, camera in enumerate(cameras):
        # Calculate grid position
        row = i // cols
        col = i % cols

        # Get camera data
        position = camera["position"]
        orientation = camera["orientation"]
        component = camera["component"]

        # Calculate offset position for the grid
        offset_x = (col - cols / 2) * grid_size
        offset_z = (row - rows / 2) * grid_size

        # Calculate the direction vector from quaternion
        qw, qx, qy, qz = orientation
        direction = np.array(
            [
                1 - 2 * qy * qy - 2 * qz * qz,
                2 * qx * qy + 2 * qw * qz,
                2 * qx * qz - 2 * qw * qy,
            ]
        )

        # Calculate up and right vectors
        up = np.array(
            [
                2 * qx * qy - 2 * qw * qz,
                1 - 2 * qx * qx - 2 * qz * qz,
                2 * qy * qz + 2 * qw * qx,
            ]
        )

        right = np.cross(direction, up)

        # Create a more detailed camera visualization

        # Camera body
        camera_body = vedo.Box(
            pos=position + np.array([offset_x, 0, offset_z]),
            size=(4, 3, 3),
            c="darkgray",
        )

        # Lens
        lens = vedo.Cylinder(
            pos=position + np.array([offset_x, 0, offset_z]) + direction * 2.5,
            height=2,
            r=1.5,
            axis=direction,
            c="black",
        )

        # Add lens glass at the front
        lens_glass = vedo.Disc(
            pos=position + np.array([offset_x, 0, offset_z]) + direction * 3.5,
            r=1.5,
            normal=direction,
            c="lightblue",
            alpha=0.7,
        )

        # Create FOV visualization (45 degree field of view)
        fov_angle = 45  # degrees
        fov_cone = vedo.Cone(
            pos=position + np.array([offset_x, 0, offset_z]) + direction * 15,
            height=20,
            angle=fov_angle / 2,
            axis=-direction,  # Pointing towards the camera
            c="yellow",
            alpha=0.2,
        )

        # Add a sensor plane at the back of the camera
        sensor = vedo.Rectangle(
            position + np.array([offset_x, 0, offset_z]) - direction * 1,
            width=3,
            height=2.25,  # 4:3 aspect ratio
            c="cyan",
            alpha=0.7,
        )
        sensor.orientation(up, right)

        # Add coordinate axes
        dir_arrow = vedo.Arrow(
            position + np.array([offset_x, 0, offset_z]),
            position + np.array([offset_x, 0, offset_z]) + direction * 6,
            c="red",
            s=0.4,
        )

        up_arrow = vedo.Arrow(
            position + np.array([offset_x, 0, offset_z]),
            position + np.array([offset_x, 0, offset_z]) + up * 3,
            c="green",
            s=0.3,
        )

        right_arrow = vedo.Arrow(
            position + np.array([offset_x, 0, offset_z]),
            position + np.array([offset_x, 0, offset_z]) + right * 3,
            c="blue",
            s=0.3,
        )

        # Add labels for axes
        x_label = vedo.Text3D(
            "X",
            pos=position + np.array([offset_x, 0, offset_z]) + direction * 7,
            s=1.5,
            c="red",
        )

        y_label = vedo.Text3D(
            "Y",
            pos=position + np.array([offset_x, 0, offset_z]) + up * 4,
            s=1.5,
            c="green",
        )

        z_label = vedo.Text3D(
            "Z",
            pos=position + np.array([offset_x, 0, offset_z]) + right * 4,
            s=1.5,
            c="blue",
        )

        # Add a camera label
        camera_label = vedo.Text3D(
            f"Camera {i + 1}\nComponent: {component}",
            pos=position + np.array([offset_x - 5, -8, offset_z]),
            s=2,
            c="black",
            justify="center",
        )

        # Add FOV label
        fov_label = vedo.Text3D(
            f"FOV: 45°",
            pos=position + np.array([offset_x, 0, offset_z]) + direction * 25,
            s=1.5,
            c="orange",
            justify="center",
        )

        camera_vis.extend(
            [
                camera_body,
                lens,
                lens_glass,
                fov_cone,
                sensor,
                dir_arrow,
                up_arrow,
                right_arrow,
                x_label,
                y_label,
                z_label,
                camera_label,
                fov_label,
            ]
        )

    # Add a title
    caption = vedo.Text2D(
        title,
        pos="top-center",
        c="black",
        font="Arial",
        s=1.2,
    )

    # Show all objects
    all_objects = [iso_non_empty] + camera_vis

    plotter.show(all_objects, at=at)
    plotter.add(caption, at=at)

    # Set the camera angle
    plotter.camera(at=at, zoom=1.0)


def _add_global_annotations(plotter):
    """Add global annotations to the visualization."""

    # Add a color legend for the first panel
    legend = vedo.LegendBox(
        [
            vedo.Sphere(r=1, c=(1, 0, 0)),
            vedo.Sphere(r=1, c=(0, 1, 0)),
            vedo.Sphere(r=1, c=(0, 0, 1)),
            vedo.Sphere(r=1, c=(1, 0, 1)),
        ],
        texts=["Soft Body", "Rigid Structure", "Joints", "Cameras"],
        pos=(0.1, 0.1),  # bottom left corner
        font="Arial",
        bg="white",
        alpha=0.7,
    )
    plotter.add(legend, at=0)

    # Add a global title
    title = vedo.Text2D(
        "Robot Structure Visualization",
        pos="top-center",
        c="black",
        font="Arial",
        bg="white",
        s=1.5,
        alpha=0.7,
    )
    plotter.add(title)
