import pybullet as p
import numpy as np
from typing import Tuple, List
import os
from datetime import datetime
import open3d as o3d

class TurtlebotController:
    def __init__(self, robot_id: int, obs_size: int, action_size: int, wheel_move = False, max_linear = 0.65):
        self.robot_id = robot_id
        # Define max speeds
        self.max_linear_speed = max_linear  # Maximum linear velocity in m/s
        self.max_angular_speed = 1.0  # Maximum angular velocity in rad/s
        self.max_force = 100.0
        
        # Wheel indices
        self.left_wheel = 0  
        self.right_wheel = 1
        
        # Camera parameters for RealSense D435i
        self.camera_link = 27
        self.width = 84      
        self.height = 84      
        self.fov = 87  # Horizontal FOV of RealSense D435i
        self.near = 0.1
        self.far = 10.0  # RealSense has range up to 10m, but we'll use 8m for better depth resolution

        
        self.obs_size = obs_size
        self.action_size = action_size
        
        # Add keyboard teleop parameters
        self.keyboard_speed = 10
        self.keyboard_forward = 0
        self.keyboard_turn = 0
        self.wheel_move = wheel_move

    def get_state(self, target_pos):
        """
        Get current state of the TurtleBot.
        Returns state vector containing:
            - x, y position (2)
            - orientation as sin/cos of yaw (2)
            - linear velocity (1)
            - angular velocity (1)
            - distance to target (1)
            - has stopped at stop line (1)
            - depth data (76800)
        """
        # Get robot position and orientation
        position, orientation = self.get_pose()
        euler = p.getEulerFromQuaternion(orientation)
        yaw = euler[2]

        # Get robot velocities
        linear_vel, angular_vel = self.get_velocity()
        
        # Calculate distance to target
        distance_to_target = np.sqrt(
            (position[0] - target_pos[0])**2 + 
            (position[1] - target_pos[1])**2
        )

        # Check if the robot has stopped at the stop line
        has_stopped_at_line = float(getattr(self, 'has_stopped_at_line', False))

        # Get camera image and depth data
        # _, depth, _ = self.get_camera_image()
        
        # # Convert depth to a flattened array and ensure correct size
        # depth = np.array(depth, dtype=np.float32).flatten()
        # target_size = self.width * self.height
        # if len(depth) > target_size:
        #     step = len(depth) // target_size
        #     depth = depth[::step][:target_size]
        # elif len(depth) < target_size:
        #     # Pad with zeros if depth is too small
        #     depth = np.pad(depth, (0, target_size - len(depth)))


        state = np.array([
                position[0], position[1],           # position (2)
                np.sin(yaw), np.cos(yaw),           # orientation (2)
                linear_vel[0], angular_vel[2],      # velocities (2)
                distance_to_target,                 # distance to target (1)
                has_stopped_at_line,                 # stop line status (1)
                yaw
            ], dtype=np.float32)
        # Construct state vector
        # state = np.concatenate([
        #     np.array([
        #         position[0], position[1],           # position (2)
        #         np.sin(yaw), np.cos(yaw),           # orientation (2)
        #         linear_vel[0], angular_vel[2],      # velocities (2)
        #         distance_to_target,                 # distance to target (1)
        #         has_stopped_at_line                 # stop line status (1)
        #     ], dtype=np.float32),
        #     depth                                   # depth data (76800)
        # ])

        return state  # Return just the state array, not a tuple
        
    def set_wheel_velocity(self, left: float, right: float) -> None:
        """Set wheel velocities within robot limits."""
        # Convert wheel velocities to match robot limits
        wheel_radius = 0.038  # meters
        max_wheel_velocity = self.max_linear_speed / wheel_radius  # rad/s
        
        left = np.clip(left, -max_wheel_velocity, max_wheel_velocity)
        right = np.clip(right, -max_wheel_velocity, max_wheel_velocity)
        
        p.setJointMotorControl2(self.robot_id, self.left_wheel, 
                              p.VELOCITY_CONTROL,
                              targetVelocity=left,
                              force=self.max_force)
        p.setJointMotorControl2(self.robot_id, self.right_wheel,
                              p.VELOCITY_CONTROL,
                              targetVelocity=right,
                              force=self.max_force)
        
        # p.stepSimulation() #uncomment for keyboard teleop

    def get_pose(self) -> Tuple[List[float], List[float]]:
        """Get robot position and orientation."""
        position, orientation = p.getBasePositionAndOrientation(self.robot_id)
        return list(position), list(orientation)

    def get_velocity(self) -> Tuple[List[float], List[float]]:
        """Get robot linear and angular velocity."""
        linear, angular = p.getBaseVelocity(self.robot_id)
        return list(linear), list(angular)
    
    def get_camera_image(self):
        """
        Get RGB and depth images from the camera with correct orientation tracking.
        Camera is mounted on top of the turtlebot at the top plate.
        
        Returns:
            tuple: (rgb_image, depth_image, segmentation)
        """
        # Get robot's current position and orientation
        pos, orn = p.getBasePositionAndOrientation(self.robot_id)
        
        # Get the Euler angles from the quaternion
        euler = p.getEulerFromQuaternion(orn)
        yaw = euler[2]  # The yaw angle determines the robot's heading
        
        # Calculate camera position (on top of the robot)
        # Apply rotation to the offset based on robot's orientation
        offset_x = 0.12  # No forward/backward offset
        offset_y = 0.0  # Side offset (local y-axis)
        offset_z = 0.23  # Height above robot base
        
        # Rotate the x,y offsets based on the robot's orientation
        rotated_x = offset_x * np.cos(yaw) - offset_y * np.sin(yaw)
        rotated_y = offset_x * np.sin(yaw) + offset_y * np.cos(yaw)
        
        camera_pos = [
            pos[0] + rotated_x,
            pos[1] + rotated_y,
            pos[2] + offset_z
        ]
        
        # Calculate forward direction based on yaw angle
        forward = [
            np.cos(yaw),  # x component
            np.sin(yaw),  # y component
            0.0           # z component (level)
        ]
        
        # Calculate target position (point in front of camera)
        target_pos = [
            camera_pos[0] + forward[0],
            camera_pos[1] + forward[1],
            camera_pos[2]
        ]
        
        # Up vector is always pointing up in world frame
        up_vector = [0, 0, 1]
        
        # Compute view matrix
        view_matrix = p.computeViewMatrix(
            cameraEyePosition=camera_pos,
            cameraTargetPosition=target_pos,
            cameraUpVector=up_vector
        )
        
        # Calculate projection matrix
        aspect = self.width / self.height
        projection_matrix = p.computeProjectionMatrixFOV(
            fov=self.fov,
            aspect=aspect,
            nearVal=self.near,
            farVal=self.far
        )
        
        # Get camera image
        _, _, rgb_array, depth_array, segmentation = p.getCameraImage(
            width=self.width,
            height=self.height,
            viewMatrix=view_matrix,
            projectionMatrix=projection_matrix,
            renderer=p.ER_BULLET_HARDWARE_OPENGL
        )
        # self.visualize_camera_angle()
        return rgb_array, depth_array, segmentation
  
    def move(self, linear_vel: float, angular_vel: float) -> None:
        """Convert linear and angular velocity to wheel velocities."""
        # Robot parameters
        wheel_separation = 0.287  # Distance between wheels in meters
        wheel_radius = 0.038     # Wheel radius in meters

        # Interpret linear and angular as left and right directly
        if self.wheel_move:
            linear_vel = np.clip(linear_vel, -1, 1)
            angular_vel = np.clip(angular_vel, -1, 1)   
            max_speed = (self.max_linear_speed - 0 * wheel_separation / 2) / wheel_radius
            left_velocity = linear_vel * max_speed
            right_velocity = angular_vel * max_speed
            #print("Ego_bot", left_velocity, right_velocity, linear_vel, angular_vel, max_speed)
        else:
            # Clip velocities to robot limits
            linear_vel = np.clip(linear_vel, -self.max_linear_speed, self.max_linear_speed)
            angular_vel = np.clip(angular_vel, -self.max_angular_speed, self.max_angular_speed)
            # Convert to wheel velocities
            left_velocity = (linear_vel - angular_vel * wheel_separation / 2) / wheel_radius
            right_velocity = (linear_vel + angular_vel * wheel_separation / 2) / wheel_radius
            #print("East_bot", left_velocity, right_velocity, linear_vel, angular_vel, self.max_linear_speed)
        self.set_wheel_velocity(left_velocity, right_velocity)

    def stop(self) -> None:
        self.set_wheel_velocity(0, 0)

    def get_laser_scan(self) -> List[float]:
        """
        Simulate a laser scan using ray casting.
        Returns distances in meters.
        """
        num_rays = 360  # One ray per degree
        max_range = 5.0  # Maximum range in meters

        position, orientation = self.get_pose()
        ray_starts = [position for _ in range(num_rays)]

        # Compute the rotation matrix only once
        rot_matrix = p.getMatrixFromQuaternion(orientation)

        # Precompute all angles in radians for efficiency
        angles = np.radians(np.arange(num_rays))
        cos_angles = np.cos(angles)
        sin_angles = np.sin(angles)

        ray_ends = []
        for ca, sa in zip(cos_angles, sin_angles):
            ray_end = [
                position[0] + max_range * (rot_matrix[0] * ca - rot_matrix[3] * sa),
                position[1] + max_range * (rot_matrix[1] * ca - rot_matrix[4] * sa),
                position[2]
            ]
            ray_ends.append(ray_end)

        results = p.rayTestBatch(ray_starts, ray_ends)
        distances = [result[2] * max_range for result in results]
        return distances
    
    def save_pointcloud(self, rgb: np.ndarray, depth: np.ndarray) -> None:
        """
        Save pointcloud to specified path.
        
        Args:
            rgb: RGB image array from get_camera_image()
            depth: Depth image array from get_camera_image()
        """
        # Define save path relative to current directory
        save_path = os.path.join(os.getcwd(), "data", "pointclouds")
        
        # Create directory if it doesn't exist
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        # Use class camera parameters instead of function parameters
        fx = fy = (self.width / 2) / np.tan(np.radians(self.fov / 2))
        cx = self.width / 2
        cy = self.height / 2
        
        # Create coordinate arrays
        x = np.arange(self.width)
        y = np.arange(self.height)
        x_coords, y_coords = np.meshgrid(x, y)
        
        # Convert depth to 3D points
        Z = depth.astype(np.float32)
        X = (x_coords - cx) * Z / fx
        Y = (y_coords - cy) * Z / fy
        
        # Stack coordinates and reshape
        xyz = np.stack([X, Y, Z], axis=-1)
        xyz = xyz.reshape(-1, 3)
        
        # Get RGB colors
        colors = rgb.reshape(-1, 4)[:, :3] / 255.0
        
        # Create Open3D point cloud
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(xyz)
        pcd.colors = o3d.utility.Vector3dVector(colors)
        
        # Generate filename with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = os.path.join(save_path, f"pointcloud_{timestamp}.ply")
        
        # Save pointcloud
        o3d.io.write_point_cloud(filename, pcd)
        print(f"Saved pointcloud to: {filename}")

    def process_keyboard_events(self):
        """Process keyboard events for teleop control."""
        keys = p.getKeyboardEvents()
        for k, v in keys.items():
            # Add escape key to exit
            if k == ord('q') and (v & p.KEY_WAS_TRIGGERED):
                p.disconnect()
                exit(0)
                
            if (k == p.B3G_RIGHT_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_turn = -0.5
            if (k == p.B3G_RIGHT_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_turn = 0
            if (k == p.B3G_LEFT_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_turn = 0.5
            if (k == p.B3G_LEFT_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_turn = 0

            if (k == p.B3G_UP_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_forward = 1
            if (k == p.B3G_UP_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_forward = 0
            if (k == p.B3G_DOWN_ARROW and (v&p.KEY_WAS_TRIGGERED)):
                self.keyboard_forward = -1
            if (k == p.B3G_DOWN_ARROW and (v&p.KEY_WAS_RELEASED)):
                self.keyboard_forward = 0

    def keyboard_teleop(self):
        """Apply keyboard teleop controls to the robot."""
        # Convert keyboard inputs to wheel velocities
        right_wheel_velocity = (self.keyboard_forward + self.keyboard_turn) * self.keyboard_speed
        left_wheel_velocity = (self.keyboard_forward - self.keyboard_turn) * self.keyboard_speed
        
        # Use the set_wheel_velocity method for consistency
        self.set_wheel_velocity(left_wheel_velocity, right_wheel_velocity)
        