import pybullet as p
import time
import os
import numpy as np

# Connect to the simulation with a GUI for visualization
p.connect(p.GUI)

# Set gravity
p.setGravity(0, 0, -9.8)

# Add a white floor
floor_shape = p.createCollisionShape(p.GEOM_PLANE)
floor_visual = p.createVisualShape(p.GEOM_PLANE, rgbaColor=[1, 1, 1, 1])  # White color
floor = p.createMultiBody(
    baseMass=0,  # Static object
    baseCollisionShapeIndex=floor_shape,
    baseVisualShapeIndex=floor_visual,
    basePosition=[0, 0, 0]
)

# Set the additional search paths using relative paths
current_dir = os.path.dirname(os.path.abspath(__file__))
# Add the parent directory of fetch_description to the search path
# This is important because mesh references in URDF are relative to this directory
fetch_ros_path = os.path.join(current_dir, "..", "data", "fetch_ros")
p.setAdditionalSearchPath(fetch_ros_path)

# Load the standard Fetch robot URDF
fetch_urdf_path = os.path.join("fetch_description", "robots", "fetch.urdf")
fetch_robot = p.loadURDF(fetch_urdf_path, basePosition=[0, 0, 0], useFixedBase=True)

# Tuck the arm of the Fetch robot
# Define the joint positions for a tucked arm configuration
tucked_joint_positions = [
    1.32,  # shoulder_pan_joint
    1.40,  # shoulder_lift_joint
    -0.2,  # upperarm_roll_joint
    1.72,  # elbow_flex_joint
    0.0,   # forearm_roll_joint
    1.66,  # wrist_flex_joint
    0.0,   # wrist_roll_joint
]
        
start_positions = [0.0, 1.4, 0.0, -2.2, 0.0, 0.7, 1.55]

# Get the joint indices for the arm joints
arm_joint_indices = []
arm_joint_names = [
    "shoulder_pan_joint", 
    "shoulder_lift_joint", 
    "upperarm_roll_joint", 
    "elbow_flex_joint", 
    "forearm_roll_joint", 
    "wrist_flex_joint", 
    "wrist_roll_joint"
]

# Find the torso_lift_joint
torso_joint_index = None
for i in range(p.getNumJoints(fetch_robot)):
    joint_info = p.getJointInfo(fetch_robot, i)
    joint_name = joint_info[1].decode('utf-8')
    if joint_name == "torso_lift_joint":
        torso_joint_index = i
        break

#find left wheel joint
left_wheel_index = None
for i in range(p.getNumJoints(fetch_robot)):
    joint_info = p.getJointInfo(fetch_robot, i)
    joint_name = joint_info[1].decode('utf-8')
    if joint_name == "l_wheel_joint":
        left_wheel_index = i
        break

#find right wheel joint
right_wheel_index = None
for i in range(p.getNumJoints(fetch_robot)):
    joint_info = p.getJointInfo(fetch_robot, i)
    joint_name = joint_info[1].decode('utf-8')
    if joint_name == "r_wheel_joint":
        right_wheel_index = i
        break



for joint_name in arm_joint_names:
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        if joint_name == joint_info[1].decode('utf-8'):
            arm_joint_indices.append(i)
            break

def set_head_camera_position(fetch_robot, pan_angle, tilt_angle):
    """
    Set the head camera of the Fetch robot to the specified pan and tilt angles.

    Args:
        fetch_robot: The PyBullet body ID of the Fetch robot
        pan_angle: The angle to pan the head camera (in radians)
        tilt_angle: The angle to tilt the head camera (in radians)
    """
    # Find the head pan and tilt joints
    head_pan_index = None
    head_tilt_index = None
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        joint_name = joint_info[1].decode('utf-8')
        if joint_name == "head_pan_joint":
            head_pan_index = i
        elif joint_name == "head_tilt_joint":
            head_tilt_index = i

    if head_pan_index is None or head_tilt_index is None:
        print("Warning: Could not find head pan or tilt joint")
        return

    # Set the head camera pan and tilt angles
    p.resetJointState(fetch_robot, head_pan_index, pan_angle)
    p.resetJointState(fetch_robot, head_tilt_index, tilt_angle)
    print(f"Set head camera pan to: {pan_angle}, tilt to: {tilt_angle}")

# Example usage: Set the head camera to specific angles (in radians)
# You may need to calculate the appropriate angles based on your target position
set_head_camera_position(fetch_robot, pan_angle=0.0, tilt_angle=1)  # Adjust angles as needed
def move_torso(fetch_robot, target_height, max_force=500, speed=0.1):
    """
    Move the Fetch robot's torso to the specified height.
    
    Args:
        fetch_robot: The PyBullet body ID of the Fetch robot
        target_height: Target position for the torso lift joint (0.0 to 0.4)
        max_force: Maximum force to apply (default: 500)
        speed: Speed of movement (default: 0.1)
    """
    if torso_joint_index is None:
        print("Error: Torso joint not found")
        return
    
    # Clamp the target height to the valid range
    # target_height = max(0.0, min(0.4, target_height))
    
    # Set position control for the torso joint
    p.setJointMotorControl2(
        bodyUniqueId=fetch_robot,
        jointIndex=torso_joint_index,
        controlMode=p.POSITION_CONTROL,
        targetPosition=target_height,
        force=max_force,
        maxVelocity=speed
    )
    
    # Allow time for the movement to complete
    for _ in range(500):
        p.stepSimulation()
        time.sleep(1./240.)
    
    print(f"Torso moved to height: {target_height}")

def move_arm(fetch_robot, target_positions, max_force=100, speed=0.3):
    """
    Move the Fetch robot's arm to the specified joint positions.
    
    Args:
        fetch_robot: The PyBullet body ID of the Fetch robot
        target_positions: List of 7 joint angles for the arm joints
        max_force: Maximum force to apply (default: 100)
        speed: Speed of movement (default: 0.3)
    """
    if len(target_positions) != len(arm_joint_indices):
        print(f"Error: Expected {len(arm_joint_indices)} joint positions, got {len(target_positions)}")
        return
    
    # Get the current torso position to maintain it
    current_torso_position = p.getJointState(fetch_robot, torso_joint_index)[0]
    
    # Set position control for each joint
    for i, joint_index in enumerate(arm_joint_indices):
        p.setJointMotorControl2(
            bodyUniqueId=fetch_robot,
            jointIndex=joint_index,
            controlMode=p.POSITION_CONTROL,
            targetPosition=target_positions[i],
            force=max_force,
            maxVelocity=speed
        )
    
    # Allow more time for the movement to complete
    for _ in range(5000):
        # IMPORTANT: Actively maintain the torso position during each simulation step
        p.setJointMotorControl2(
            bodyUniqueId=fetch_robot,
            jointIndex=torso_joint_index,
            controlMode=p.POSITION_CONTROL,
            targetPosition=current_torso_position,
            force=1000,  # Very high force to keep it fixed
            maxVelocity=0.01  # Very low velocity limit
        )
        
        p.stepSimulation()
        time.sleep(1./240.)
        
        # Print current joint positions every 50 steps to monitor progress
        if _ % 50 == 0:
            current_positions = []
            for joint_index in arm_joint_indices:
                joint_state = p.getJointState(fetch_robot, joint_index)
                current_positions.append(joint_state[0])
            print(f"Current arm positions: {current_positions}")
    
    print("Arm movement completed")



# Get link indices for the problematic parts
base_link_index = -1  # Base link has index -1
bellows_link_index = None
bellows_link2_index = None

# Find the bellows link indices
for i in range(p.getNumJoints(fetch_robot)):
    joint_info = p.getJointInfo(fetch_robot, i)
    link_name = joint_info[12].decode('utf-8')
    if link_name == "bellows_link":
        bellows_link_index = i
    elif link_name == "bellows_link2":
        bellows_link2_index = i

# Disable collisions between problematic parts
if bellows_link_index is not None and bellows_link2_index is not None:
    p.setCollisionFilterPair(fetch_robot, fetch_robot, bellows_link_index, bellows_link2_index, 0)

if bellows_link_index is not None:
    p.setCollisionFilterPair(fetch_robot, fetch_robot, base_link_index, bellows_link_index, 0)

# Set the torso to 0.2 position
print("Setting torso to 0.2 position...")
if torso_joint_index is not None:
    p.resetJointState(fetch_robot, torso_joint_index, 0.2)
    time.sleep(0.5)  # Give it a moment to settle

# Set the arm to the start_positions
print("Setting arm to start position...")
for i, joint_index in enumerate(arm_joint_indices):
    p.resetJointState(fetch_robot, joint_index, start_positions[i])

# move_base_forward(fetch_robot, 5.0)
def debug_base_linkoint(self):
    """Debug information for the base joint (base_link)."""
    base_linkoint_index = None
    
    # Find the index of the base joint
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        joint_name = joint_info[1].decode('utf-8')
        if joint_name == "base_link":  # Replace with the actual name of the base joint
            base_linkoint_index = i
            break
    
    if base_linkoint_index is not None:
        # Get joint information
        joint_info = p.getJointInfo(fetch_robot, base_linkoint_index)
        joint_name = joint_info[1].decode('utf-8')
        joint_type = joint_info[2]
        joint_limits = joint_info[8:10]  # Lower and upper limits
        joint_state = p.getJointState(fetch_robot, base_linkoint_index)
        joint_position = joint_state[0]
        joint_velocity = joint_state[1]
        
        # Print debug information
        print(f"Base Joint Index: {base_linkoint_index}")
        print(f"Base Joint Name: {joint_name}")
        print(f"Base Joint Type: {joint_type}")
        print(f"Joint Limits: {joint_limits}")
        print(f"Current Position: {joint_position}")
        print(f"Current Velocity: {joint_velocity}")
    else:
        print("Base joint (base_link) not found.")
def show_all_joints(self):
    """Print all joints in the Fetch robot."""
    num_joints = p.getNumJoints(fetch_robot)
    print(f"Total number of joints: {num_joints}")
    
    for i in range(num_joints):
        joint_info = p.getJointInfo(fetch_robot, i)
        joint_name = joint_info[1].decode('utf-8')  # Joint name
        joint_type = joint_info[2]  # Joint type
        joint_index = joint_info[0]  # Joint index
        
        print(f"Joint Index: {joint_index}, Name: {joint_name}, Type: {joint_type}")
show_all_joints(fetch_robot)


        
# move_base_forward(fetch_robot, 0.5)
def capture_depth_image(width=84, height=84, filename="depth_image.png"):
    """
    Capture a depth image from the Fetch robot's head camera perspective.
    
    Args:
        width: Width of the image in pixels (default: 84)
        height: Height of the image in pixels (default: 84)
        filename: Name of the file to save the image to
    """
    # Find the head camera link
    head_camera_link = None
    base_link_index = -1  # Base link has index -1
    
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        link_name = joint_info[12].decode('utf-8')
        if link_name == "head_camera_link":
            head_camera_link = i
            break
    
    if head_camera_link is None:
        print("Warning: Could not find head_camera_link, using default camera position")
        # Default camera position if head camera link not found
        camera_position = [0.5, 0.0, 1.2]  # Position the camera
    else:
        # Get the position of the head camera link
        link_state = p.getLinkState(fetch_robot, head_camera_link)
        camera_position = link_state[0]  # Position
    
    # Get the position of the base link
    base_position, base_orientation = p.getBasePositionAndOrientation(fetch_robot)
    
    # Calculate the target position in world coordinates
    # The target is at (1.6, 0.0, 0.0) relative to the base link
    target_local = [1.6, 0.0, 0.38]
    
    # Convert from base link frame to world frame
    # This is a simplified approach - for more accuracy, use full transformation
    rotation_matrix = p.getMatrixFromQuaternion(base_orientation)
    target_world = [
        base_position[0] + rotation_matrix[0] * target_local[0] + rotation_matrix[1] * target_local[1] + rotation_matrix[2] * target_local[2],
        base_position[1] + rotation_matrix[3] * target_local[0] + rotation_matrix[4] * target_local[1] + rotation_matrix[5] * target_local[2],
        base_position[2] + rotation_matrix[6] * target_local[0] + rotation_matrix[7] * target_local[1] + rotation_matrix[8] * target_local[2]
    ]
    
    # Up vector (Z-axis in world coordinates)
    camera_up = [0, 0, 1]
    
    print(f"Camera position: {camera_position}")
    print(f"Looking at target: {target_world}")
    
    # Compute view and projection matrices
    view_matrix = p.computeViewMatrix(
        cameraEyePosition=camera_position,
        cameraTargetPosition=target_world,
        cameraUpVector=camera_up
    )
    
    projection_matrix = p.computeProjectionMatrixFOV(
        fov=60.0,  # Approximate FOV for Fetch head camera
        aspect=float(width) / height,
        nearVal=0.1,
        farVal=100.0
    )
    
    # Capture the image
    _, _, _, depth_buffer, _ = p.getCameraImage(
        width=width,
        height=height,
        viewMatrix=view_matrix,
        projectionMatrix=projection_matrix,
        renderer=p.ER_BULLET_HARDWARE_OPENGL
    )
    
    # Convert depth buffer to a normalized depth image
    depth_image = np.array(depth_buffer)
    
    # Create the images directory if it doesn't exist
    os.makedirs("images", exist_ok=True)
    
    # Save the raw depth data as a numpy file
    np.save(os.path.join("images", filename.replace(".png", ".npy")), depth_image)
    
    print(f"Depth data saved to images/{filename.replace('.png', '.npy')}")
    
    # Also save as a simple grayscale image using PIL if available
    try:
        from PIL import Image
        
        # Normalize to 0-255 for image
        normalized = ((depth_image - np.min(depth_image)) / 
                     (np.max(depth_image) - np.min(depth_image)) * 255).astype(np.uint8)
        
        # Create and save image
        img = Image.fromarray(normalized)
        img.save(os.path.join("images", filename))
        print(f"Depth image saved to images/{filename}")
    except ImportError:
        print("PIL not available, only saved raw numpy data")
    
    return depth_image

# After setting up the robot, capture a depth image
print("Capturing depth image...")
capture_depth_image(width=84, height=84, filename="fetch_depth_84x84.png")

def move_wheels(fetch_robot, left_speed=0.5, right_speed=0.5):
    """
    Move both the left and right wheels of the Fetch robot.

    Args:
        fetch_robot: The PyBullet body ID of the Fetch robot
        left_speed: Speed of the left wheel (default: 0.5 m/s)
        right_speed: Speed of the right wheel (default: 0.5 m/s)
    """
    # Find the left wheel joint
    left_wheel_index = None
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        joint_name = joint_info[1].decode('utf-8')
        if joint_name == "l_wheel_joint":
            left_wheel_index = i
            break

    if left_wheel_index is None:
        print("Warning: Could not find left wheel joint")
        return

    # Find the right wheel joint
    right_wheel_index = None
    for i in range(p.getNumJoints(fetch_robot)):
        joint_info = p.getJointInfo(fetch_robot, i)
        joint_name = joint_info[1].decode('utf-8')
        if joint_name == "r_wheel_joint":
            right_wheel_index = i
            break

    if right_wheel_index is None:
        print("Warning: Could not find right wheel joint")
        return

    # Print joint indices for debugging
    print(f"Left Wheel Index: {left_wheel_index}, Right Wheel Index: {right_wheel_index}")

    # Print joint info for debugging
    left_joint_info = p.getJointInfo(fetch_robot, left_wheel_index)
    right_joint_info = p.getJointInfo(fetch_robot, right_wheel_index)
    print(f"Left Joint Info: {left_joint_info}")
    print(f"Right Joint Info: {right_joint_info}")

    # Set the wheel radius
    wheel_radius = 0.1  # Fixed value for the wheel radius in meters

    # Convert linear velocity to wheel angular velocity
    left_angular_velocity = left_speed / wheel_radius  # rad/s
    right_angular_velocity = right_speed / wheel_radius  # rad/s

    # Apply velocity to both wheels
    p.setJointMotorControl2(
        bodyUniqueId=fetch_robot,
        jointIndex=left_wheel_index,
        controlMode=p.VELOCITY_CONTROL,
        targetVelocity=left_angular_velocity,
        force=5000.0  # Further increased force
    )
    
    p.setJointMotorControl2(
        bodyUniqueId=fetch_robot,
        jointIndex=right_wheel_index,
        controlMode=p.VELOCITY_CONTROL,
        targetVelocity=right_angular_velocity,
        force=5000.0  # Further increased force
    )

    print(f"Applied forward velocity to left wheel: {left_speed:.3f} m/s (angular: {left_angular_velocity:.3f} rad/s)")
    print(f"Applied forward velocity to right wheel: {right_speed:.3f} m/s (angular: {right_angular_velocity:.3f} rad/s)")

    # Step the simulation after applying the control
    for _ in range(100):  # Step the simulation multiple times to ensure movement
        p.stepSimulation()
        time.sleep(1./240.)

    # Print joint states after movement
    left_joint_state = p.getJointState(fetch_robot, left_wheel_index)
    right_joint_state = p.getJointState(fetch_robot, right_wheel_index)
    print(f"Left Wheel State: {left_joint_state}, Right Wheel State: {right_joint_state}")

# Test moving both wheels
move_wheels(fetch_robot, left_speed=0.5, right_speed=0.5)


p.disconnect()