import numpy as np
import pygame
import threading
from pyquaternion import Quaternion
from typing import List, Tuple, Any, Union

from rise import *


class LocalRobotTest:
    """Interactive local test environment for robot development with pygame visualization"""

    # Camera and rendering constants
    CAMERA_WIDTH = 1280
    CAMERA_HEIGHT = 720
    CAMERA_INITIAL_POSITION = (-0.5, 0.0, 0.3)
    CAMERA_INITIAL_ORIENTATION = (0, 0, 0, 1)
    CAMERA_INITIAL_FOCUS_POINT = (0.0, 0.0, 0.15)
    CAMERA_SAMPLES_PER_PIXEL = 4
    CAMERA_MOVE_SPEED = 0.05
    MOUSE_SENSITIVITY = 0.001
    CAMERA_BOX_MARGIN = 0.5

    # Rigid body camera display size (for resizing and displaying)
    RIGID_BODY_CAMERA_WIDTH = 256
    RIGID_BODY_CAMERA_HEIGHT = 256
    RIGID_BODY_CAMERA_LOCAL_ANCHOR = (0, 0, 1)

    # Default joint angle ranges
    JOINT_ANGLE_MIN = -1.5
    JOINT_ANGLE_DEFAULT = 0
    JOINT_ANGLE_MAX = 1.5

    def __init__(
        self,
        robot_structure_config: RS_StructureConfig,
        config: RS_Config,
        device: str = "cuda:0",
        robot_camera_configs: List[RS_CameraConfig] = None,
    ):
        """
        Initialize the local robot test environment.

        Args:
            robot_structure_config: Robot structure configuration
            config: Rise simulation configuration
            device: Device to run simulation on (default: "cuda:0")
            robot_camera_configs: Optional list of robot-mounted camera configurations
        """
        self.rise = Rise(device)
        self.config = config
        self.robot_structure_config = robot_structure_config
        self.robot_camera_configs = robot_camera_configs or []
        self.simulation_handle = None  # type: Union[None, RS_SimulationHandle]

        # Simulation state
        self.joint_states = {}
        self.camera_position = list(self.CAMERA_INITIAL_POSITION)
        self.camera_quaternion = list(self.CAMERA_INITIAL_ORIENTATION)

        # Orbit controls state
        self.focus_point = list(
            self.CAMERA_INITIAL_FOCUS_POINT
        )  # Point camera orbits around
        cam_to_target = np.array(self.focus_point) - np.array(self.camera_position)
        self.camera_distance = np.linalg.norm(
            cam_to_target
        )  # Distance from camera to focus point
        self.focus_point_movement = [0.0, 0.0]  # X, Y movement for focus point

        # Mouse control state
        self.pending_orbit_pitch = 0.0  # Left mouse: orbit rotation
        self.pending_orbit_yaw = 0.0
        self.pending_pan_x = 0.0  # Right mouse: pan in screen space
        self.pending_pan_y = 0.0
        self.pending_zoom_delta = 0.0  # Scroll wheel: zoom

        self.pending_camera_reset = False
        self.termination_requested = False
        self.last_mouse_pos = (0, 0)
        self.left_mouse_down = False
        self.right_mouse_down = False
        self.middle_mouse_down = False

        # Robot bounds tracking
        self.robot_bounds_min = [0.0, 0.0, 0.0]
        self.robot_bounds_max = [0.0, 0.0, 0.0]
        self.robot_bounds_center = [0.0, 0.0, 0.0]
        self.robot_bounds_valid = False

        # Pygame surfaces
        self.surface = None
        self.rigid_body_camera_surface = None
        self.screen = None

        # Initialize joint states
        for i in range(9):
            self.joint_states[i] = self.JOINT_ANGLE_DEFAULT

    @staticmethod
    def create_structure_record(name="record", structure_name="structure"):
        """Create a structure record configuration"""
        rconf = RS_RecordConfig()
        rconf.name = name
        rconf.type = RSE_RecordType.RSE_RECORD_STRUCTURE

        rsconf = RS_StructureRecordConfig()
        rsconf.structure_name = structure_name
        rsconf.voxel_record_range = (
            RSE_StructureVoxelRecordRange.RSE_VOXEL_RECORD_RANGE_ALL
        )
        rsconf.link_record_range = (
            RSE_StructureLinkRecordRange.RSE_LINK_RECORD_RANGE_NONE
        )
        rsconf.rigid_body_record_range = (
            RSE_StructureRigidBodyRecordRange.RSE_RIGID_BODY_RECORD_RANGE_ALL
        )
        rsconf.joint_record_range = (
            RSE_StructureJointRecordRange.RSE_JOINT_RECORD_RANGE_NONE
        )
        rconf.config = rsconf
        return rconf

    @staticmethod
    def create_free_perspective_camera(
        name="camera",
        origin=(0.0, 0.0, 0.0),
        orientation=(0.0, 0.0, 0.0, 1.0),
        width=1280,
        height=720,
        samples=4,
    ):
        """Create a free perspective camera configuration"""
        ccoonf = RS_CameraConfig()
        ccoonf.name = name
        ccoonf.type = RSE_CameraType.RSE_CAMERA_FREE_PERSPECTIVE

        cfpconf = RS_FreePerspectiveCameraConfig()
        cfpconf.max_depth = 5
        cfpconf.image_width = width
        cfpconf.image_height = height
        cfpconf.samples_per_pixel = samples
        cfpconf.field_of_view = 45.0
        cfpconf.origin_position = RVec3rf(*origin)
        cfpconf.orientation = RQuat3rf(*orientation)
        cfpconf.defocus_angle = 0
        cfpconf.focus_distance = 0.05
        ccoonf.config = cfpconf
        return ccoonf

    @staticmethod
    def look_at_quaternion(eye, center, global_up=(0, 0, 1)):
        """
        Create a quaternion that makes an object at eye position look at center.
        For a camera whose local forward vector is (1,0,0) and up vector is (0,0,1).

        Args:
            eye: Position of the object [x, y, z]
            center: Position to look at [x, y, z]
            global_up: Global up direction, defaults to [0, 0, 1]

        Returns:
            list: Quaternion in [x, y, z, w] format
        """
        eye = np.array(eye, dtype=np.float64)
        center = np.array(center, dtype=np.float64)
        global_up = np.array(global_up, dtype=np.float64)

        forward = center - eye
        forward = forward / np.linalg.norm(forward)

        right = np.cross(global_up, forward)
        right = right / np.linalg.norm(right)

        local_up = np.cross(forward, right)
        local_up = local_up / np.linalg.norm(local_up)

        rotation_matrix = np.column_stack((forward, right, local_up))
        q = Quaternion(matrix=rotation_matrix)

        return [q.x, q.y, q.z, q.w]

    @staticmethod
    def get_quaternion_forward_vector(q):
        """Get the forward vector (x-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        return quat.rotate([1, 0, 0])

    @staticmethod
    def get_quaternion_right_vector(q):
        """Get the right vector (y-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        return quat.rotate([0, 1, 0])

    @staticmethod
    def get_quaternion_up_vector(q):
        """Get the up vector (z-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        return quat.rotate([0, 0, 1])

    def update_robot_bounds(self, structure_record):
        """Update the bounding box of the robot based on voxel positions"""
        if not structure_record:
            return

        voxel_positions = structure_record.voxel_position()
        if voxel_positions is None or len(voxel_positions) == 0:
            return

        min_pos = np.min(voxel_positions, axis=0)
        max_pos = np.max(voxel_positions, axis=0)
        center_pos = (min_pos + max_pos) / 2.0

        self.robot_bounds_min = min_pos.tolist()
        self.robot_bounds_max = max_pos.tolist()
        self.robot_bounds_center = center_pos.tolist()
        self.robot_bounds_valid = True

    def compute_camera_view_from_bounds(self):
        """Compute camera position and orientation to view the robot from a diagonal"""
        if not self.robot_bounds_valid:
            print("Robot bounds not valid, cannot reset camera")
            return None, None

        size_x = self.robot_bounds_max[0] - self.robot_bounds_min[0]
        size_y = self.robot_bounds_max[1] - self.robot_bounds_min[1]
        size_z = self.robot_bounds_max[2] - self.robot_bounds_min[2]

        size_x = max(size_x, 0.1)
        size_y = max(size_y, 0.1)
        size_z = max(size_z, 0.1)

        max_size = max(size_x, size_y, size_z)

        camera_pos = [
            self.robot_bounds_center[0] + max_size * 0.7,
            self.robot_bounds_center[1] + max_size * 0.7,
            self.robot_bounds_center[2] + max_size * 0.7,
        ]

        camera_quat = self.look_at_quaternion(camera_pos, self.robot_bounds_center)

        return camera_pos, camera_quat

    def get_simulation_callback(self):
        """Create and return the simulation callback function"""

        def callback(sim_ids, controllers_list):
            """Callback function for the simulator to handle controls and rendering"""
            for i, (sim_id, controllers) in enumerate(zip(sim_ids, controllers_list)):
                for controller_tuple in controllers:
                    if (
                        controller_tuple[0]
                        == RSE_SimulationControllerType.RSE_SIMULATION_STRUCTURE_CONTROLLER
                    ):
                        controller = controller_tuple[1]  # type: RS_StructureController

                        structure_record = controller.structure_record()
                        if structure_record:
                            self.update_robot_bounds(structure_record)

                        rotation_signals = controller.signal().rotation_angle_signals()

                        # Apply joint states from key presses
                        for joint_idx, angle in self.joint_states.items():
                            if joint_idx < len(rotation_signals):
                                rotation_signals[joint_idx] = angle

                        # Handle robot-mounted cameras
                        for camera_record in controller.camera_records():
                            if camera_record.is_valid():
                                pixels = camera_record.pixels()
                                rb_surface = pygame.surfarray.make_surface(
                                    np.transpose(pixels, (1, 0, 2))
                                )
                                # Resize to display size
                                self.rigid_body_camera_surface = pygame.transform.scale(
                                    rb_surface,
                                    (
                                        self.RIGID_BODY_CAMERA_WIDTH,
                                        self.RIGID_BODY_CAMERA_HEIGHT,
                                    ),
                                )

                    elif (
                        controller_tuple[0]
                        == RSE_SimulationControllerType.RSE_SIMULATION_CAMERA_CONTROLLER
                    ):
                        controller = controller_tuple[1]  # type: RS_CameraController
                        cam_record = controller.camera_record()

                        if cam_record.is_valid():
                            quat = cam_record.orientation()
                            pos = cam_record.position()

                            if quat is not None and len(quat) == 4:
                                self.camera_quaternion = [
                                    quat[0],
                                    quat[1],
                                    quat[2],
                                    quat[3],
                                ]

                            if pos is not None and len(pos) == 3:
                                self.camera_position = [pos[0], pos[1], pos[2]]

                            # Handle camera reset
                            if self.pending_camera_reset:
                                new_pos, new_quat = (
                                    self.compute_camera_view_from_bounds()
                                )
                                if new_pos and new_quat:
                                    self.camera_position = new_pos
                                    self.camera_quaternion = new_quat
                                    # Update focus point and distance
                                    self.focus_point = self.robot_bounds_center.copy()
                                    cam_to_target = np.array(
                                        self.focus_point
                                    ) - np.array(self.camera_position)
                                    self.camera_distance = np.linalg.norm(cam_to_target)
                                self.pending_camera_reset = False

                            # Apply WASD focus point movement (camera-relative, projected to XY plane)
                            if (
                                self.focus_point_movement[0] != 0
                                or self.focus_point_movement[1] != 0
                            ):
                                # Get camera forward vector and project to XY plane
                                forward_vector = self.get_quaternion_forward_vector(
                                    self.camera_quaternion
                                )
                                forward_xy = np.array(
                                    [forward_vector[0], forward_vector[1], 0.0]
                                )
                                forward_xy_norm = np.linalg.norm(forward_xy)

                                if forward_xy_norm > 0.01:  # Avoid division by zero
                                    forward_xy = forward_xy / forward_xy_norm
                                    # Right vector is perpendicular in XY plane (rotate 90 degrees around Z)
                                    right_xy = np.array(
                                        [forward_xy[1], -forward_xy[0], 0.0]
                                    )

                                    move_speed = 0.05
                                    # self.focus_point_movement[1] is W/S (forward/back)
                                    # self.focus_point_movement[0] is A/D (left/right)
                                    move_offset = (
                                        forward_xy
                                        * self.focus_point_movement[1]
                                        * move_speed
                                        + right_xy
                                        * self.focus_point_movement[0]
                                        * move_speed
                                    )

                                    self.focus_point[0] += move_offset[0]
                                    self.focus_point[1] += move_offset[1]
                                    self.focus_point[2] += move_offset[2]

                            # Apply zoom (scroll wheel or middle mouse)
                            if self.pending_zoom_delta != 0:
                                zoom_speed = 0.1
                                self.camera_distance *= (
                                    1.0 - self.pending_zoom_delta * zoom_speed
                                )
                                # Clamp distance
                                self.camera_distance = max(
                                    0.1, min(self.camera_distance, 50.0)
                                )
                                self.pending_zoom_delta = 0

                            # Apply pan (right mouse button)
                            if self.pending_pan_x != 0 or self.pending_pan_y != 0:
                                # Pan moves the focus point in screen space
                                right_vector = self.get_quaternion_right_vector(
                                    self.camera_quaternion
                                )
                                up_vector = self.get_quaternion_up_vector(
                                    self.camera_quaternion
                                )

                                pan_speed = self.camera_distance * 0.001
                                pan_offset = (
                                    np.array(right_vector)
                                    * self.pending_pan_x
                                    * pan_speed
                                    + np.array(up_vector)
                                    * self.pending_pan_y
                                    * pan_speed
                                )

                                self.focus_point[0] += pan_offset[0]
                                self.focus_point[1] += pan_offset[1]
                                self.focus_point[2] += pan_offset[2]

                                self.pending_pan_x = 0
                                self.pending_pan_y = 0

                            # Apply orbit rotation (left mouse button)
                            if (
                                self.pending_orbit_pitch != 0
                                or self.pending_orbit_yaw != 0
                            ):
                                # Get current camera-to-target vector
                                cam_to_target = np.array(self.focus_point) - np.array(
                                    self.camera_position
                                )

                                # Create rotation quaternions
                                # Pitch around the right vector (local X axis)
                                right_axis = self.get_quaternion_right_vector(
                                    self.camera_quaternion
                                )
                                pitch_quat = Quaternion(
                                    axis=right_axis, angle=self.pending_orbit_pitch
                                )

                                # Yaw around the world Z axis (up)
                                yaw_quat = Quaternion(
                                    axis=[0, 0, 1], angle=self.pending_orbit_yaw
                                )

                                # Apply rotations to the camera-to-target vector
                                rotated_vector = pitch_quat.rotate(cam_to_target)
                                rotated_vector = yaw_quat.rotate(rotated_vector)

                                # Update camera position to maintain distance from focus point
                                self.camera_position = (
                                    np.array(self.focus_point) - rotated_vector
                                ).tolist()

                                # Update camera orientation to look at focus point
                                self.camera_quaternion = self.look_at_quaternion(
                                    self.camera_position, self.focus_point
                                )

                                self.pending_orbit_pitch = 0
                                self.pending_orbit_yaw = 0
                            else:
                                # Update camera position based on focus point and distance
                                # Calculate position from focus point using current orientation
                                forward_vector = self.get_quaternion_forward_vector(
                                    self.camera_quaternion
                                )
                                self.camera_position = (
                                    np.array(self.focus_point)
                                    - np.array(forward_vector) * self.camera_distance
                                ).tolist()

                            new_orientation = self.camera_quaternion

                            # Apply camera position and orientation
                            position_signals = controller.signal().position_signals()
                            orientation_signals = (
                                controller.signal().orientation_signals()
                            )

                            if (
                                position_signals is not None
                                and len(position_signals) >= 3
                            ):
                                position_signals[0] = self.camera_position[0]
                                position_signals[1] = self.camera_position[1]
                                position_signals[2] = self.camera_position[2]

                            if (
                                orientation_signals is not None
                                and len(orientation_signals) >= 4
                            ):
                                orientation_signals[0] = new_orientation[0]
                                orientation_signals[1] = new_orientation[1]
                                orientation_signals[2] = new_orientation[2]
                                orientation_signals[3] = new_orientation[3]

                            # Update pygame surface
                            pixels = cam_record.pixels()
                            pygame_surface = pygame.surfarray.make_surface(
                                np.transpose(pixels, (1, 0, 2))
                            )
                            self.surface = pygame_surface

            # Check for termination request
            if self.termination_requested and self.simulation_handle is not None:
                self.simulation_handle.terminate()

        return callback

    def pygame_thread(self):
        """Main PyGame thread to handle input and display"""
        self.pending_orbit_pitch = 0.0
        self.pending_orbit_yaw = 0.0
        self.pending_pan_x = 0.0
        self.pending_pan_y = 0.0
        self.pending_zoom_delta = 0.0
        self.pending_camera_reset = False

        pygame.init()

        self.screen = pygame.display.set_mode((self.CAMERA_WIDTH, self.CAMERA_HEIGHT))
        pygame.display.set_caption("Rise Simulator - Local Robot Test")

        self.surface = pygame.Surface((self.CAMERA_WIDTH, self.CAMERA_HEIGHT))
        self.surface.fill((0, 0, 0))

        # Only create rigid body camera surface if robot cameras are configured
        if len(self.robot_camera_configs) > 0:
            self.rigid_body_camera_surface = pygame.Surface(
                (self.RIGID_BODY_CAMERA_WIDTH, self.RIGID_BODY_CAMERA_HEIGHT)
            )
            self.rigid_body_camera_surface.fill((0, 0, 0))

        font = pygame.font.Font(None, 24)
        clock = pygame.time.Clock()
        running = True

        joint_keys = {
            pygame.K_1: 0,
            pygame.K_2: 1,
            pygame.K_3: 2,
            pygame.K_4: 3,
            pygame.K_5: 4,
            pygame.K_6: 5,
            pygame.K_7: 6,
            pygame.K_8: 7,
            pygame.K_9: 8,
        }

        self.last_mouse_pos = (0, 0)
        pygame.mouse.set_visible(True)

        while running and not self.termination_requested:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                    self.termination_requested = True

                elif event.type == pygame.KEYDOWN:
                    if (
                        event.key == pygame.K_c
                        and pygame.key.get_mods() & pygame.KMOD_CTRL
                    ):
                        print("Ctrl+C pressed, terminating simulation...")
                        self.termination_requested = True
                        running = False

                    elif event.key == pygame.K_ESCAPE:
                        print("ESC pressed, terminating simulation...")
                        self.termination_requested = True
                        running = False

                    elif event.key == pygame.K_t:
                        print("T key pressed, resetting camera to view robot")
                        self.pending_camera_reset = True

                    elif event.key in joint_keys:
                        joint_idx = joint_keys[event.key]
                        # Check if Shift is pressed
                        if pygame.key.get_mods() & pygame.KMOD_SHIFT:
                            self.joint_states[joint_idx] = self.JOINT_ANGLE_MIN
                            print(
                                f"Joint {joint_idx + 1} set to minimum ({self.JOINT_ANGLE_MIN})"
                            )
                        else:
                            self.joint_states[joint_idx] = self.JOINT_ANGLE_MAX
                            print(
                                f"Joint {joint_idx + 1} set to maximum ({self.JOINT_ANGLE_MAX})"
                            )

                    # WASD for focus point movement (camera-relative, projected to XY plane)
                    elif event.key == pygame.K_w:
                        self.focus_point_movement[1] = 1.0  # Forward
                    elif event.key == pygame.K_s:
                        self.focus_point_movement[1] = -1.0  # Backward
                    elif event.key == pygame.K_a:
                        self.focus_point_movement[0] = -1.0  # Left
                    elif event.key == pygame.K_d:
                        self.focus_point_movement[0] = 1.0  # Right

                elif event.type == pygame.KEYUP:
                    if event.key in joint_keys:
                        joint_idx = joint_keys[event.key]
                        self.joint_states[joint_idx] = self.JOINT_ANGLE_DEFAULT

                    elif event.key == pygame.K_w or event.key == pygame.K_s:
                        self.focus_point_movement[1] = 0.0
                    elif event.key == pygame.K_a or event.key == pygame.K_d:
                        self.focus_point_movement[0] = 0.0

                elif event.type == pygame.MOUSEBUTTONDOWN:
                    if event.button == 1:  # Left mouse button
                        self.left_mouse_down = True
                        self.last_mouse_pos = pygame.mouse.get_pos()
                    elif event.button == 2:  # Middle mouse button
                        self.middle_mouse_down = True
                        self.last_mouse_pos = pygame.mouse.get_pos()
                    elif event.button == 3:  # Right mouse button
                        self.right_mouse_down = True
                        self.last_mouse_pos = pygame.mouse.get_pos()
                    elif event.button == 4:  # Scroll up
                        self.pending_zoom_delta = 1.0
                    elif event.button == 5:  # Scroll down
                        self.pending_zoom_delta = -1.0

                elif event.type == pygame.MOUSEBUTTONUP:
                    if event.button == 1:  # Left mouse button
                        self.left_mouse_down = False
                    elif event.button == 2:  # Middle mouse button
                        self.middle_mouse_down = False
                    elif event.button == 3:  # Right mouse button
                        self.right_mouse_down = False

                elif event.type == pygame.MOUSEMOTION:
                    if (
                        self.left_mouse_down
                        or self.right_mouse_down
                        or self.middle_mouse_down
                    ):
                        mouse_pos = pygame.mouse.get_pos()
                        dx = mouse_pos[0] - self.last_mouse_pos[0]
                        dy = mouse_pos[1] - self.last_mouse_pos[1]

                        if dx != 0 or dy != 0:
                            if self.left_mouse_down:
                                # Left mouse: Orbit rotation
                                self.pending_orbit_pitch += dy * self.MOUSE_SENSITIVITY
                                self.pending_orbit_yaw += -dx * self.MOUSE_SENSITIVITY
                            elif self.right_mouse_down:
                                # Right mouse: Pan in screen space
                                self.pending_pan_x += dx
                                self.pending_pan_y += dy
                            elif self.middle_mouse_down:
                                # Middle mouse: Zoom (dolly)
                                self.pending_zoom_delta += dy * 0.01

                        self.last_mouse_pos = mouse_pos

            # Render
            self.screen.fill((0, 0, 0))
            self.screen.blit(self.surface, (0, 0))

            # Display rigid body camera
            if self.rigid_body_camera_surface is not None:
                self.screen.blit(
                    self.rigid_body_camera_surface,
                    (
                        self.CAMERA_WIDTH - self.RIGID_BODY_CAMERA_WIDTH - 10,
                        10,
                    ),
                )
                pygame.draw.rect(
                    self.screen,
                    (255, 255, 255),
                    (
                        self.CAMERA_WIDTH - self.RIGID_BODY_CAMERA_WIDTH - 12,
                        8,
                        self.RIGID_BODY_CAMERA_WIDTH + 4,
                        self.RIGID_BODY_CAMERA_HEIGHT + 4,
                    ),
                    2,
                )
                text = font.render("Robot Camera", True, (255, 255, 255))
                self.screen.blit(
                    text,
                    (
                        self.CAMERA_WIDTH - self.RIGID_BODY_CAMERA_WIDTH - 10,
                        self.RIGID_BODY_CAMERA_HEIGHT + 20,
                    ),
                )

            # Display joint states
            for i, (joint_idx, angle) in enumerate(self.joint_states.items()):
                if angle == self.JOINT_ANGLE_MAX:
                    status = "MAX"
                elif angle == self.JOINT_ANGLE_MIN:
                    status = "MIN"
                else:
                    status = "DEFAULT"
                text = font.render(
                    f"Joint {joint_idx + 1}: {status}", True, (255, 255, 255)
                )
                self.screen.blit(text, (10, 10 + i * 25))

            # Display camera info
            y_offset = 10 + len(self.joint_states) * 25 + 20

            text = font.render(
                f"Focus Point: X={self.focus_point[0]:.2f}, Y={self.focus_point[1]:.2f}, Z={self.focus_point[2]:.2f}",
                True,
                (255, 255, 255),
            )
            self.screen.blit(text, (10, y_offset))

            text = font.render(
                f"Camera Position: X={self.camera_position[0]:.2f}, Y={self.camera_position[1]:.2f}, Z={self.camera_position[2]:.2f}",
                True,
                (255, 255, 255),
            )
            self.screen.blit(text, (10, y_offset + 25))

            text = font.render(
                f"Distance: {self.camera_distance:.2f}",
                True,
                (255, 255, 255),
            )
            self.screen.blit(text, (10, y_offset + 50))

            if self.robot_bounds_valid:
                text = font.render(
                    f"Robot Bounds: Min={self.robot_bounds_min[0]:.2f},{self.robot_bounds_min[1]:.2f},{self.robot_bounds_min[2]:.2f}",
                    True,
                    (255, 255, 255),
                )
                self.screen.blit(text, (10, y_offset + 75))
                text = font.render(
                    f"Robot Bounds: Max={self.robot_bounds_max[0]:.2f},{self.robot_bounds_max[1]:.2f},{self.robot_bounds_max[2]:.2f}",
                    True,
                    (255, 255, 255),
                )
                self.screen.blit(text, (10, y_offset + 100))
            else:
                text = font.render("Robot Bounds: Not available", True, (255, 255, 255))
                self.screen.blit(text, (10, y_offset + 75))

            # Display controls
            text = font.render(
                "Camera: Left Mouse=Orbit, Right Mouse=Pan, Scroll/Middle=Zoom, WASD=Move Focus (Camera Dir)",
                True,
                (200, 200, 200),
            )
            self.screen.blit(text, (10, y_offset + 125))

            text = font.render(
                "Controls: 1-9=Joints (Shift+#=Min), T=Reset Camera, ESC=Exit",
                True,
                (200, 200, 200),
            )
            self.screen.blit(text, (10, y_offset + 150))

            pygame.display.flip()
            clock.tick(60)

        pygame.quit()

        if not self.termination_requested and self.simulation_handle is not None:
            self.termination_requested = True
            self.simulation_handle.terminate()

    def run(self):
        """Run the simulation with pygame visualization"""
        # Focus point and distance are already initialized in __init__

        # Prepare configuration
        robot_record_config = self.create_structure_record(
            "robot_0_record", self.robot_structure_config.name
        )

        global_camera_config = self.create_free_perspective_camera(
            "camera_0",
            width=self.CAMERA_WIDTH,
            height=self.CAMERA_HEIGHT,
            origin=self.CAMERA_INITIAL_POSITION,
            samples=self.CAMERA_SAMPLES_PER_PIXEL,
        )

        self.config.simulation_config.warm_start_damping_z = 1
        self.config.structure_configs.append(self.robot_structure_config)
        self.config.record_configs.append(robot_record_config)
        self.config.camera_configs.append(global_camera_config)

        # Add robot cameras if provided
        for robot_camera_config in self.robot_camera_configs:
            self.config.camera_configs.append(robot_camera_config)

        self.config.simulation_config.stop_condition.from_string("t > 10000000.0")
        self.config.simulation_config.record_frequency = 20
        self.config.simulation_config.control_frequency = 20

        print(
            f"Robot joint num: {self.robot_structure_config.rotation_angle_signal_num}"
        )

        # Start pygame thread
        pygame_thread_handle = threading.Thread(target=self.pygame_thread)
        pygame_thread_handle.daemon = True
        pygame_thread_handle.start()

        # Create callback
        callback = self.get_simulation_callback()

        # Run simulation
        handles = self.rise.run_sims(
            [self.config],
            [0],
            callback=callback,
            dt_update_interval=10,
            collision_update_interval=10,
            constraint_update_interval=2,
            divergence_check_interval=100,
            record_buffer_size=5,
            max_time=10000000.0,
            checkpoint_time=1000,
            save_record=False,
            save_checkpoint=False,
            simulation_memory_accuracy="float32",
            simulation_compute_accuracy="float32",
            policy="sequential",
            log_level="info",
        )

        self.simulation_handle = handles[0]

        try:
            self.simulation_handle.wait_for_end()
        except KeyboardInterrupt:
            print("KeyboardInterrupt received, terminating...")
            self.simulation_handle.terminate()

        pygame_thread_handle.join(timeout=1.0)

        status = self.simulation_handle.get_status()
        print("Simulation ended with status:", status)
        if status == 0:
            print("Simulation successful!")
        else:
            print("Simulation failed or was terminated.")
