import pybullet as p
import numpy as np
import time
import os

class FetchController:
    def __init__(self, robot_id, debug=False):
        """Initialize the Fetch robot controller
        
        Args:
            robot_id: PyBullet body ID for the Fetch robot
            debug: Enable debug output
        """
        self.robot_id = robot_id
        self.debug = debug
        self.width = 84
        self.height = 84
        self.fov = 54.0
        self.near = 0.3
        self.far = 3.0
        self.keyboard_forward = 0
        self.keyboard_turn = 0
        
        # For saving LIDAR images
        self.lidar_save_dir = "lidar_scans"
        self.scan_counter = 0
        # Create directory for LIDAR scans if it doesn't exist
        if debug and not os.path.exists(self.lidar_save_dir):
            try:
                os.makedirs(self.lidar_save_dir)
                print(f"Created directory '{self.lidar_save_dir}' for LIDAR scan images")
            except Exception as e:
                print(f"Warning: Could not create directory for LIDAR scans: {e}")
        
        # LIDAR specifications (SICK TIM571)
        self.lidar_range = 25.0  # 25 meters range
        self.lidar_fov = 220.0  # 220 degree field of view
        self.lidar_resolution = 1/3  # 1/3 degree angular resolution
        self.lidar_update_rate = 15  # 15Hz update rate
        self.num_lidar_rays = int(self.lidar_fov / self.lidar_resolution)
        # Position LIDAR at the front of the robot base, slightly elevated
        self.lidar_offset = [0.3, 0, 0.4]  # 30cm forward, 40cm up from base
        
        # Debug line IDs for visualization
        self.debug_line_view_id = -1
        self.debug_line_up_id = -1
        self.debug_line_right_id = -1
        self.lidar_debug_lines = []
        self.debug_draw_failed = False  # Track if debug drawing has failed

        # Joint information
        self.arm_joint_indices = []
        self.arm_joint_names = [
            "shoulder_pan_joint", 
            "shoulder_lift_joint", 
            "upperarm_roll_joint", 
            "elbow_flex_joint", 
            "forearm_roll_joint", 
            "wrist_flex_joint", 
            "wrist_roll_joint"
        ]
        
        # Define joint limits for the wrist joints
        self.joint_limits = {
            "wrist_flex_joint": {"min": -1.5, "max": 2.0},  # Real Robot Limits
            "wrist_roll_joint": {"min": -2.5, "max": 2.5}   # Real Robot Limits
        }
        
        self.start_positions = [0.0, 1.4, 0.0, -2.2, 0.0, 2, 1.55]
        self.start_torso_height = 0.2
        
        # Find arm joint indices
        self._find_arm_joint_indices()
        # self.lock_joints_except_wrist()
        self.lock_all_joints()
        # Starting arm position and torso height
        self.set_arm_position(self.start_positions)
        self.set_torso_height(self.start_torso_height)
        
        # Flag to track if arm is currently moving
        self.arm_is_moving = False
        
    def _print_joint_info(self):
        """Print information about all joints in the robot"""
        print(f"\nJoint information for robot {self.robot_id}:")
        for i in range(p.getNumJoints(self.robot_id)):
            joint_info = p.getJointInfo(self.robot_id, i)
            print(f"Joint {i}: {joint_info[1].decode('utf-8')}, Type: {joint_info[2]}")
    
    def _find_arm_joint_indices(self):
        """Find the joint indices for the arm joints"""
        self.arm_joint_indices = []
        for joint_name in self.arm_joint_names:
            found = False
            for i in range(p.getNumJoints(self.robot_id)):
                joint_info = p.getJointInfo(self.robot_id, i)
                if joint_name == joint_info[1].decode('utf-8'):
                    self.arm_joint_indices.append(i)
                    found = True
                    break
            if not found and self.debug:
                print(f"Warning: Joint '{joint_name}' not found in the robot")
        
        if self.debug:
            print(f"Arm joint indices: {self.arm_joint_indices}")
            
    def apply_action(self, action, duration=1.0, move_base=False):
        """
        Apply an action to control the velocity of the wrist joints in the simulated Fetch robot.
        
        Args:
            action: A list of 2 values between -1.0 and 1.0, representing the velocity multipliers for
                    wrist_flex_joint and wrist_roll_joint (negative rotates counterclockwise, positive clockwise).
            duration: Duration (in seconds) for which the action should be applied.
                    
        Returns:
            bool: True if action was applied successfully, False otherwise.
        """
        # Check if arm is already moving
        if self.arm_is_moving:
            if self.debug:
                print("Action blocked: arm is already moving")
            return False

        # Set flag to indicate arm is moving
        self.arm_is_moving = True

        try:
            # Monitor locked joints and ensure they remain static
            joints_to_monitor = [
                "shoulder_pan_joint",    # index 0
                "shoulder_lift_joint",   # index 1
                "upperarm_roll_joint",   # index 2
                "elbow_flex_joint",      # index 3
                "forearm_roll_joint"     # index 4
            ]
            monitor_indices = []
            for joint_name in joints_to_monitor:
                if joint_name in self.arm_joint_names:
                    idx = self.arm_joint_names.index(joint_name)
                    monitor_indices.append(self.arm_joint_indices[idx])
            
            for i, joint_index in enumerate(monitor_indices):
                joint_state = p.getJointState(self.robot_id, joint_index)
                joint_velocity = joint_state[1]  # velocity is at index 1
                if abs(joint_velocity) > 0.01:
                    if self.debug:
                        print(f"WARNING: Locked joint {joints_to_monitor[i]} is moving with velocity {joint_velocity:.4f}")
                    # Re-lock the joint by commanding its current position
                    current_position = joint_state[0]
                    p.setJointMotorControl2(
                        bodyUniqueId=self.robot_id,
                        jointIndex=joint_index,
                        controlMode=p.POSITION_CONTROL,
                        targetPosition=current_position,
                        force=5000.0
                    )
                    
            if move_base:
                self.move_base_velocity(self.keyboard_forward, self.keyboard_turn)

            # Define the maximum velocity scale
            max_velocity = 1.0

            # Get indices for the wrist joints (assumed to be at indices 5 and 6)
            wrist_flex_index = self.arm_joint_indices[5]
            wrist_roll_index = self.arm_joint_indices[6]

            # Retrieve current positions
            wrist_flex_pos = p.getJointState(self.robot_id, wrist_flex_index)[0]
            wrist_roll_pos = p.getJointState(self.robot_id, wrist_roll_index)[0]

            # Calculate target velocities based on the action multipliers
            wrist_flex_velocity = action[0] * max_velocity
            wrist_roll_velocity = action[1] * max_velocity

            # Retrieve joint limits for the wrist joints
            flex_limit_min = self.joint_limits["wrist_flex_joint"]["min"]
            flex_limit_max = self.joint_limits["wrist_flex_joint"]["max"]
            roll_limit_min = self.joint_limits["wrist_roll_joint"]["min"]
            roll_limit_max = self.joint_limits["wrist_roll_joint"]["max"]

            # Prevent movement beyond the joint limits
            if (wrist_flex_pos <= flex_limit_min and wrist_flex_velocity < 0) or \
            (wrist_flex_pos >= flex_limit_max and wrist_flex_velocity > 0):
                wrist_flex_velocity = 0
                # print(f"Wrist flex at limit: {wrist_flex_pos:.2f}, blocking velocity")
            
            if (wrist_roll_pos <= roll_limit_min and wrist_roll_velocity < 0) or \
            (wrist_roll_pos >= roll_limit_max and wrist_roll_velocity > 0):
                wrist_roll_velocity = 0
                # print(f"Wrist roll at limit: {wrist_roll_pos:.2f}, blocking velocity")

            # Apply velocity control to the wrist joints
            p.setJointMotorControl2(
                bodyUniqueId=self.robot_id,
                jointIndex=wrist_flex_index,
                controlMode=p.VELOCITY_CONTROL,
                targetVelocity=wrist_flex_velocity,
                force=100.0
            )
            p.setJointMotorControl2(
                bodyUniqueId=self.robot_id,
                jointIndex=wrist_roll_index,
                controlMode=p.VELOCITY_CONTROL,
                targetVelocity=wrist_roll_velocity,
                force=100.0
            )
            
            # Reset the arm moving flag so that new actions can be applied
            self.arm_is_moving = False
            return True

        except Exception as e:
            print(f"Error applying action: {e}")
            self.arm_is_moving = False
            return False
           
    def set_torso_height(self, height):
        if not hasattr(self, 'torso_joint_index'):
            self.torso_joint_index = None
            for i in range(p.getNumJoints(self.robot_id)):
                joint_info = p.getJointInfo(self.robot_id, i)
                if joint_info[1].decode('utf-8') == "torso_lift_joint":
                    self.torso_joint_index = i
                    break
        
        if self.torso_joint_index is not None:
            p.resetJointState(self.robot_id, self.torso_joint_index, height)
            if self.debug:
                print(f"Set torso height to {height}")
        else:
            print("Error: Torso joint not found")

    def set_arm_position(self, joint_positions):
        if len(joint_positions) != len(self.arm_joint_indices):
            print(f"Error: Expected {len(self.arm_joint_indices)} joint positions, got {len(joint_positions)}")
            return
        
        # Get and print current arm positions before setting new ones
        current_positions = []
        for joint_index in self.arm_joint_indices:
            joint_state = p.getJointState(self.robot_id, joint_index)
            current_positions.append(joint_state[0])  # Position is at index 0
        
        # print(f"Current arm positions: {current_positions}")
        
        # Set new positions
        for i, joint_index in enumerate(self.arm_joint_indices):
            # Reset joint state
            p.resetJointState(self.robot_id, joint_index, joint_positions[i])
            
            # Also apply position control to hold the joint in place
            p.setJointMotorControl2(
                bodyUniqueId=self.robot_id,
                jointIndex=joint_index,
                controlMode=p.POSITION_CONTROL,
                targetPosition=joint_positions[i],
                force=5000.0  # High force to maintain position
            )
        
        # print(f"Set arm to positions: {joint_positions}")
        
        # After setting positions, re-lock joints except wrist
        # self.lock_joints_except_wrist()
    def move_base_velocity(self, linear_velocity, angular_velocity):
        """
        Move the base of the robot with linear and angular velocities.
        
        Args:
            linear_velocity: Linear velocity in m/s (positive = forward, negative = backward)
            angular_velocity: Angular velocity in rad/s (positive = left turn, negative = right turn)
        """
        # Find the base joint indices
        if not hasattr(self, 'left_wheel_index') or not hasattr(self, 'right_wheel_index'):
            self.left_wheel_index = None
            self.right_wheel_index = None
            
            for i in range(p.getNumJoints(self.robot_id)):
                joint_info = p.getJointInfo(self.robot_id, i)
                joint_name = joint_info[1].decode('utf-8')
                
                if joint_name == "l_wheel_joint":
                    self.left_wheel_index = i
                elif joint_name == "r_wheel_joint":
                    self.right_wheel_index = i
            
            if self.left_wheel_index is None or self.right_wheel_index is None:
                if self.debug:
                    print("Error: Could not find wheel joints")
                return
        
        # Calculate wheel velocities for differential drive
        # For Fetch, wheel radius is approximately 0.0625m and wheel base is approximately 0.38m
        wheel_radius = 0.07  # meters
        wheel_separation = 0.38  # meters
        
        # Convert linear and angular velocities to wheel velocities
        left_wheel_velocity = (linear_velocity - angular_velocity * wheel_separation / 2) / wheel_radius
        right_wheel_velocity = (linear_velocity + angular_velocity * wheel_separation / 2) / wheel_radius
        
        # Apply velocity control to the wheel joints
        p.setJointMotorControl2(
            bodyUniqueId=self.robot_id,
            jointIndex=self.left_wheel_index,
            controlMode=p.VELOCITY_CONTROL,
            targetVelocity=left_wheel_velocity,
            force=100.0
        )
        
        p.setJointMotorControl2(
            bodyUniqueId=self.robot_id,
            jointIndex=self.right_wheel_index,
            controlMode=p.VELOCITY_CONTROL,
            targetVelocity=right_wheel_velocity,
            force=100.0
        )
        
        if self.debug:
            print(f"Base velocity: linear={linear_velocity:.2f} m/s, angular={angular_velocity:.2f} rad/s")
            print(f"Wheel velocities: left={left_wheel_velocity:.2f} rad/s, right={right_wheel_velocity:.2f} rad/s")
    def _get_camera_transform(self):
        """Calculates the camera position, target position, and up vector."""
        # Get robot's current position and orientation
        pos, orn = p.getBasePositionAndOrientation(self.robot_id)
        
        # Convert quaternion to rotation matrix
        rot_matrix = p.getMatrixFromQuaternion(orn)
        rot_matrix = np.array(rot_matrix).reshape(3, 3)
        
        # Define camera offset in robot's local frame
        camera_offset = [0.15, 0, 1.3]  # 15cm forward, 1.3m above
        
        # Define target offset in robot's local frame
        target_offset = [1.6, 0, 0.38]  # 1.6m ahead, 0.38m above base
        
        # Transform offsets to world frame using rotation matrix
        camera_world_offset = np.dot(rot_matrix, camera_offset)
        target_world_offset = np.dot(rot_matrix, target_offset)
        
        # Calculate actual positions in world frame
        camera_pos = [
            pos[0] + camera_world_offset[0],
            pos[1] + camera_world_offset[1],
            pos[2] + camera_world_offset[2]
        ]
        
        target_pos = [
            pos[0] + target_world_offset[0],
            pos[1] + target_world_offset[1],
            pos[2] + target_world_offset[2]
        ]
        
        # Calculate the direction vector from camera to target
        direction = [
            target_pos[0] - camera_pos[0],
            target_pos[1] - camera_pos[1],
            target_pos[2] - camera_pos[2]
        ]
        
        # Normalize the direction vector
        length = np.sqrt(direction[0]**2 + direction[1]**2 + direction[2]**2)
        if length > 0:
            direction = [d/length for d in direction]
        
        # Calculate right vector (perpendicular to direction and world up)
        world_up = [0, 0, 1]
        right = [
            direction[1] * world_up[2] - direction[2] * world_up[1],
            direction[2] * world_up[0] - direction[0] * world_up[2],
            direction[0] * world_up[1] - direction[1] * world_up[0]
        ]
        
        # Normalize right vector
        right_length = np.sqrt(right[0]**2 + right[1]**2 + right[2]**2)
        if right_length > 0:
            right = [r/right_length for r in right]
            
        # Calculate camera up vector (perpendicular to direction and right)
        up_vector = [
            right[1] * direction[2] - right[2] * direction[1],
            right[2] * direction[0] - right[0] * direction[2],
            right[0] * direction[1] - right[1] * direction[0]
        ]

        return camera_pos, target_pos, up_vector, right

    def get_camera_image(self):
        """
        Get RGB and depth images from the camera with the camera looking
        at a fixed target position where the arm is located.
        
        Returns:
            tuple: (rgb_image, depth_image, segmentation)
        """
        # Get camera transform details
        camera_pos, target_pos, up_vector, _ = self._get_camera_transform()
        
        # Compute view matrix
        view_matrix = p.computeViewMatrix(
            cameraEyePosition=camera_pos,
            cameraTargetPosition=target_pos,
            cameraUpVector=up_vector
        )
        
        # Calculate projection matrix
        aspect = self.width / self.height
        projection_matrix = p.computeProjectionMatrixFOV(
            fov=self.fov,
            aspect=aspect,
            nearVal=self.near,
            farVal=self.far
        )
        
        # Get camera image
        _, _, rgb_array, depth_array, segmentation = p.getCameraImage(
            width=self.width,
            height=self.height,
            viewMatrix=view_matrix,
            projectionMatrix=projection_matrix,
            renderer=p.ER_TINY_RENDERER
        )
        # self.visualize_camera_view()
        return rgb_array, depth_array, segmentation
        
    def visualize_camera_view(self, line_length=0.2, line_width=2):
        """
        Draw debug lines in the simulation to show the camera's position, 
        target direction, up vector, and right vector.
        
        Args:
            line_length: Length of the debug lines for up and right vectors.
            line_width: Width of the debug lines.
        """
        # Get camera transform details
        camera_pos, target_pos, up_vector, right_vector = self._get_camera_transform()

        # Define endpoints for the up and right vector lines
        up_pos_end = [camera_pos[i] + up_vector[i] * line_length for i in range(3)]
        right_pos_end = [camera_pos[i] + right_vector[i] * line_length for i in range(3)]

        # Draw line from camera to target (View direction - Red)
        self.debug_line_view_id = p.addUserDebugLine(
            camera_pos, 
            target_pos, 
            [1, 0, 0], 
            lineWidth=line_width, 
            replaceItemUniqueId=self.debug_line_view_id
        )
        
        # Draw line representing the Up vector (Blue)
        self.debug_line_up_id = p.addUserDebugLine(
            camera_pos, 
            up_pos_end, 
            [0, 0, 1], 
            lineWidth=line_width, 
            replaceItemUniqueId=self.debug_line_up_id
        )
        
        # Draw line representing the Right vector (Green)
        self.debug_line_right_id = p.addUserDebugLine(
            camera_pos, 
            right_pos_end, 
            [0, 1, 0], 
            lineWidth=line_width, 
            replaceItemUniqueId=self.debug_line_right_id
        )
    
    def lock_all_joints(self):
        """Lock all arm joints including wrist_flex_joint and wrist_roll_joint, and also lock the torso"""
        if len(self.arm_joint_indices) < 7:
            print("Error: Arm joint indices not properly initialized")
            return
        
        # Define all joints to lock (including wrist_flex_joint and wrist_roll_joint)
        joints_to_lock = [
            "shoulder_pan_joint",    # index 0
            "shoulder_lift_joint",   # index 1
            "upperarm_roll_joint",   # index 2
            "elbow_flex_joint",      # index 3
            "forearm_roll_joint",    # index 4
            "wrist_flex_joint",      # index 5
            "wrist_roll_joint"       # index 6
        ]
        
        # Get indices of joints to lock
        lock_indices = []
        for joint_name in joints_to_lock:
            if joint_name in self.arm_joint_names:
                idx = self.arm_joint_names.index(joint_name)
                lock_indices.append(self.arm_joint_indices[idx])
        
        # Find and add the torso joint to lock
        torso_joint_index = None
        for i in range(p.getNumJoints(self.robot_id)):
            joint_info = p.getJointInfo(self.robot_id, i)
            joint_name = joint_info[1].decode('utf-8')
            if joint_name == "torso_lift_joint":
                torso_joint_index = i
                break
        
        if torso_joint_index is not None:
            lock_indices.append(torso_joint_index)
            if self.debug:
                print("Adding torso joint to the list of joints to lock")
        else:
            if self.debug:
                print("Warning: Could not find torso joint to lock")
        
        if self.debug:
            print(f"Locking the following joints: {joints_to_lock + ['torso_lift_joint']}")
            print(f"Joint indices to lock: {lock_indices}")
        
        # Lock each joint at its current position
        for joint_index in lock_indices:
            # Get current joint position
            joint_state = p.getJointState(self.robot_id, joint_index)
            current_position = joint_state[0]
            
            # Apply position control with high force to lock the joint
            p.setJointMotorControl2(
                bodyUniqueId=self.robot_id,
                jointIndex=joint_index,
                controlMode=p.POSITION_CONTROL,
                targetPosition=current_position,
                force=5000.0,  # Very high force to ensure the joint stays locked
                positionGain=1.0,
                velocityGain=1.0
            )
            
            # Also disable the joint motor to prevent any movement
            p.setJointMotorControl2(
                bodyUniqueId=self.robot_id,
                jointIndex=joint_index,
                controlMode=p.VELOCITY_CONTROL,
                targetVelocity=0,
                force=5000.0
            )
            
            if self.debug:
                joint_info = p.getJointInfo(self.robot_id, joint_index)
                joint_name = joint_info[1].decode('utf-8')
                print(f"Locked joint {joint_name} at position {current_position:.4f}")
        
        if self.debug:
            print("All arm joints and torso have been locked")
            
    def process_keyboard_events(self):
        """Process keyboard events for teleop control."""
        keys = p.getKeyboardEvents()
        for k, v in keys.items():
            # Add escape key to exit
            if k == ord('q') and (v & p.KEY_WAS_TRIGGERED):
                p.disconnect()
                exit(0)
                
            if (k == p.B3G_RIGHT_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_turn = -0.2
            if (k == p.B3G_RIGHT_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_turn = 0
            if (k == p.B3G_LEFT_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_turn = 0.2
            if (k == p.B3G_LEFT_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_turn = 0

            if (k == p.B3G_UP_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_forward = 0.5
            if (k == p.B3G_UP_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_forward = 0
            if (k == p.B3G_DOWN_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_forward = -0.5
            if (k == p.B3G_DOWN_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_forward = 0

        
        
    def keyboard_control(self):
        self.process_keyboard_events()
        
    def keyboard_teleop(self):

        self.move_base_velocity(self.keyboard_forward, self.keyboard_turn)

    def visualize_lidar_plot(self, distances, intensities):
        """
        Create a polar plot visualization of LIDAR scan data.
        
        Args:
            distances: Array of distance measurements
            intensities: Array of intensity values
        """
        if not self.debug:
            return
            
        # Clear the previous plot
        self.ax.clear()
        
        # Calculate angles for the scan points (convert to radians)
        start_angle = -self.lidar_fov/2
        angles = np.array([start_angle + i * self.lidar_resolution for i in range(len(distances))])
        angles_rad = np.radians(angles)
        
        # Create scatter plot
        scatter = self.ax.scatter(angles_rad, distances, 
                                c=intensities, 
                                cmap='jet',
                                alpha=0.6,
                                s=10)
        
        # Set plot limits and labels
        self.ax.set_rlim(0, self.lidar_range)
        self.ax.set_title('LIDAR Scan (Top View)')
        self.ax.grid(True)
        
        # Add colorbar
        if not hasattr(self, 'colorbar'):
            self.colorbar = self.plt.colorbar(scatter)
            self.colorbar.set_label('Intensity')
        
        # Update the plot
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        
    def save_lidar_scan_image(self, distances, intensities, save_every=10):
        """
        Save the LIDAR scan as an image file.
        
        Args:
            distances: Array of distance measurements
            intensities: Array of intensity values
            save_every: Only save every Nth scan to avoid too many files
        
        Returns:
            str: Path to the saved image, or None if no image was saved
        """
        if not self.debug:
            return None
            
        # Only save every Nth scan to avoid too many files
        self.scan_counter += 1
        if self.scan_counter % save_every != 0:
            return None
            
        try:
            import matplotlib
            matplotlib.use('Agg')  # Use non-interactive backend
            import matplotlib.pyplot as plt
            import numpy as np
            import time
            
            # Create figure
            fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={'projection': 'polar'})
            
            # Calculate angles for the scan points (convert to radians)
            start_angle = -self.lidar_fov/2
            angles = np.array([start_angle + i * self.lidar_resolution for i in range(len(distances))])
            angles_rad = np.radians(angles)
            
            # Create scatter plot with colormap based on intensities
            scatter = ax.scatter(angles_rad, distances, 
                              c=intensities, 
                              cmap='jet',
                              alpha=0.8,
                              s=10)
            
            # Add range rings
            for r in [5, 10, 15, 20]:
                if r <= self.lidar_range:
                    circle = plt.Circle((0, 0), r, transform=ax.transData._b, 
                                        fill=False, color='gray', alpha=0.5, linestyle='--')
                    ax.add_artist(circle)
                    ax.text(0, r, f"{r}m", ha='center', va='bottom', color='gray')
            
            # Set plot limits and labels
            ax.set_rlim(0, min(self.lidar_range, 25))  # Limit to 25m for better visualization
            ax.set_theta_zero_location('N')  # 0 degrees at North (front of robot)
            ax.set_theta_direction(-1)  # Clockwise
            ax.set_title('LIDAR Scan (Top View)')
            ax.grid(True)
            
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Intensity')
            
            # Add timestamp and sensor info
            timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
            plt.figtext(0.02, 0.02, f"Time: {timestamp}\nRange: {self.lidar_range}m, FOV: {self.lidar_fov}°", fontsize=8)
            
            # Save figure to file
            filename = f"{self.lidar_save_dir}/lidar_scan_{self.scan_counter}.png"
            plt.savefig(filename, dpi=150, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory
            
            print(f"Saved LIDAR scan to {filename}")
            return filename
            
        except Exception as e:
            print(f"Error saving LIDAR scan image: {e}")
            return None
            
    def get_lidar_scan(self):
        """
        Get LIDAR scan data that matches SICK TIM571 specifications.
        Returns:
            tuple: (distances, intensities) where:
                  distances is array of distance measurements
                  intensities is array of RSSI values
        """
        try:
            # Get base link state
            base_state = p.getLinkState(self.robot_id, 0)  # 0 is typically the base link
            base_pos = base_state[0]  # world position of base
            base_orn = base_state[1]  # world orientation of base
            
            # Convert quaternion to rotation matrix
            rot_matrix = p.getMatrixFromQuaternion(base_orn)
            rot_matrix = np.array(rot_matrix).reshape(3, 3)
            
            # Transform LIDAR offset from base frame to world frame
            lidar_local_pos = np.array(self.lidar_offset)
            lidar_world_offset = np.dot(rot_matrix, lidar_local_pos)
            lidar_pos = [
                base_pos[0] + lidar_world_offset[0],
                base_pos[1] + lidar_world_offset[1],
                base_pos[2] + lidar_world_offset[2]
            ]
            
            # Calculate start angle (-110 degrees)
            start_angle = -np.pi * (self.lidar_fov/2) / 180
            
            # Initialize arrays for distances and intensities
            distances = []
            intensities = []
            
            # We'll skip all debug visualization for now
            # Just perform ray casting for each angle without visualization
            for i in range(self.num_lidar_rays):
                angle = start_angle + (i * np.pi * self.lidar_resolution / 180)
                
                # Calculate ray direction in world frame
                ray_local_dir = [np.cos(angle), np.sin(angle), 0]
                ray_world_dir = np.dot(rot_matrix, ray_local_dir)
                
                ray_end = [
                    lidar_pos[0] + ray_world_dir[0] * self.lidar_range,
                    lidar_pos[1] + ray_world_dir[1] * self.lidar_range,
                    lidar_pos[2] + ray_world_dir[2] * self.lidar_range
                ]
                
                # Perform raycast
                result = p.rayTest(lidar_pos, ray_end)[0]
                hit_fraction = result[2]
                
                # Calculate distance and intensity
                if hit_fraction < 1.0:
                    distance = hit_fraction * self.lidar_range
                    # Simulate intensity based on distance (closer objects have higher intensity)
                    intensity = max(0, 255 * (1.0 - distance/self.lidar_range))
                else:
                    distance = self.lidar_range
                    intensity = 0
                    
                distances.append(distance)
                intensities.append(intensity)
            
            # Alternative visualization method - print basic stats about the scan
            if self.debug:
                min_dist = min(distances)
                avg_dist = sum(distances) / len(distances)
                print(f"LIDAR scan: min distance={min_dist:.2f}m, avg distance={avg_dist:.2f}m")
                
                # Save the LIDAR scan as an image
                # self.save_lidar_scan_image(distances, intensities)
            
            return np.array(distances), np.array(intensities)
            
        except Exception as e:
            print(f"Error in LIDAR scan: {e}")
            # Return empty arrays in case of error
            return np.zeros(self.num_lidar_rays), np.zeros(self.num_lidar_rays)