import time
import queue
import uuid
import threading
import torch as t
import numpy as np
from copy import deepcopy
from pyquaternion import Quaternion
from typing import Tuple, List, Dict, Any, Union
from rl.algorithms import A2C, PPO
from utils.misc import list_of_dict_to_dict_of_list
from sim.env.env_interface import CameraState, CameraWebClient
from sim.create.create_camera import create_free_perspective_camera
from rise import *


class Timer:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.time()
        self.interval = self.end - self.start
        # print(f"Elapsed {self.name} time: {self.interval:.4f} seconds")


class ThreadPool:
    def __init__(self, num_threads):
        self.tasks = queue.Queue()
        self.results = queue.Queue()
        self.threads = []
        for _ in range(num_threads):
            t = threading.Thread(target=self.worker)
            t.start()
            self.threads.append(t)

    def worker(self):
        while True:
            task = self.tasks.get()
            if task is None:  # None is the signal to shut down
                self.tasks.task_done()
                break
            func, args = task
            result = func(*args)
            self.results.put(result)
            self.tasks.task_done()

    def submit(self, func, *args):
        """Submit a task to the thread pool"""
        if not callable(func):
            raise TypeError("First argument must be callable")
        expected_arg_count = func.__code__.co_argcount
        if len(args) != expected_arg_count:
            raise TypeError(
                f"Function expects {expected_arg_count} args, " f"got {len(args)}"
            )
        self.tasks.put((func, args))

    def map(self, func, iterable, *fixed_args):
        """Apply 'func' to every item in 'iterable' and collect the results,
        Note: order is unspecified"""
        for args in iterable:
            if fixed_args is not None:
                combined_args = (args, *fixed_args)
                self.submit(func, *combined_args)
            else:
                combined_args = args
                self.submit(func, combined_args)

        self.close()
        results = []
        while not self.results.empty():
            results.append(self.results.get())
        return results

    def close(self):
        """Stop the thread pool and wait for all tasks to complete"""
        for _ in self.threads:
            self.tasks.put(None)  # Each None will shut down one thread
        for thread in self.threads:
            thread.join()


class RiseEnvForLatentConditionedMoE:
    def __init__(
        self,
        framework: Union[A2C, PPO],
        device: str,
        config: RS_Config,
        angles: Union[np.ndarray, List[float]],
        rank: int,
        voxel_sample_num: int = 5000,
        random_seed: int = 42,
        build_voxel_observations: bool = True,
        build_kinematic_graph: bool = True,
        build_reward_state: bool = True,
        webserver_url: str = "XXXX",
        sim_id: str = None,
        sim_name: str = None,
        global_camera_width: int = 1280,
        global_camera_height: int = 720,
        global_camera_samples: int = 4,
        title: str = None,
    ):

        self.framework = framework
        self.rise = Rise(device)
        self.simulation_handle = None  # type: Union[None, RS_SimulationHandle]
        self.config = deepcopy(config)
        self.angles = np.array(angles).astype(np.float32)

        self.robots = []
        self.objects = []
        self.materials = []
        # The key is robot name
        self.prev_com = {}  # type: Dict[str, t.Tensor]
        self.observations = {}  # type: Dict[str, List[dict]]
        self.actions = {}  # type: Dict[str, List[List[t.Tensor]]]
        self.reward_states = {}  # type: Dict[str, List[dict]]
        # The key is robot name, value is latent tensor of shape [1, latent_dim]
        self.robot_latents = {}  # type: Dict[str, t.Tensor]
        self.voxel_sample_num = voxel_sample_num
        self.random_seed = random_seed
        self.build_voxel_observations = build_voxel_observations
        self.build_kinematic_graph = build_kinematic_graph
        self.build_reward_state = build_reward_state
        # Web server connection
        self.webserver_url = webserver_url
        self.sim_id = sim_id or str(uuid.uuid4())
        self.sim_name = sim_name or f"Simulation {self.sim_id[:8]}"

        self.rank = rank
        if webserver_url is not None:
            self.web_client = CameraWebClient(webserver_url, self.sim_id, self.sim_name)
        else:
            self.web_client = None

        # Global camera setup
        self.global_camera_name = f"global_camera_{self.sim_id}"
        self.global_camera_width = global_camera_width
        self.global_camera_height = global_camera_height
        self.global_camera_samples = global_camera_samples
        self.global_camera_position = [1.0, 0.0, 4.0]
        self.global_camera_orientation = [0, 0.7071, 0, 0.7071]
        self.global_camera_state = CameraState(
            position=self.global_camera_position,
            orientation=self.global_camera_orientation,
            movement_direction=[0.0, 0.0, 0.0],
            roll_input=0.0,
        )

        # Set up the global camera in the config
        global_camera_config = create_free_perspective_camera(
            name=self.global_camera_name,
            origin=self.global_camera_position,
            orientation=self.global_camera_orientation,
            width=self.global_camera_width,
            height=self.global_camera_height,
            samples=self.global_camera_samples,
        )
        self.config.camera_configs.append(global_camera_config)

        # Connect to web server
        if self.web_client is not None:
            if self.web_client.connect():
                if self.rank == 0:
                    self.web_client.update_title(title)

            else:
                print(
                    f"Warning: Failed to connect to web server {webserver_url}, running in offline mode"
                )
                self.web_client = None
        else:
            print("Env running in offline mode")

    def terminate(self):
        if self.simulation_handle is not None:
            self.simulation_handle.terminate()

        # Disconnect from web server
        if self.web_client is not None:
            self.web_client.disconnect()

    def pause(self):
        if self.simulation_handle is not None:
            self.simulation_handle.pause()

    def resume(self):
        if self.simulation_handle is not None:
            self.simulation_handle.resume()

    def wait_for_end(self):
        if self.simulation_handle is not None:
            self.simulation_handle.wait_for_end()
            if self.web_client is not None:
                self.web_client.disconnect()

    def add_materials(
        self, material_configs: Union[RS_MaterialConfig, List[RS_MaterialConfig]]
    ):
        """
        Add materials to the simulation environment
        """
        if isinstance(material_configs, RS_MaterialConfig):
            material_configs = [material_configs]

        if self.simulation_handle is not None:
            self.simulation_handle.begin_modification_transaction()

            for material_config in material_configs:
                if material_config.name in self.materials:
                    raise ValueError(f"Material {material_config.name} already exists")

                self.simulation_handle.add_material(material_config)
                self.materials.append(material_config.name)
                print(f"Added material {material_config.name}")

            self.simulation_handle.end_modification_transaction()

    def add_objects(
        self, object_configs: Union[RS_StructureConfig, List[RS_StructureConfig]]
    ):
        """
        Add objects to the simulation environment

        Args:
            object_configs: Single RS_StructureConfig object or a list of RS_StructureConfig objects
        """
        # If a single configuration object is passed, convert it to a list
        if isinstance(object_configs, RS_StructureConfig):
            object_configs = [object_configs]

        if self.simulation_handle is not None:
            self.simulation_handle.begin_modification_transaction()

            for object_config in object_configs:
                if object_config.name in self.objects:
                    raise ValueError(f"Object {object_config.name} already exists")

                self.objects.append(object_config.name)
                self.simulation_handle.add_structure(object_config)
                print(f"Added object {object_config.name}")

            self.simulation_handle.end_modification_transaction()

    def add_robot(self, robot_config: RS_StructureConfig, robot_latent: t.Tensor):
        if robot_config.name in self.robots:
            raise ValueError(f"Robot {robot_config.name} already exists")
        if self.simulation_handle is not None:
            self.robots.append(robot_config.name)
            self.prev_com[robot_config.name] = None
            self.observations[robot_config.name] = []
            self.actions[robot_config.name] = []
            self.reward_states[robot_config.name] = []
            # Store per-robot latent: shape [1, latent_dim]
            self.robot_latents[robot_config.name] = robot_latent

            robot_record_config = RS_RecordConfig()
            robot_record_config.name = robot_config.name + "_record"
            robot_record_config.type = RSE_RecordType.RSE_RECORD_STRUCTURE

            struct_record_config = RS_StructureRecordConfig()
            struct_record_config.structure_name = robot_config.name
            struct_record_config.rigid_body_record_range = (
                RSE_StructureRigidBodyRecordRange.RSE_RIGID_BODY_RECORD_RANGE_ALL
            )
            struct_record_config.joint_record_range = (
                RSE_StructureJointRecordRange.RSE_JOINT_RECORD_RANGE_ALL
            )
            struct_record_config.voxel_record_range = (
                RSE_StructureVoxelRecordRange.RSE_VOXEL_RECORD_RANGE_ALL
            )
            struct_record_config.link_record_range = (
                RSE_StructureLinkRecordRange.RSE_LINK_RECORD_RANGE_NONE
            )

            robot_record_config.config = struct_record_config

            self.simulation_handle.begin_modification_transaction()
            self.simulation_handle.add_structure(robot_config)
            self.simulation_handle.add_record(robot_record_config)
            self.simulation_handle.end_modification_transaction()

            print(f"Added robot {robot_config.name}")

            if self.web_client is not None:
                # Update the robot list on the web server
                self.web_client.update_robot_list(self.robots)

    def remove_objects(self, object_names: Union[str, List[str]]):
        if isinstance(object_names, str):
            object_names = [object_names]

        for object_name in object_names:
            if object_name not in self.objects:
                raise ValueError(f"Object {object_name} not found")
            if self.simulation_handle is not None:
                self.simulation_handle.begin_modification_transaction()
                self.simulation_handle.delete_structure(object_name)
                self.simulation_handle.end_modification_transaction()
            self.objects.remove(object_name)
            print(f"Removed object {object_name}")

    def remove_robot(self, robot_name: str):
        if robot_name not in self.robots:
            raise ValueError(f"Robot {robot_name} not found")
        if self.simulation_handle is not None:
            robot_structure_gid = self.simulation_handle.find_structure(robot_name)
            robot_record_gid = self.simulation_handle.find_record(
                robot_name + "_record"
            )
            self.simulation_handle.begin_modification_transaction()
            self.simulation_handle.delete_structure(robot_structure_gid)
            self.simulation_handle.delete_record(robot_record_gid)
            self.simulation_handle.end_modification_transaction()
            # Remove cached latent for this robot if present
            self.robot_latents.pop(robot_name, None)
            self.robots.remove(robot_name)
            print(f"Removed robot {robot_name}")

            if self.web_client is not None:
                # Update the robot list on the web server
                self.web_client.update_robot_list(self.robots)

    def get_and_clear_robot_info(self, robot_name: str):
        if robot_name not in self.robots:
            raise ValueError(f"Robot {robot_name} not found")
        info = (
            self.observations[robot_name],
            self.actions[robot_name],
            self.reward_states[robot_name],
        )
        self.observations[robot_name] = []
        self.actions[robot_name] = []
        self.reward_states[robot_name] = []
        return info

    def set_global_camera_fixed_at(self, position: np.ndarray, orientation: Quaternion):
        """Set the global camera to a fixed position and orientation"""
        self.global_camera_state.position = (
            position.tolist() if isinstance(position, np.ndarray) else list(position)
        )

        # Convert Quaternion to list [x, y, z, w]
        if isinstance(orientation, Quaternion):
            self.global_camera_state.orientation = [
                orientation.x,
                orientation.y,
                orientation.z,
                orientation.w,
            ]
        else:
            self.global_camera_state.orientation = list(orientation)

        self.global_camera_state.tracking_robot = None

        if self.web_client is not None:
            # Update the camera state on the web server
            self.web_client.camera_state = self.global_camera_state

    def set_global_camera_track_robot(self, robot_name: str, track_height: float = 1.0):
        """Set the global camera to track a specific robot"""
        if robot_name is not None and robot_name not in self.robots:
            raise ValueError(f"Robot {robot_name} not found")

        self.global_camera_state.tracking_robot = robot_name
        self.global_camera_state.track_height = track_height

        if self.web_client is not None:
            # Update the camera state on the web server
            self.web_client.camera_state = self.global_camera_state

    @staticmethod
    def build_voxel_observations(structure_record: RS_StructureRecord, com: np.ndarray):
        # shape [voxel_num, 3]
        relative_voxel_positions = t.from_numpy(structure_record.voxel_position() - com)
        # shape [voxel_num, 1]
        pressures = t.from_numpy(
            np.mean(structure_record.voxel_poissons_strain(), axis=1, keepdims=True)
        )
        return relative_voxel_positions, pressures

    @staticmethod
    def build_kinematic_graph(structure_record: RS_StructureRecord, com: np.ndarray):
        rigid_body_mass = structure_record.rigid_body_mass()
        rigid_body_com = structure_record.rigid_body_com()
        rigid_body_orientations = structure_record.rigid_body_orientation()
        rigid_body_linear_velocities = structure_record.rigid_body_linear_velocity()
        rigid_body_angular_velocities = structure_record.rigid_body_angular_velocity()
        node_features = t.from_numpy(
            np.concatenate(
                (
                    rigid_body_com - com,
                    rigid_body_orientations,
                    rigid_body_linear_velocities,
                    rigid_body_angular_velocities,
                    rigid_body_mass[:, None],
                ),
                axis=1,
            )
        )
        joint_num = len(structure_record.joint_type())
        edges = t.zeros([2, joint_num * 2], dtype=t.long)
        edge_features = t.zeros([joint_num * 2, 9], dtype=t.float32)

        joint_rb_a = structure_record.joint_rigid_body_a_sid()
        joint_rb_b = structure_record.joint_rigid_body_b_sid()
        joint_positions = structure_record.joint_position()
        joint_axes = structure_record.joint_axis()
        joint_hinge_min = structure_record.joint_hinge_min()
        joint_hinge_max = structure_record.joint_hinge_max()
        joint_angles = structure_record.joint_angle()

        for idx in range(joint_num):
            rb_a = int(joint_rb_a[idx])
            rb_b = int(joint_rb_b[idx])
            j_pos = joint_positions[idx]
            j_axis = joint_axes[idx]
            j_angle_min = joint_hinge_min[idx]
            j_angle_max = joint_hinge_max[idx]
            j_angle = joint_angles[idx]

            edges[0][idx] = rb_a
            edges[1][idx] = rb_b
            edge_features[idx] = t.from_numpy(
                np.concatenate(
                    [
                        j_pos - rigid_body_com[rb_a],
                        j_axis,
                        np.array([j_angle_min, j_angle_max, j_angle]),
                    ]
                )
            )

            # Add reverse edge
            edges[0][idx + joint_num] = rb_b
            edges[1][idx + joint_num] = rb_a
            edge_features[idx + joint_num] = t.from_numpy(
                np.concatenate(
                    [
                        j_pos - rigid_body_com[rb_b],
                        -j_axis,
                        np.array([j_angle_min, j_angle_max, j_angle]),
                    ]
                )
            )
        return (
            t.from_numpy(rigid_body_com - com),
            node_features,
            edges,
            edge_features,
        )

    @staticmethod
    def get_observation_processor(
        build_voxel_observations: bool = True,
        build_kinematic_graph: bool = True,
        build_reward_state: bool = True,
        build_velocity_observations: bool = True,
    ):
        def process_observation(structure_controller, prev_com, robot_latents):
            structure_name = structure_controller.name()

            prev_com_structure = prev_com.get(structure_name)
            robot_latent_structure = robot_latents.get(structure_name)
            if robot_latent_structure is None:
                raise ValueError(f"Missing latent for robot {structure_name}")
            structure_record = structure_controller.structure_record()

            observation = {}
            reward_state = {}

            # shape [3]
            com = np.mean(structure_record.voxel_position(), axis=0)

            # Attach per-robot latent vector for gating, shape [1, latent_dim]
            observation["robot_latent"] = robot_latent_structure

            if build_voxel_observations:
                relative_voxel_positions, pressures = (
                    RiseEnvForLatentConditionedMoE.build_voxel_observations(
                        structure_record, com
                    )
                )
                observation["relative_voxel_positions"] = relative_voxel_positions
                observation["voxel_features"] = pressures
                observation["com"] = t.from_numpy(com).unsqueeze(0)

            if build_velocity_observations:
                if prev_com_structure is not None:
                    velocity = observation["com"] - prev_com_structure
                else:
                    velocity = t.zeros_like(observation["com"])
                observation["velocity"] = velocity

                prev_com[structure_name] = observation["com"].clone()

            if build_kinematic_graph:
                (
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                ) = RiseEnvForLatentConditionedMoE.build_kinematic_graph(
                    structure_record, com
                )
                observation["relative_node_positions"] = relative_node_positions
                observation["node_features"] = node_features
                observation["edges"] = edges
                observation["edge_features"] = edge_features

            if build_reward_state:
                # numpy, shape [voxel_num, 3]
                reward_state["voxel_positions"] = structure_record.voxel_position()
                # numpy, shape [3]
                reward_state["com"] = com

            return (
                structure_name,
                observation,
                reward_state,
            )

        return process_observation

    @staticmethod
    def get_camera_quaternion_forward_vector(q):
        """Get the forward vector (x-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        # Rotate (1,0,0) vector using quaternion
        forward = quat.rotate([1, 0, 0])
        return forward

    @staticmethod
    def get_camera_quaternion_right_vector(q):
        """Get the right vector (y-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        # Rotate (0,1,0) vector using quaternion
        right = quat.rotate([0, 1, 0])
        return right

    @staticmethod
    def get_camera_quaternion_up_vector(q):
        """Get the up vector (z-axis) rotated by quaternion"""
        quat = Quaternion(w=q[3], x=q[0], y=q[1], z=q[2])
        # Rotate (0,0,1) vector using quaternion
        up = quat.rotate([0, 0, 1])
        return up

    @staticmethod
    def look_at_quaternion(eye, center, global_up=(0, 0, 1)):
        """
        Create a quaternion that makes an object at eye position look at center.
        For a camera whose local forward vector is (1,0,0) and up vector is (0,0,1).

        Parameters:
            eye (list): Position of the object [x, y, z]
            center (list): Position to look at [x, y, z]
            global_up (list): Global up direction, defaults to [0, 0, 1]

        Returns:
            list: Quaternion in [x, y, z, w] format
        """
        # Convert inputs to numpy arrays
        eye = np.array(eye, dtype=np.float64)
        center = np.array(center, dtype=np.float64)
        global_up = np.array(global_up, dtype=np.float64)

        # Calculate the forward direction vector (from eye to center)
        forward = center - eye
        forward = forward / np.linalg.norm(forward)  # Normalize

        # Calculate the right vector using cross product
        right = np.cross(global_up, forward)
        right = right / np.linalg.norm(right)  # Normalize

        # Calculate the local up vector using cross product
        local_up = np.cross(forward, right)
        local_up = local_up / np.linalg.norm(local_up)  # Normalize

        # Create rotation matrix
        rotation_matrix = np.column_stack((forward, right, local_up))

        # Convert rotation matrix to quaternion using pyquaternion
        q = Quaternion(matrix=rotation_matrix)

        # Return the quaternion in [x, y, z, w] format
        return [q.x, q.y, q.z, q.w]

    @staticmethod
    def get_callback(
        framework: Union[A2C, PPO],
        observations: Dict[str, List[dict]],
        actions: Dict[str, List[List[t.Tensor]]],
        reward_states: Dict[str, List[dict]],
        prev_com: Dict[str, t.Tensor],
        robot_latents: Dict[str, t.Tensor],
        observation_processor,
        angles: np.ndarray,
        web_client: CameraWebClient,
        global_camera_name: str,
        camera_move_speed: float = 0.05,
    ):

        def callback(
            ids_: List[int],
            controllers: List[List[Tuple[RSE_SimulationControllerType, Any]]],
        ):
            # Organize controllers by type and name
            robot_name_to_controller_idx = {}
            camera_name_to_controller_idx = {}

            for idx, controller in enumerate(controllers[0]):
                if (
                    controller[0]
                    == RSE_SimulationControllerType.RSE_SIMULATION_STRUCTURE_CONTROLLER
                ):
                    robot_name_to_controller_idx[controller[1].name()] = idx
                elif (
                    controller[0]
                    == RSE_SimulationControllerType.RSE_SIMULATION_CAMERA_CONTROLLER
                ):
                    camera_name_to_controller_idx[controller[1].name()] = idx

            # Process structure observations
            if len(robot_name_to_controller_idx) != 0:
                robot_names = []
                with Timer("Preprocess"):
                    pool = ThreadPool(4)
                    structure_controllers = [
                        controller
                        for controller_type, controller in controllers[0]
                        if controller_type
                        == RSE_SimulationControllerType.RSE_SIMULATION_STRUCTURE_CONTROLLER
                    ]
                    results = list(
                        pool.map(
                            observation_processor,
                            structure_controllers,
                            prev_com,
                            robot_latents,
                        )
                    )
                    for structure_name, observation, reward_state in results:
                        robot_names.append(structure_name)
                        observations[structure_name].append(observation)
                        reward_states[structure_name].append(reward_state)
                    mapped_results = {r[0]: r[1] for r in results}

                    all_observations = list_of_dict_to_dict_of_list(
                        [mapped_results[robot_name] for robot_name in robot_names]
                    )

                # Calculate model outputs
                with Timer("Control"):
                    with t.no_grad():
                        all_output = framework.act(
                            all_observations,
                            call_dp_or_ddp_internal_module=True,
                        )
                    all_action = [angles[a.cpu().numpy()] for a in all_output[0]]

                    if len(all_output) > 3:
                        all_robot_routing_weights = [
                            r.unsqueeze(0).cpu().numpy() for r in all_output[3]
                        ]
                    else:
                        all_robot_routing_weights = [None] * len(robot_names)

                    # Apply actions to robots, and save routing weights to reward states
                    for robot_idx, robot_name in enumerate(robot_names):
                        controller_idx = robot_name_to_controller_idx[robot_name]
                        actions[robot_name].append(
                            [all_output[0][robot_idx], all_output[1][robot_idx]]
                        )
                        robot_controller = controllers[0][controller_idx][
                            1
                        ]  # type: RS_StructureController
                        rotation_signals = (
                            robot_controller.signal().rotation_angle_signals()
                        )
                        if rotation_signals is not None and len(rotation_signals) > 0:
                            rotation_signals[:] = all_action[robot_idx]

                        # inner numpy array shape [1, num_experts] or None
                        reward_states[robot_name][-1]["routing_weight"] = (
                            all_robot_routing_weights[robot_idx]
                        )

            with Timer("Web"):
                if web_client is not None:
                    # Get the latest camera state from the web client - this allows web UI control
                    camera_state = web_client.get_camera_state()

                    global_camera_image = None
                    global_camera_time = 0

                    # Handle global camera
                    if global_camera_name in camera_name_to_controller_idx:
                        camera_controller = controllers[0][
                            camera_name_to_controller_idx[global_camera_name]
                        ][
                            1
                        ]  # type: RS_CameraController
                        cam_record = camera_controller.camera_record()

                        if cam_record.is_valid():
                            # If tracking a robot, update the camera position to follow the robot
                            if (
                                camera_state.tracking_robot
                                and camera_state.tracking_robot
                                in robot_name_to_controller_idx
                            ):
                                robot_idx = robot_name_to_controller_idx[
                                    camera_state.tracking_robot
                                ]
                                robot_controller = controllers[0][robot_idx][
                                    1
                                ]  # type: RS_StructureController
                                structure_record = robot_controller.structure_record()

                                if structure_record:
                                    # Get the robot's center of mass
                                    robot_pos = np.mean(
                                        structure_record.voxel_position(), axis=0
                                    )

                                    # Set camera position to look at robot from a distance at specified height
                                    camera_pos = [
                                        robot_pos[0] - 1.5,  # Back a bit
                                        robot_pos[1],  # Centered
                                        robot_pos[2]
                                        + camera_state.track_height,  # Above
                                    ]

                                    # Create quaternion to look at robot
                                    camera_quat = RiseEnvForLatentConditionedMoE.look_at_quaternion(
                                        camera_pos, robot_pos
                                    )

                                    # Update camera state with tracking position/orientation
                                    camera_state.position = camera_pos
                                    camera_state.orientation = camera_quat

                            elif not camera_state.bird_view_mode:
                                # Not tracking - apply manual controls
                                # Get current position and orientation
                                quat = cam_record.orientation()
                                pos = cam_record.position()

                                if quat is not None and len(quat) == 4:
                                    # Store current orientation in camera state
                                    camera_state.orientation = [
                                        quat[0],
                                        quat[1],
                                        quat[2],
                                        quat[3],
                                    ]

                                if pos is not None and len(pos) == 3:
                                    # Store current position in camera state
                                    camera_state.position = [pos[0], pos[1], pos[2]]

                                # Apply pending mouse rotations
                                if (
                                    camera_state.pending_mouse_pitch != 0
                                    or camera_state.pending_mouse_yaw != 0
                                ):
                                    # Get axes for rotations based on current orientation
                                    pitch_axis = RiseEnvForLatentConditionedMoE.get_camera_quaternion_right_vector(
                                        camera_state.orientation
                                    )
                                    yaw_axis = RiseEnvForLatentConditionedMoE.get_camera_quaternion_up_vector(
                                        camera_state.orientation
                                    )

                                    # Create rotation quaternions for pitch and yaw
                                    pitch_quat = Quaternion(
                                        axis=pitch_axis,
                                        angle=camera_state.pending_mouse_pitch,
                                    )
                                    yaw_quat = Quaternion(
                                        axis=yaw_axis,
                                        angle=camera_state.pending_mouse_yaw,
                                    )

                                    # Convert camera quaternion to Quaternion object
                                    curr_quat = Quaternion(
                                        w=camera_state.orientation[3],
                                        x=camera_state.orientation[0],
                                        y=camera_state.orientation[1],
                                        z=camera_state.orientation[2],
                                    )

                                    # Apply rotations to camera quaternion
                                    new_quat = pitch_quat * curr_quat * yaw_quat

                                    # Update the camera quaternion in [x, y, z, w] format
                                    camera_state.orientation = [
                                        new_quat.x,
                                        new_quat.y,
                                        new_quat.z,
                                        new_quat.w,
                                    ]

                                    # Reset pending rotations
                                    camera_state.pending_mouse_pitch = 0
                                    camera_state.pending_mouse_yaw = 0

                                # Handle roll control (Q/E keys)
                                if camera_state.roll_input != 0:
                                    # Compute forward vector for roll axis
                                    forward_vector = RiseEnvForLatentConditionedMoE.get_camera_quaternion_forward_vector(
                                        camera_state.orientation
                                    )

                                    # Create a roll rotation around the forward vector
                                    roll_quat = Quaternion(
                                        axis=forward_vector,
                                        angle=camera_state.roll_input,
                                    )

                                    # Current quaternion as Quaternion object
                                    curr_quat = Quaternion(
                                        w=camera_state.orientation[3],
                                        x=camera_state.orientation[0],
                                        y=camera_state.orientation[1],
                                        z=camera_state.orientation[2],
                                    )

                                    # Apply roll to current orientation
                                    new_orientation = roll_quat * curr_quat

                                    # Convert back to list format
                                    camera_state.orientation = [
                                        new_orientation.x,
                                        new_orientation.y,
                                        new_orientation.z,
                                        new_orientation.w,
                                    ]

                                # Apply WASD movement
                                if np.any(camera_state.movement_direction):
                                    # Get movement vectors based on camera orientation
                                    forward_vector = RiseEnvForLatentConditionedMoE.get_camera_quaternion_forward_vector(
                                        camera_state.orientation
                                    )
                                    right_vector = RiseEnvForLatentConditionedMoE.get_camera_quaternion_right_vector(
                                        camera_state.orientation
                                    )

                                    # Calculate movement in world space
                                    world_forward = (
                                        np.array(forward_vector)
                                        * camera_state.movement_direction[0]
                                    )
                                    world_right = (
                                        np.array(right_vector)
                                        * camera_state.movement_direction[1]
                                    )
                                    world_move = world_forward + world_right

                                    # Calculate new position by applying movement
                                    new_position = [
                                        camera_state.position[0]
                                        + world_move[0] * camera_move_speed,
                                        camera_state.position[1]
                                        + world_move[1] * camera_move_speed,
                                        camera_state.position[2]
                                        + world_move[2] * camera_move_speed,
                                    ]

                                    # Update position in camera state
                                    camera_state.position = new_position

                            # Apply camera position and orientation to simulation
                            position_signals = (
                                camera_controller.signal().position_signals()
                            )
                            orientation_signals = (
                                camera_controller.signal().orientation_signals()
                            )

                            # Apply position to signals
                            if (
                                position_signals is not None
                                and len(position_signals) >= 3
                            ):
                                position_signals[0] = camera_state.position[0]
                                position_signals[1] = camera_state.position[1]
                                position_signals[2] = camera_state.position[2]

                            # Apply orientation to signals
                            if (
                                orientation_signals is not None
                                and len(orientation_signals) >= 4
                            ):
                                orientation_signals[0] = camera_state.orientation[
                                    0
                                ]  # x
                                orientation_signals[1] = camera_state.orientation[
                                    1
                                ]  # y
                                orientation_signals[2] = camera_state.orientation[
                                    2
                                ]  # z
                                orientation_signals[3] = camera_state.orientation[
                                    3
                                ]  # w

                            # Get pixel data and send to web server
                            global_camera_image = (
                                cam_record.pixels()
                            )  # shape [height, width, 3]
                            global_camera_time = cam_record.time()

                    # Only send data if global camera is valid
                    if global_camera_image is not None:
                        # Send updated camera state and feed to the web server
                        web_client.camera_state = camera_state
                        web_client.update_camera_feed(
                            global_camera_time,
                            global_camera_image,
                            {},  # No robot cameras
                        )

        return callback

    def run(self):
        """Run the simulation with web server integration"""
        # Create the observation processor
        observation_processor = self.get_observation_processor(
            self.build_voxel_observations,
            self.build_kinematic_graph,
            self.build_reward_state,
        )

        # Create the simulation callback
        callback = self.get_callback(
            self.framework,
            self.observations,
            self.actions,
            self.reward_states,
            self.prev_com,
            self.robot_latents,
            observation_processor,
            self.angles,
            self.web_client,
            self.global_camera_name,
        )

        # Run the simulation
        self.simulation_handle = self.rise.run_sims(
            [self.config],
            [0],
            callback,
            save_record=True,
            max_time=100000.0,
            log_level="info",
            policy="sequential",
            constraint_update_interval=2,
        )[0]
